In [1]:

import torch as t
import torch.nn as nn 
import torch.nn.functional as F 
from torch import Tensor
from jaxtyping import Int, Float
from dataclasses import dataclass

In [2]:

@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()

In [54]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(t.ones(cfg.d_model))
        self.b = nn.Parameter(t.ones(cfg.d_model))

    def forward(self, resid: Float[Tensor, "batch seq d_model"]):
        mean = resid.mean(dim=-1, keepdim=True)
        std = (
            resid.var(dim=-1, keepdim=True, unbiased=False) + self.cfg.layer_norm_eps
        ).sqrt()
        normalized = (resid - mean) / std
        return normalized * self.w + self.b

In [None]:

class SHA(nn.Module):
    def __init__(self, cfg: Config): 
        super().__init__()
        DM, DH = cfg.d_model, cfg.d_head
        self.D = DM
        self.DH = DH
        # TODO: where do I chop up the d_model dim?
        self.scale = t.sqrt(t.tensor(DM, dtype=t.float32))
        self.wq = nn.Linear(DM, DH) # query projection
        self.wk = nn.Linear(DM, DH) # key projection  
        self.wv = nn.Linear(DM, DH) # value projection
        self.wo = nn.Linear(DH, DM) # final output projection

    def forward(self, x: Float[Tensor, "B S D"]) -> Float[Tensor, "B S D"]:
        Q, K, V = self.wq(x), self.wk(x), self.wv(x) # BNH S DH
        K_t = K.transpose(1, 2) # B DH S
        A_logits = (Q @ K_t)/self.scale # B S S
        A_masked = t.tril(A_logits)
        A = F.softmax(A_masked, dim=-1) # [B, S, S]
        Z = A @ V
        return self.wo(Z)

In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.W_E = nn.Parameter(
            num_embeddings=cfg.n_ctx,
            embedding_dim=cfg.d_model,
        )
        
    def forward(self, resid: Int[Tensor, "batch seq"]) -> Float[Tensor, "batch seq d_model"]:
        _, seq = resid.shape
        return self.W_E[None, :seq, :]
    

In [70]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.ln1 = LayerNorm(cfg)
        self.attn = SHA(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, x: Float[Tensor, "batch seq d_model"]) -> Float[Tensor, "batch seq d_model"]:
        attn_output = self.ln1(self.attn(x))
        x += attn_output
        mlp_output = self.ln2(self.mlp(x))
        x += mlp_output
        return x
        
        

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.W =nn.Linear(
            in_features=cfg.d_model,
            out_features=cfg.d_vocab
        )

    def forward(self, x: Float[Tensor, "batch seq d_model"]) -> Float[Tensor, "batch seq d_vocab"]:
        a = self.W(x)
        return t.softmax(a, dim=-1)

In [73]:

class FullTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.token_embedding = nn.Embedding(
            num_embeddings=cfg.d_vocab,
            embedding_dim=cfg.d_model,
        )
        self.pos_embedding = PosEmbed(cfg)
        self.blocks = nn.ModuleList([
            TransformerBlock(cfg)
        ])
        self.unemb = Unembed(cfg)
    
    def forward(self, input: Int[Tensor, "b s"]):
        tok_emb = self.token_embedding(input)
        pos_emb = self.pos_embedding(input)
        resid = tok_emb + pos_emb
        logits = self.blocks(resid)
        probs = self.unemb(logits)
        return probs 

In [74]:
toks = t.tensor([
    [3, 8, 4]
], dtype=t.int)

In [75]:
model = FullTransformer(cfg=cfg)

In [76]:
model(toks)

TypeError: 'Embedding' object is not subscriptable