In [1]:
import math

In [2]:
import struct

In [3]:
import inspect

In [4]:
from dataclasses import dataclass

In [5]:
from typing import Any,Optional,Tuple

In [6]:
import numpy as np

In [7]:
import torch

In [8]:
import torch.nn.functional as F

In [9]:
from torch import nn

In [10]:
@dataclass
class ModelArgs:
    # default hyperparameters for the Llama 7B model
    dim: int = 4096
    n_layers: int = 32
    n_heads: int = 32
    n_kv_heads: Optional[int] = None
    vocab_size: int = 32000
    hidden_dim: Optional[int] = None
    multiple_of: int = 256  # MLP hidden layer size will be multiple of
    norm_eps: float = 1e-5
    max_seq_len: int = 2048
    dropout: float = 0.0

In [11]:
class RMSNorm(nn.Module):

    def __init__(self,dim,eps=0.0):
        super().__init__()
        self.weights = nn.Parameter(torch.ones(dim))
        self.eps = eps
        
    def forward(self,x):
        # root = (x ** 2).mean(dim=-1,keepdim=True) ** 0.5
        root = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return self.weights * (x * (root + self.eps))

In [12]:
rms = RMSNorm(4)

In [13]:
x = torch.rand((5,4))

In [14]:
y=rms(x)

In [15]:
y

tensor([[1.7215, 0.5718, 0.2738, 0.7965],
        [1.0627, 1.0709, 0.9630, 0.8925],
        [0.6462, 1.3204, 1.3475, 0.1526],
        [0.6397, 0.9529, 1.3306, 0.9552],
        [0.6170, 1.1709, 1.0955, 1.0238]], grad_fn=<MulBackward0>)

In [16]:
(y ** 2).mean(dim=-1) 

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000], grad_fn=<MeanBackward1>)

In [17]:
def precompute_freqs_cis(dim:int,end:int,theta:float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0,dim,2)[:(dim//2)].float() / dim) )
    t = torch.arange(end)
    freqs = torch.outer(t,freqs).float()
    freqs_cos = torch.cos(freqs)
    freqs_sin = torch.sin(freqs)
    return freqs_cos,freqs_sin

In [18]:
x1,x2 = precompute_freqs_cis(4,8)

In [19]:
def reshape_for_broadcast(freqs_cis:torch.Tensor,x:torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1],x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i,d in enumerate(x.shape)]
    return freqs_cis.view(shape)

In [20]:
x1

tensor([[ 1.0000,  1.0000],
        [ 0.5403,  0.9999],
        [-0.4161,  0.9998],
        [-0.9900,  0.9996],
        [-0.6536,  0.9992],
        [ 0.2837,  0.9988],
        [ 0.9602,  0.9982],
        [ 0.7539,  0.9976]])

In [21]:
reshape_for_broadcast(x1,torch.ones([1,8,3,2])).shape

torch.Size([1, 8, 1, 2])

In [23]:
def apply_rotary_emb(xq:torch.Tensor,xk:torch.Tensor,freqs_cos:torch.Tensor,freqs_sin:torch.Tensor) -> Tuple[torch.Tensor,torch.Tensor]:
    xq_r,xq_i = xq.float().reshape(xq.shape[:-1]+(-1,2)).ubind(-1)
    xk_r,xk_i = xk.float().reshape(xk.shape[:-1]+(-1,2)).ubind(-1)

    freqs_cos = reshape_for_broadcast(freqs_cos,xq_r)
    freqs_sin = reshape_for_broadcast(freqs_sin,xq_r)

    xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
    xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
    xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
    xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos


    xq_out = torch.stack([xq_out_r,xq_out_i],dim=-1).flatten(3)
    xk_out = torch.stack([xk_out_r,xq_out_i],dim=-1).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

In [29]:
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs,slen,n_kv_heads,head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:,:,:,None,:].expand(
            bs,slen,n_kv_heads,n_rep,head_dim
        ).reshape(bs,slen,n_kv_heads * n_rep,head_dim)
    )

In [30]:
x = torch.ones(1,2,3,4)

In [34]:
x

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

In [33]:
repeat_kv(x,2)

tensor([[[[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]],

         [[1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.],
          [1., 1., 1., 1.]]]])

