In [1]:
import torch
import torch.nn as nn
import os

In [2]:
def toy_softmax(x):
    sx = x -  x.max(-1,keepdim=True).values
    ex = sx.exp()
    return ex/ex.sum(-1,keepdim=True)

def toy_product_atte(Q, K, V, mask=None):
    d_k = torch.tensor(Q.shape[-1]) 
    Qk = Q @ K.transpose(-2,-1)/ torch.sqrt(d_k)   
    if mask is not None:
        Qk = Qk.masked_fill(mask==False,float('-inf'))
    sQk = toy_softmax(Qk)
    return sQk @ V

def toy_multihead_atte(d_model,num_heads,Qp,Kp,Vp,proj,in_features,posistion=None):
    Qs = (in_features @ Qp.transpose(-2,-1)).split(d_model//num_heads,-1)
    Ks = (in_features @ Kp.transpose(-2,-1)).split(d_model//num_heads,-1)
    Vs = (in_features @ Vp.transpose(-2,-1)).split(d_model//num_heads,-1)
    
    seq_len = Qs[0].size(-2)
    mask = torch.tril(torch.ones(seq_len,seq_len))
    
    atts = [toy_product_atte(Qs[i],Ks[i],Vs[i],mask) for i in range(num_heads)]
    atts = torch.cat(atts,-1)
    return atts @  proj.transpose(-2,-1)

In [None]:
class toy_Liner(nn.Module):
    def __init__(self, in_features, out_features, bias=None, device=None, dtype=torch.float32):
        super().__init__()
        self.weight = nn.Parameter(torch.empty(out_features, in_features,dtype=dtype))
        self.bias = nn.Parameter(torch.empty(out_features,dtype=dtype)) if bias else None
        self.device = device
        
        self.set_weights()
    
    def set_weights(self,w=None):
        if w == None:
            nn.init.trunc_normal_(self.weight)
        else:
            self.weight.data = w
    
    def forward(self,x):
        out = x @ self.weight.transpose(-2,-1)
        if self.bias != None:
            out += self.bias
        return out
    

class toy_Embedding(nn.Module):
    def __init__(self, num_embd, embd_dim, device = None,dtype = torch.float32) -> None:
        super().__init__()
        self.embd = nn.Parameter(torch.empty(num_embd,embd_dim,dtype=dtype))
        self.device = device

        self.set_para()
        
    def set_para(self,embd=None):
        if embd == None:
            nn.init.trunc_normal_(self.embd)
        else:
            self.embd.data = embd 
    
    def forward(self,x):
        return self.embd[x]
    
class toy_RMSnorm(nn.Module):
    def __init__(self, d_model, eps: float = 1e-5, device = None, dtype = torch.float32):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        self.gain = nn.Parameter(torch.empty(d_model,dtype=dtype))
        self.device = device

        self.set_para()
        
    def set_para(self,g=None):
        if g==None:
            nn.init.trunc_normal_(self.gain,1,0.02)
        else:
            self.gain.data = g
    
    def forward(self,x):
        rmsx = x.square().mean(-1,keepdim=True) 
        out = x*self.gain/torch.sqrt(rmsx+self.eps)
        return out 
        
        
class toy_SwiGLU(nn.Module):
    def __init__(self, d_model, d_ff, device=None, dtype=torch.float32):
        super().__init__()
        self.W1 = nn.Parameter(torch.empty(d_ff,d_model,dtype=dtype))
        self.W2 = nn.Parameter(torch.empty(d_model,d_ff,dtype=dtype))
        self.W3 = nn.Parameter(torch.empty(d_ff,d_model,dtype=dtype))
        self.device = device
    
        self.set_para()
    def set_para(self,w1=None,w2=None,w3=None):
        if w1 == None:
            nn.init.trunc_normal_(self.W1)
        else:
            self.W1.data = w1
        if w2 == None:
            nn.init.trunc_normal_(self.W2)
        else:
            self.W2.data = w2
        if w3 == None:
            nn.init.trunc_normal_(self.W3)
        else:
            self.W3.data = w3
    
    def forward(self,x):
        W3x = x @ self.W3.transpose(-2,-1)
        W1x = x @ self.W1.transpose(-2,-1)
        Slu = W1x * torch.sigmoid(W1x)
        return (Slu * W3x)@self.W2.transpose(-2,-1)

class toy_RoPE(nn.Module):
    def __init__(self, d_k, theta, max_len, device = None, dtype = torch.float32):
        super().__init__()
        
        self.rot_d = d_k//2
        i = torch.arange(self.rot_d, device=device, dtype=dtype)         
        j = torch.arange(max_len, device=device, dtype=dtype)      

        inv_freq = torch.exp(-(2*i)/d_k * torch.log(torch.tensor(theta, device=device, dtype=dtype)))                   
        thetas = j[:, None] * inv_freq[None, :]  
        
        cos_table = torch.cos(thetas)  #cos_table [token posistion, feature posistion]
        sin_table = torch.sin(thetas)
        
        self.register_buffer("cos_table",cos_table,persistent=False)
        self.register_buffer("sin_table",sin_table,persistent=False)
    
    def forward(self,x,tk_posistions):
        cos = self.cos_table[tk_posistions] #(T,d/2)
        sin = self.sin_table[tk_posistions] #(T,d/2)
        
        
        x_rot = x[..., :2*self.rot_d]
        x_pass = x[..., 2*self.rot_d:]
        x1 = x_rot[...,0::2] #(T,d/2 + 1) ?
        x2 = x_rot[...,1::2]
        y1 = x1 * cos - x2 * sin
        y2 = x1 * sin + x2 * cos
        y_rot = torch.stack([y1, y2], dim=-1).flatten(-2)
        return torch.cat([y_rot, x_pass], dim=-1)


        
        
    

In [None]:
class multi_attention(nn.Module):
    def __init__(self, d_in, num_heads, max_seq_len, theta, device = None) -> None:
        super().__init__()
        self.c_attention = toy_Liner(d_in, 3* d_in)
        self.proj = toy_Liner(d_in,d_in)
        self.num_head = num_heads
        self.d_head = d_in//self.num_head       
        self.ropez = toy_RoPE(self.d_head, theta ,max_seq_len)
        self.device = device
        
        trill_mask = torch.tril(torch.ones(max_seq_len,max_seq_len,dtype=torch.bool))
        self.register_buffer("trill",trill_mask,persistent=False)
    
    
    def forward(self,x):
        B, T, C = x.shape
        qkv = self.c_attention(x) #B,T,C @ C 3C -> B,T,3C
        Q,K,V = qkv.split(C,-1) # B,T,C
        
        qs = Q.view(B,T,self.num_head,self.d_head).transpose(1,2)
        ks = K.view(B,T,self.num_head,self.d_head).transpose(1,2) #B,h,T,d_h
        vs = V.view(B,T,self.num_head,self.d_head).transpose(1,2)
        
        tk_ps = torch.arange(T)
        qs = self.ropez.forward(qs,tk_ps)  
        ks = self.ropez.forward(ks,tk_ps)
        
        atts = toy_product_atte(qs,ks,vs,self.trill[:T,:T]).transpose(1, 2).contiguous().view(B,T,C) # B, T ,C
        return self.proj(atts)
    
class transformer_block(nn.Module):
    def __init__(self, d_in, num_heads, d_ff, max_seq_len, theta, device=None) -> None:
        super().__init__()   
        self.norm1 = toy_RMSnorm(d_in)
        self.atte = multi_attention(d_in,num_heads,max_seq_len,theta,device)
        self.norm2 = toy_RMSnorm(d_in)
        self.ff = toy_SwiGLU(d_in,d_ff)
        self.device = device
    
    def set_para(self,para_dict):
        q_proj_weight = para_dict["attn.q_proj.weight"]
        k_proj_weight = para_dict["attn.k_proj.weight"]
        v_proj_weight = para_dict["attn.v_proj.weight"]
        o_proj_weight = para_dict["attn.output_proj.weight"]
        ln1_weight = para_dict["ln1.weight"]
        ln2_weight = para_dict["ln2.weight"]
        ff_w1 = para_dict["ffn.w1.weight"]
        ff_w2 = para_dict["ffn.w2.weight"]
        ff_w3 = para_dict["ffn.w3.weight"]
        c_atte_weight = torch.cat([q_proj_weight,k_proj_weight,v_proj_weight],0)
        self.atte.c_attention.data = c_atte_weight
        self.atte.proj.data = o_proj_weight
        self.norm1.set_para(ln1_weight)
        self.norm2.set_para(ln2_weight)
        self.ff.set_para(ff_w1,ff_w2,ff_w3)
        
    def forward(self,x):
        x = x + self.atte(self.norm1(x))
        x = x + self.ff(self.norm2(x))
        return x

In [None]:
import re
from collections import defaultdict
class toy_Transformer_lm(nn.Module):
    def __init__(self, vocab_size, context_length, d_model, num_layers, num_heads, d_ff, rope_theta, device = None) -> None:
        super().__init__()
        self.tk_embd = toy_Embedding(vocab_size,d_model)
        self.blocks =nn.ModuleList([transformer_block(d_model,num_heads,d_ff,context_length,rope_theta) for _ in range(num_layers)])
        self.norm = toy_RMSnorm(d_model)
        self.out_embd = toy_Liner(d_model,vocab_size)
    
    def set_para(self,para_dict):
        self.tk_embd.set_para(para_dict["token_embeddings.weight"])
        self.out_embd.set_weights(para_dict["lm_head.weight"])
        self.norm.set_para(para_dict["ln_final.weight"])
        grouped = defaultdict(dict)
        pat = re.compile(r"^layers\.(\d+)\.(.+)$")  # layers.{i}.rest

        for k, v in para_dict.items():
            m = pat.match(k)
            if m:
                i = int(m.group(1))
                rest = m.group(2)  # e.g. "ffn.w3.weight"
                grouped[i][rest] = v
        layer_dict = dict(grouped)
        for _ ,blk in enumerate(self.blocks):
            blk.set_para(layer_dict[_])
    def forward(self,x):
        #x : ids
        x = self.tk_embd(x)
        for blk in self.blocks:
            x = blk(x)
        x = self.norm(x)
        return self.out_embd(x)

In [None]:
test_h = multi_attention(10,2,10,0.0001)
a = torch.randn(2,5,10)
test_h(a)

tensor([[[  5.0490,   9.1225,   4.8232,   3.0160, -12.8134,  -0.6967,  -6.0283,
          -10.6801,   0.1312,   7.3410],
         [  6.1181,   7.3358,  -4.6733,   5.7041,  -6.1997,  -6.0524,  -5.4670,
            0.2124,   1.9641,   1.8383],
         [ 10.8416,   9.0004,   8.1528,   1.5137, -13.1621,   6.4140,  -8.8128,
           -4.4009,   5.0310,  -1.4037],
         [  0.6109,  -1.7073, -12.9099,   4.9808,  13.2151, -10.3835,   1.6771,
           10.4893,  -8.0457, -10.0227],
         [-11.4117,  -8.4786, -12.5582,   7.1547,   8.3548,  -8.7325,   2.2082,
            2.9028,  18.2975,  -7.4372]],

        [[ -2.0266,  12.7991,   3.1344,   1.3523, -19.5812,  -9.1501,  -3.1588,
          -19.6919,   5.2371,   8.3103],
         [  2.2641,   5.2216,   1.3730,   6.4510, -11.3216,  -2.9140,  -1.9544,
           -6.0966,   8.3469,  -3.0424],
         [  7.8703,  15.6284,  24.5607,  -9.3389,  -2.4780,  11.6780,  -3.2490,
           -7.7913,  -3.9369,  14.9561],
         [  3.2591,   2.8341, 

In [None]:


def toy_multihead_atte_rope(d_model: int,
    num_heads: int,
    max_seq_len: int,
    theta: float,
    Qp, #Float[Tensor, " d_k d_in"],
    Kp, #Float[Tensor, " d_k d_in"],
    Vp, #Float[Tensor, " d_v d_in"],
    proj, #Float[Tensor, " d_model d_v"],
    in_features, #Float[Tensor, " ... sequence_length d_in"],
    token_positions, #Int[Tensor, " ... sequence_length"] | None = None,
):# -> Float[Tensor, " ... sequence_length d_out"]
    Ro = toy_RoPE(d_model//num_heads,theta,max_seq_len) #
    Qs = (in_features @ Qp.transpose(-2,-1)).split(d_model//num_heads,-1)
    Ks = (in_features @ Kp.transpose(-2,-1)).split(d_model//num_heads,-1)
    Vs = (in_features @ Vp.transpose(-2,-1)).split(d_model//num_heads,-1)
    
    Qs = [Ro.forward(Qs[i],token_positions) for i in range(num_heads)] #
    Ks = [Ro.forward(Ks[i],token_positions) for i in range(num_heads)]
    
    seq_len = Qs[0].size(-2)
    mask = torch.tril(torch.ones(seq_len,seq_len))
    atts = [toy_product_atte(Qs[i],Ks[i],Vs[i],mask) for i in range(num_heads)]
    atts = torch.cat(atts,-1)
    return atts @  proj.transpose(-2,-1)