In [4]:
import torch
from torch import nn
from typing import Optional, Union
import logging
import math



In [37]:
def broadcast_shaping(x: Optional[torch.Tensor], ferq: Optional[torch.Tensor]):
    ndim = x.ndim
    logger.debug(f"x Shape at broadcast_shaping {x.shape}")
    logger.debug(f"frq Shape at broadcast_shaping {frq.shape}")
    assert 0 <= 1 < ndim
    assert ferq.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return ferq.view(*shape)


def rotary_embedding(xq: Optional[torch.Tensor], xk: Optional[torch.Tensor], ferq: Optional[torch.Tensor]):
    xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().view(*xk.shape[:-1], -1, 2))
    ferq = broadcast_shaping(xq_, ferq)
    xq_out = torch.view_as_real(xq_ * ferq).flatten(3)
    xk_out = torch.view_as_real(xk_ * ferq).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)



In [22]:
class LLamaConfig:
    eps: Optional[float] = 1e-6
    hidden_size: Optional[int] = 680
    n_heads: Optional[int] = 12
    n_layers: Optional[int] = 8
    vocab_size: Optional[int] = 200
    max_sentence_length: Optional[int] = 512
    max_batch_size: Optional[int] = 32
    device: Union[torch.device, str] = 'cuda' if torch.cuda.is_available() else 'cpu'



In [23]:
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.DEBUG)

In [24]:
vocab_size: int = 200

batch: int = 12

In [25]:
x = torch.randint(0, vocab_size - 1, size=(batch, 80))

In [26]:
class LLamaAttention(nn.Module):
    def __init__(self, config: LLamaConfig):
        super(LLamaAttention, self).__init__()
        self.local_rank = config.n_heads // 1
        self.head_dim = config.hidden_size // config.n_heads
        self.wq = nn.Linear(config.hidden_size, config.n_heads * self.head_dim, bias=False,
                            )
        self.wk = nn.Linear(config.hidden_size, config.n_heads * self.head_dim, bias=False,
                            )
        self.wv = nn.Linear(config.hidden_size, config.n_heads * self.head_dim, bias=False,
                            )
        self.wo = nn.Linear(config.n_heads * self.head_dim, config.hidden_size, bias=False,
                            )
        self.cash_k = torch.zeros(
            (config.max_batch_size, config.max_sentence_length, self.local_rank, self.head_dim)).to(config.device)
        self.cash_v = torch.zeros(
            (config.max_batch_size, config.max_sentence_length, self.local_rank, self.head_dim)).to(config.device)

    def forward(self, x: Optional[torch.Tensor], pos_start: int, frq: Optional[torch.Tensor],
                mask: Optional[torch.Tensor] = None):
        batch_, seq_len_, _ = x.shape
        xq = self.wq(x).view(batch_, seq_len_, self.local_rank, self.head_dim)
        xv = self.wv(x).view(batch_, seq_len_, self.local_rank, self.head_dim)
        xk = self.wk(x).view(batch_, seq_len_, self.local_rank, self.head_dim)
        logger.debug(f'xq : {xq.shape} \nxv : {xv.shape}\nxk : {xk.shape}')
        # using rotary embedding for key and query
        xq, xk = apply_rotary_emb(xq=xq, xk=xk, freqs_cis=frq)
        # we need to cash key and values
        self.cash_v = self.cash_v.to(xv)
        self.cash_k = self.cash_k.to(xk)
        self.cash_k[:batch_, pos_start:pos_start + seq_len_] = xk
        self.cash_v[:batch_, pos_start:pos_start + seq_len_] = xq
        key = self.cash_k[:batch_, pos_start:pos_start + seq_len_]
        value = self.cash_v[:batch_, pos_start:pos_start + seq_len_]
        # [batch, seq_len , num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim]
        key = key.permute(0, 2, 1, 3)
        # [batch, seq_len , num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim]
        value = value.permute(0, 2, 1, 3)
        # [batch, seq_len , num_heads, head_dim] -> [batch, num_heads, seq_len, head_dim]
        query = xq.permute(0, 2, 1, 3)
        logger.debug(f'key : {key.shape} \nvalue : {value.shape}\nquery : {query.shape}')
        # key : [batch, num_heads, seq_len, head_dim] -> [batch, seq_len , num_heads, head_dim]
        # score : [batch, num_heads, seq_len , head_dim]
        attention = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim)
        logger.debug(f'score : {attention.shape}')
        if mask is not None:
            attention += mask
        attention = nn.functional.softmax(attention, dim=-1)
        # after matmul [batch, num_heads, seq_len , head_dim]
        comb = torch.matmul(attention, value).permute(0, 2, 1, 3).contiguous().view(batch_, seq_len_, -1)
        return self.wo(comb)

In [27]:
config = LLamaConfig()
attention = LLamaAttention(config)
embedding = nn.Embedding(config.vocab_size, config.hidden_size)

In [28]:
attention.head_dim, attention.local_rank

(56, 12)

In [29]:
def precompute_frq_cis(dim: int, end: int, theta: float = 10000.0):
    frq = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=frq.device)  # type: ignore
    frq = torch.outer(t, frq).float()  # type: ignore
    frq_cis = torch.polar(torch.ones_like(frq), frq)  # complex64
    return frq_cis

In [30]:
frq = precompute_frq_cis(attention.head_dim, config.max_sentence_length * 2)

In [31]:
x.shape

torch.Size([12, 80])

In [32]:
embedded = embedding(x)

In [33]:
embedded.shape

torch.Size([12, 80, 680])

In [34]:
start_pos = 0
mask = None


In [35]:
mask = torch.full((1, 1, config.max_sentence_length, config.max_sentence_length), float('-inf'), device=x.device)

mask = torch.triu(mask, diagonal=start_pos + 1).type_as(embedded)
mask

tensor([[[[0., -inf, -inf,  ..., -inf, -inf, -inf],
          [0., 0., -inf,  ..., -inf, -inf, -inf],
          [0., 0., 0.,  ..., -inf, -inf, -inf],
          ...,
          [0., 0., 0.,  ..., 0., -inf, -inf],
          [0., 0., 0.,  ..., 0., 0., -inf],
          [0., 0., 0.,  ..., 0., 0., 0.]]]])

In [36]:
at = attention(embedded, start_pos, frq, mask)

DEBUG:__main__:xq : torch.Size([12, 80, 12, 56]) 
xv : torch.Size([12, 80, 12, 56])
xk : torch.Size([12, 80, 12, 56])
DEBUG:__main__:x Shape at broadcast_shaping torch.Size([12, 80, 12, 28])
DEBUG:__main__:frq Shape at broadcast_shaping torch.Size([1024, 28])


AssertionError: 

In [1]:
# if you follow your anger you power will answer .