In [58]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass
from jaxtyping import Int, Float
from einops import einsum

In [2]:
size = (10, 5, 768)

In [3]:
x = torch.rand(size=size, dtype=torch.float16)   # [batch, seq_len, d_model]

In [66]:
class LayerNorm(nn.Module):
    """
    mintains the distribution of each token, by forcing each feature in the embeddings to follow its own 
    distribution.
    
    algorithm:
        mean and var of every token-embedding is calculated -> [batch, seq_len, 1] means and vars
        input embeddings are normalized independently -> mean = 0, variance = 1
        each feature in normalized embeddings are forced to have -> mean = b [d_model], variance = w [d_model]
    """
    def __init__(self):
        super().__init__()
        self.w = nn.Parameter(torch.ones(768))      # learnable scaling parameter [batch, seq_len, 1]
        self.b = nn.Parameter(torch.zeros(768))     # learnable bias parameter [batch, seq_len, 1]
        self.eps = 1e-5                             # constant to prevent divison by 0 
        
    def forward(self, x: Float[torch.Tensor, "batch seq_len d_model"]) -> Float[torch.Tensor, "batch seq_len d_model"]:
        mu = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        
        x = (x - mu) / torch.sqrt(var + self.eps)
        return x * self.w + self.b

layer_norm = LayerNorm()

In [5]:
layer_norm.forward(x).size()

torch.Size([10, 5, 768])

In [46]:
class Attention(nn.Module):
    IGNORE: Float[torch.Tensor, ""]

    def __init__(self):
        super().__init__()
        self.n_heads = 12
        self.d_model = 768
        self.d_head = 64
        self.W_Q = nn.Parameter(torch.empty((self.n_heads, self.d_model, self.d_head)))
        self.W_K = nn.Parameter(torch.empty((self.n_heads, self.d_model, self.d_head)))
        self.W_V = nn.Parameter(torch.empty((self.n_heads, self.d_model, self.d_head)))
        self.W_O = nn.Parameter(torch.empty((self.n_heads, self.d_head, self.d_model)))
        self.b_Q = nn.Parameter(torch.zeros((self.n_heads, self.d_head)))
        self.b_K = nn.Parameter(torch.zeros((self.n_heads, self.d_head)))
        self.b_V = nn.Parameter(torch.zeros((self.n_heads, self.d_head)))
        self.b_O = nn.Parameter(torch.zeros((self.d_model)))
        nn.init.normal_(self.W_Q, std=0.02)
        nn.init.normal_(self.W_K, std=0.02)
        nn.init.normal_(self.W_V, std=0.02)
        nn.init.normal_(self.W_O, std=0.02)
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cpu"))

    def forward(self, resid: Float[torch.Tensor, "batch posn d_model"]):
        q = einsum(resid, self.W_Q, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head") + self.b_Q
        k = einsum(resid, self.W_K, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head") + self.b_K
        v = einsum(resid, self.W_V, "batch posn d_model, n_heads d_model d_head -> batch posn n_heads d_head") + self.b_V

        attn_patt = einsum(q, k, "batch posn_q n_heads d_head, batch posn_k n_heads d_head -> batch n_heads posn_q posn_k")
        mask_attn = self.apply_causal_mask(attn_patt)/(self.d_head ** 0.5)    # sqrt(d_head)
        soft_attn = torch.softmax(mask_attn, dim=-1)

        z = einsum(soft_attn, v, "batch n_heads posn_q posn_k, batch posn_k n_heads d_head -> batch n_heads posn_q d_head")
        z = einsum(z, self.W_O, "batch n_heads posn_q d_head, n_heads d_head d_model -> batch posn_q d_model")
        return z

    def apply_causal_mask(self, attn_scores: Float[torch.Tensor, "batch n_heads query_pos key_pos"]) -> Float[torch.Tensor, "batch n_heads query_pos key_pos"]:
        _, _, query_len, key_len = attn_scores.shape
        
        causal_mask = torch.triu(torch.ones(query_len, key_len, device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(causal_mask, self.IGNORE)
        
        return attn_scores


In [48]:
attn = Attention()
attn.forward(torch.rand(size=size, dtype=torch.float32)).size()

torch.Size([10, 5, 768])

In [61]:
class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.d_mlp = 768*4
        self.d_model = 768
        self.W_in = nn.Parameter(torch.empty((self.d_model, self.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((self.d_mlp, self.d_model)))
        self.b_in = nn.Parameter(torch.zeros((self.d_mlp)))
        self.b_out = nn.Parameter(torch.zeros((self.d_model)))
        nn.init.normal_(self.W_in, std=0.02)
        nn.init.normal_(self.W_out, std=0.02)

    def forward(self, resid: Float[torch.Tensor, "batch posn d_model"]) -> Float[torch.Tensor, "batch posn d_model"]:
        mid = F.gelu(einsum(resid, self.W_in, "batch posn d_model, d_model d_mlp -> batch posn d_mlp") + self.b_in)
        out = einsum(mid, self.W_out, "batch posn d_mlp, d_mlp d_model -> batch posn d_model") + self.b_out
        return out


In [62]:
mlp = MLP()
mlp.forward(torch.rand(size=size, dtype=torch.float32)).size()

torch.Size([10, 5, 768])

In [67]:
class TransformerBlock(nn.Module):
    def __init__(self):
        super().__init__()
        self.ln1 = LayerNorm()
        self.attn = Attention()
        self.ln2 = LayerNorm()
        self.mlp = MLP()

    def forward(self, resid_pre: Float[torch.Tensor, "batch position d_model"]) -> Float[torch.Tensor, "batch position d_model"]:
        resid_mid = self.attn(self.ln1(resid_pre)) + resid_pre
        resid_post = self.mlp(self.ln2(resid_mid)) + resid_mid
        return resid_post


In [68]:
block = TransformerBlock()
block.forward(torch.rand(size=size, dtype=torch.float32)).size()

torch.Size([10, 5, 768])