In [36]:
class Attention(nn.Module):

    def __init__(self,args:ModelArgs):
        super().__init__()
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        assert args.n_heads % self.n_kv_heads == 0
        model_parrallel_size = 1
        self.n_local_heads = args.n_heads // model_parrallel_size
        self.n_local_kv_heads = self.n_kv_heads // model_parrallel_size
        self.n_rep = self.n_local_heads // self.n_local_kv_heads
        self.head_dim = args.dim // args.n_heads
        self.wq = nn.Linear(args.dim,self.n_heads * self.head_dim,bias=False)
        self.wk = nn.Linear(args.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wv = nn.Linear(args.dim,self.n_kv_heads * self.head_dim,bias=False)
        self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim,bias=False)
        self.attn_dropout = nn.Dropout(args.dropout)
        self.resid_dropout = nn.Dropout(args.dropout)
        self.dropout = args.dropout

        self.flash = hasattr(torch.nn.functional,'scaled_dot_production_attention')
        if not self.flash:
            print('warning')
            mask = torch.full((1,1,args.max_seq_len,args.max_seq_len),float('-inf'))
            mask = torch.triu(mask,diagnoal=1)
            self.register_buffer("mask",mask)
    def forward(self,x:torch.Tensor,freqs_cos:torch.Tensor,freqs_sin:torch.Tensor):
        bsz, seqlen,_ = x.shape
        xq,xk,xv = self.wq(x),self.wk(x),self.wv(x)
        xq = xq.view(bsz,seqlen,self.n_local_heads,self.head_dim)
        xk = xk.view(bsz,seqlen,self.n_local_kv_heads,self.head_dim)
        xv = xv.view(bsz,seqlen,self.n_local_kv_heads,self.head_dim)
        xq,xv = apply_rotary_emb(xq,xv,freqs_cos,freqs_sin)
        xk = repeat_kv(xk,self.n_rep)
        xv = repeat_kv(xv,self.n_rep)
        xq = xq.transpose(1,2)
        xk = xk.transpose(1,2)
        xv = xv.transpose(1,2)
        if self.flash:
            output = torch.nn.functional.scaled_dot_product_attention(xq,xk,xv,attn_mask=None,dropout=self.dropout if self.training
                                                                     else 0.0,is_causal=True)
        else:
            scores = torch.matmul(xq,xk.transpose(2,3) / math.sqrt(self.head_dim))
            assert hasattr(self,'mask')
            scores = scores + self.mask[:,:,:seqlen,:seqlen]
            scores = F.softmax(scores.float(),dim=-1).type_as(xq)
            scores = self.attn_dropout(scores)
            output = torch.matmul(scores,xv)
        output = output.transpose(1,2).contiguous(),view(bsz,seqlen,-1)
        output = self.wo(output)
        output = self.resid_dropout(output)
        return output

In [39]:
class FeedForward(nn.Module):

    def __init__(self,dim: int,hidden_dim: int, multiple_of: int, dropout: float):
        super().__init__()
        if hidden_dim is None:
            hidden_dim = 4 * dim
            hidden_dim = int(2 * hidden_dim / 3)
            hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        self.w1 = nn.Linear(dim,hidden_dim,bias=False)
        self.w2 = nn.Linear(hidden_dim,dim,bias=False)
        self.w3 = nn.Linear(dim,hidden_dim,bias=False)
        self.dropout = nn.Dropout(dropout)
    def forward(self,x):
        return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))

In [40]:
class TransformerBlock(nn.Module):

    def __init__(self,layer_id:int, args: ModelArgs):
        super().__init__()
        self.n_heads = args.n_heads
        self.dim = args.dim
        self.head_dim = args.dim // args.n_head
        self.attention = Attention(args)
        self.feed_forward = FeedForward(
            dim=args.dim,
            hidden_dim=args.hidden_dim,
            multiple_of=args.multiple_of,
            dropout=args.dropout
        )
        self.layer_id = layer_id
        self.attention_norm = RMSNorm(args.dim,eps=args.norm_eps)
        self.ffn_norm = RMSNorm(args.dim,eps=args.norm_eps)
    
    def forward(self,x,freqs_cos,freqs_sin):
        h = x + self.attention.forward(self.attention_norm(x),freqs_cos,freqs_sin)
        out = h + self.feed_forward(self.ffn_norm(h))
        return out

In [41]:
class Transformer(nn.Module):

    last_loss = Optional[torch.Tensor]
    def __init__(self,params:ModelArgs):
        super().__init__()
        self.params = params
        self.vocab_size = params.vocab_size
        self.n_layers = params.n_layers

        self.tok_embedding = nn.Embedding(params.vocab_size,params.dim)
        self.dropout = nn.Dropout(params.dropout)
        self.layers = nn.ModuleList()
        for layer_id in range(params.n_layers):
            self.layers.append(TransformerBlock(layer_id,params))
        self.norm = RMSNorm(params.dim,eps=params.norm_eps)
        self.output = nn.Linear(params.dim,params.vocab_size,bias)
        self.tok_embedding = self.output.weight
        freqs_cos,freqs_sin = precompute_freqs_cis(self.params.dim // self.params.n_heads, self.params.max_seq_len)
        self.register_buffer("freqs_cos", freqs_cos, persistent=False)
        self.register_buffer("freqs_sin", freqs_sin, persistent=False)

        # init all weights
        self.apply(self._init_weights)
        # apply special scaled init to the residual projections, per GPT-2 paper
        for pn, p in self.named_parameters():
            if pn.endswith('w3.weight') or pn.endswith('wo.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * params.n_layers))

        # Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
        self.last_loss = None

    def _init_weights(self,module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
    def forward(self,tokens: torch.Tensor, targets: Optional[torch.Tensor] = None) -> torch.Tensor:
        _bsz, seqlen = tokens.shape
        h = self.tok_embedding(tokens)
        h = self.dropout(h)
        freqs_cos = self.freqs_cos[:seqlen]
        freqs_sin = self.freqs_sin[:seqlen]
        for layer in self.layers:
            h = layer(h,freqs_cos,freqs_sin)
        h = self.norm(h)
        if targets is not None:
            logits = self.output(h)
            self.last_loss = F.cross_entropy(logits.view(-1,logits.size(-1)),targets.view(-1),ignore_index=-1)
        else:
            logits = self.output(h[:,[-1],:])
            self.last_loss = None
        return logits