In [1]:
from dataclasses import dataclass
from typing import Optional
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [3]:
@dataclass
class ModelArgs:
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = -1 # Later set in the build method
    multiple_of: int = 256
    ffn_dim_multiplier: Optional[float] = None
    norm_eps: float = 1e-5

    # Needed for KV cache
    max_batch_size: int = 32
    max_seq_len: int = 2048

    device: str = None

In [5]:
class RMSNorm(nn.Module):

    def __init__(self,dim: int, eps: float=1e-6):
        super().__init__()
        self.eps = eps
        self.weights = nn.Parameter(torch.ones(dim))

    def _norm(self,x):
        root = (x ** 2).mean(dim=-1,keepdim=True) ** 0.5
        return x / (root + self.eps)
    def forward(self,x):
        return self.weights * self._norm(x)

In [6]:
norm = RMSNorm(4)

In [23]:
x =torch.rand(5,18)

In [12]:
(norm(x) ** 2).mean(dim=-1)

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<MeanBackward1>)

In [13]:
def precompute_theta_pos_frequencies(head_dim:int, seq_len: int, device: str, theta: float = 10000.0):
    assert head_dim % 2 == 0, "Dimension must be divisible by 2"
    theta_numberator = torch.arange(0,head_dim,2).float()
    theta = 1.0 / (theta ** (theta_numberator / head_dim)).to(device)
    m = torch.arange(seq_len,device=device)
    freqs = torch.outer(m,theta).float()
    freqs_complex = torch.polar(torch.ones_like(freqs),freqs)
    return freqs_complex

In [21]:
x = precompute_theta_pos_frequencies(16,10,'cpu')

In [25]:
x.float().reshape(*x.shape[:-1],-1,2).shape

torch.Size([5, 9, 2])

In [26]:
y = torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2))

In [29]:
y

tensor([[0.9634+0.9510j, 0.2736+0.3525j, 0.5917+0.7835j, 0.7921+0.5045j,
         0.2107+0.0653j, 0.3504+0.5839j, 0.6104+0.6343j, 0.4109+0.4540j,
         0.0039+0.9517j],
        [0.0099+0.4052j, 0.7689+0.1905j, 0.2378+0.1365j, 0.7101+0.7956j,
         0.2141+0.5726j, 0.8201+0.8848j, 0.7426+0.0547j, 0.0101+0.5260j,
         0.1623+0.8415j],
        [0.9587+0.6623j, 0.5116+0.6582j, 0.9878+0.6390j, 0.9034+0.7015j,
         0.5972+0.0058j, 0.8090+0.8477j, 0.8951+0.0786j, 0.2375+0.1210j,
         0.0922+0.7294j],
        [0.1064+0.1917j, 0.9676+0.0022j, 0.1191+0.4788j, 0.5936+0.5327j,
         0.0418+0.2735j, 0.0733+0.2542j, 0.3116+0.7680j, 0.4421+0.3954j,
         0.7363+0.2031j],
        [0.0786+0.4917j, 0.1203+0.1728j, 0.9636+0.5069j, 0.3092+0.7395j,
         0.8178+0.0180j, 0.9085+0.4730j, 0.9913+0.2056j, 0.7448+0.9696j,
         0.6131+0.1815j]])

In [32]:
torch.view_as_real(y).shape

torch.Size([5, 9, 2])

In [33]:
def apply_rotary_embeddings(x: torch.Tensor, freqs_complex: torch.Tensor, device: str):
    x_complex = torch.view_as_complex(x.float().reshape(*x.shape[:-1],-1,2))
    freqs_complex = freqs_complex.unsqueeze(0).unsqueeze(2)
    x_rotated = x_complex * freqs_complex 
    x_out = torch.view_as_real(x_rotated)
    x_out = x_out.reshape(*xshape)
    return x_out.type_as(x).to(device)

In [40]:
def repeat_kv(x: torch.Tensor,n_rep: int) -> torch.Tensor:
    batch_size, seq_len , n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,:,None,:].expand(batch_size,seq_len,n_kv_heads,n_rep,head_dim).reshape(batch_size,seq_len,n_kv_heads * n_rep,head_dim)
    )

In [41]:
class SelfAttention(nn.Module):

    def __init__(self,args:ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        self.n_heads_q = args.n_heads
        self.n_rep = self.n_heads_q // self.n_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim,args.head_dim * self.n_heads,bias=False)
        self.wk = nn.Linear(args.dim,args.n_kv_heads * self.head_dim,bias=False)
        self.wv = nn.Linear(args.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim,args.dim,bias=False)
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
    def forward(self,x:torch.Tensor,start_pos:int,freqs_complex:torch.Tensor):
        batch_size, seq_len, _ = x.shape
        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)
        xq = xq.view(batch_size,seq_len,self.n_heads_q,self.head_dim)
        xk = xk.view(batch_size,seq_len,self.n_kv_heads,self.head_dim)
        xv = xv.view(batch_size,seq_len,self.n_kv_heads,self.head_dim)
        xq = apply_rotary_embeddings(xq,freqs_complex,device=x.device)
        xk = apply_rotary_embeddings(xk,freqs_complex,device=x.device)
        self.cache_k[:batch_size,start_pos:start_pos+seq_len] = xk
        self.cache_v[:batch_size,start_pos:start_pos+seq_len] = xv
        keys = self.cache_k[:batch_size,:start_pos+seq_len]
        values = self.cache_v[:batch_size,:start_pos+seq_len]
        keys = repeat_kv(keys,self.n_rep)
        values = repeat_kv(values,self.n_rep)
        xq = xq.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        scores = torch.matmul(xq,keys.transpose(2,3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(),dim=-1).type_as(xq)
        output = (output.transpose(1,2).contiguous().view(batch_size,seq_len,-1))
        return self.wo(output)

In [43]:
class Feedward(nn.Module):

    def __init__(self,args:ModelArgs):
        super().__init__()
        hidden_dim = 4 * args.dim
        hidden_dim = int(2 * hidden_dim / 3)
        if args.ffn_dim_multiplier is not None:
            hidden_dim = int(args.ffn_dim_multiplier * hidden_dim)
        hidden_dim = args.multiple_of * ((hidden_dim + args.multiple_of - 1) // args.multiple_of)
        self.w1 = nn.Linear(args.dim,hidden_dim,bias=False)
        self.w2 = nn.Linear(hidden_dim,args.dim,bias=False)
        self.w3 = nn.Linear(args.dim,hidden_dim,bias=False)
    def forward(self,x:torch.Tensor):
        swish = F.silu(self.w1(x))
        x_V = self.w3(x)
        x = swish * x_V
        x = self.w2(x)
        return x

In [44]:
class EncoderBlock(nn.Module):

    def __init__(self,args:ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_heads
        self.attention = SelfAttention(args)
        self.feed_forward = Feedward(args)
        self.attention_norm = RMSNorm(args.dim,eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim,eps=args.norm_eps)
    def forward(self,x: torch.Tensor,start_pos:int, freqs_complex: torch.Tensor):
        h = self.attention.forward(
            self.attention_norm(x),start_pos,freqs_complex
        )
        out = h + self.feed_forward.forward(self.ffn_norm(h))
        return out

In [None]:
class Transformer(nn.Module):

    def __init__(self,args:ModelArgs):
        super().__init__()
        assert args.vocab_size != -1, 'Vocab_szie must be set'
        self.args = args
        self.vocab_size = args.vocab_size
        self.n_layers = args.n_layers
        self.tok_embeddings = nn.Embedding(self.vocab_size,args.dim)
        self.layers = nn.ModuleList()
        for layer_id in range(args.n_layers):
            self.layers.append(EncoderBlock(args))
        self.norm = RMSNorm(args.dim,self.vocab_size)
        self.output = nn.Linear(args.dim,self.vocab_size,bias=False)
        self.freqs_complex = precompute_theta_pos_frequencies(self.args.dim // self.args.n_heads,self.args.max_seq_len * 2, device=self.args.device)
    def forward(self,tokens: torch.Tensor,start_pos: int):
        batch_size, seq_len = tokens.shape
        assert seq_len == 1, "only one token at a time"
        h = self.tok_embeddings(tokens)
        freqs_complex = self.freqs_complex[start_pos:start_pos+seq_len]
        for layer in self.la