In [1]:
import torch
import torchvision, torchaudio
from torch import optim, nn, utils, Tensor
from dataclasses import dataclass
import einops
from einops import einsum
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
from easy_transformer import EasyTransformer
import pandas as pd
import json
import lightning as L
import numpy as np
from torch.utils.data import Dataset
from copy import deepcopy
from bisect import bisect_right

from datasets import load_dataset, load_dataset_builder

print(torch.cuda.is_available())
print(torch.__version__)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

True
2.4.0+cu121


In [2]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
target_tokens = torch.cat((tokens[0][1:], torch.tensor([50256])), 0).view(1, -1)
logits, cache = reference_gpt2.run_with_cache(reference_text)
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])[-1]

data = pd.read_json('../data/train_1M.jsonl', lines=True)
data.insert(3, "tokens", [reference_gpt2.to_tokens(data.contents[i])[0] for i, x in enumerate(data.contents)], True)
data



Moving model to device:  cuda
Finished loading pretrained model gpt2-small into EasyTransformer!


Token indices sequence length is longer than the specified maximum sequence length for this model (1288 > 1024). Running this sequence through the model will result in indexing errors


Unnamed: 0,contents,metadata,id,tokens
0,Alsatian Cheese Tart\n\nFrench Chef Michel Ber...,"{'pile_set_name': ['Pile-CC', 'OpenWebText2']}",21,"[tensor(50256), tensor(2348), tensor(49720), t..."
1,depicted by four young winged men in Roman-lik...,"{'pile_set_name': ['Wikipedia (en)', 'Pile-CC']}",81,"[tensor(50256), tensor(10378), tensor(5722), t..."
2,Available in:\n\ndescription\n\nOn Tuesday Mar...,"{'pile_set_name': ['Pile-CC', 'Pile-CC']}",102,"[tensor(50256), tensor(10493), tensor(287), te..."
3,Date:\n\nDiscipline:\n\nSource:\n\nProduct num...,"{'pile_set_name': ['Pile-CC', 'Pile-CC']}",107,"[tensor(50256), tensor(10430), tensor(25), ten..."
4,other meetings she had chaired. Stockholders w...,"{'pile_set_name': ['Pile-CC', 'Pile-CC']}",110,"[tensor(50256), tensor(847), tensor(8292), ten..."
...,...,...,...,...
999995,wedding day and itinerary runs without a hitch...,"{'pile_set_name': ['Pile-CC', 'USPTO Backgroun...",34118665,"[tensor(50256), tensor(86), tensor(6048), tens..."
999996,a faster motor. Motor speed is already a param...,"{'pile_set_name': ['USPTO Backgrounds', 'Pile-...",34118691,"[tensor(50256), tensor(64), tensor(5443), tens..."
999997,Flag Law on 29 May 1936.\n\nFlag of Turkey\n\n...,"{'pile_set_name': ['Pile-CC', 'OpenWebText2']}",34118794,"[tensor(50256), tensor(34227), tensor(3854), t..."
999998,"president; he should, at best, be reviving cel...","{'pile_set_name': ['OpenWebText2', 'Pile-CC']}",34118819,"[tensor(50256), tensor(22540), tensor(26), ten..."


In [3]:
len(data.tokens.values[0])

323

In [4]:
data.contents.values[1]

"depicted by four young winged men in Roman-like dresses, driving vessels and blowing air into horns. The central upper square is an old man representing the Year, with the Wheel of Time, while at the upper corners are the personifications of the Rivers of Paradise. The other six upper squares depict the Four Seasons, as well as Samson and Abel (or Cain).\n\nThe two lower corners show the personifications of the Sun (left, symbolizing Sunday) and the Moon (right, much deteriorated, symbolizing Monday), while the  side outer squares represent the months (only eight of which survive). At the bottom are incomplete scenes of the discovery of Holy Cross.\n\nSources\n\nExternal links\n\nOfficial cathedral's website \nPage with details of the figures \nPage with links to websites and the newest literature  (2012)Publication Date:\n\nDiscipline:\n\nSource:\n\nProduct number:\n\nLength:\n\nAlso Available in:\n\ndescription\n\nIn October 2004 Fernández Pujals, founder of Telepizza, an internatio

In [7]:
class CustomGPTDataset(Dataset):
    def __init__(self, data, sequence_length):
        self.data = data
        self.sequence_length = sequence_length
        self.n_samples = len(self.data)
        self.li = [len(tokens) for tokens in self.data.tokens.values]
        self.cumulative_li = deepcopy(self.li)
        self.cumulative_li[0] += 1 - self.sequence_length 
        for i in range(1, len(self.cumulative_li)):
            self.cumulative_li[i] = self.cumulative_li[i-1] + self.cumulative_li[i] + 1 - self.sequence_length
            
    def __len__(self):
        return self.cumulative_li[-1]

    def __getitem__(self, idx):
        string_i = bisect_right(self.cumulative_li, idx)
        diff = self.cumulative_li[string_i] - idx
        context_window = self.data.tokens.values[string_i][-diff:self.sequence_length-diff]
        if diff == 0:
            target_window = torch.cat((self.data.tokens.values[string_i][-diff+1:self.sequence_length-diff], torch.tensor([50256])), 0).view(1, -1) 
        else:
            target_window = self.data.tokens.values[string_i][-diff+1:self.sequence_length-diff+1]
        return context_window, target_window

In [11]:
training_data = CustomGPTDataset(data, sequence_length = 47)

In [12]:
min(training_data.li), max(training_data.li)

(47, 13375)

In [28]:
train_dataloader = utils.data.DataLoader(training_data, batch_size=64, shuffle=True)
len(next(iter(train_dataloader)))

RuntimeError: stack expects each tensor to be equal size, but got [0] at entry 0 and [128] at entry 2

In [None]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, cfg):
        super().__init__()
        self.transformer = DemoTransformer(cfg)
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        
        logits = self.transformer(text)
        loss = self.criterion(logits, target)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


In [31]:
criterion(logits.view(-1, 50257), target_tokens.flatten().to(device))

tensor(4.5890, device='cuda:0', grad_fn=<NllLossBackward0>)

In [5]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    random_input = torch.randn(shape).cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    random_input = torch.randint(100, 1000, shape).cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str): 
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

In [6]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


In [7]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))
    
    def forward(self, residual):
        # residual: [batch, position, d_model]
        "YOUR CODE HERE"
        if self.cfg.debug: print("Residual:", residual.shape)
        residual = residual - einops.reduce(residual, "batch position d_model -> batch position 1", "mean")
        # Calculate the variance, square root it. Add in an epsilon to prevent divide by zero.
        scale = (einops.reduce(residual.pow(2), "batch position d_model -> batch position 1", "mean") + cfg.layer_norm_eps).sqrt()
        normalized = residual / scale
        normalized = normalized * self.w + self.b
        if self.cfg.debug: print("Normalized:", residual.shape)
        return normalized
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Residual: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


In [8]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        "YOUR CODE HERE"
        if self.cfg.debug: print("Tokens:", tokens.shape)
        embed = self.W_E[tokens, :] # [batch, position, d_model]
        if self.cfg.debug: print("Embeddings:", embed.shape)
        return embed
        
rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
Embeddings: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
Embeddings: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       device='cuda:0', grad_fn=<IndexBackward0>)

In [9]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        if self.cfg.debug: print("Tokens:", tokens.shape)
        pos_embed = self.W_pos[:tokens.size(1), :] # [position, d_model]
        pos_embed = einops.repeat(pos_embed, "position d_model -> batch position d_model", batch=tokens.size(0))
        if self.cfg.debug: print("pos_embed:", pos_embed.shape)
        return pos_embed

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Tokens: torch.Size([1, 35])
pos_embed: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         ...,
         [ 4.6277e-04,  2.3037e-02,  4.1227e-02,  ..., -1.9287e-03,
          -2.3037e-03, -4.3189e-03],
         [-2.7136e-03,  2.1724e-02,  3.9675e-02,  ...,  4.2048e-04,
          -4.8160e-03, -9.2252e-04],
         [ 6.6815e-03,  2.0595e-02,  3.6596e-02,  ..., -9.5090e-04,
          -3.2512e-03, -9.6509e-04]]], device='cuda:0',
       grad_fn=<ExpandBackward0>)

In [21]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(0.0, dtype=torch.float32, device="cuda"))
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        "YOUR CODE HERE"
        q = einsum(self.W_Q, normalized_resid_pre, "n_heads d_model d_head, batch position d_model -> batch position n_heads d_head") + self.b_Q
        k = einsum(self.W_K, normalized_resid_pre, "n_heads d_model d_head, batch position d_model -> batch position n_heads d_head") + self.b_K
        v = einsum(self.W_V, normalized_resid_pre, "n_heads d_model d_head, batch position d_model -> batch position n_heads d_head") + self.b_V
        qk = einsum(q, k, "batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos")
        attn_scores = qk / math.sqrt(self.cfg.d_head)
        attn_scores = self.apply_causal_mask(attn_scores)
        pattern = attn_scores.softmax(dim=-1) # [batch, n_heads, query_pos, key_pos]
        print(pattern.shape, v.shape)
        z = einsum(pattern, v, "batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head")
        attn_out = einsum(z, self.W_O, "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model") + self.b_O
        return attn_out
        
    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.ln1.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
torch.Size([2, 12, 4, 4]) torch.Size([2, 4, 12, 64])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
torch.Size([1, 12, 35, 35]) torch.Size([1, 35, 12, 64])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 7.9663e-01,  1.6985e-02,  3.4781e-02,  ...,  3.3120e-02,
          -2.3129e-02,  1.8103e-01],
         [ 1.3162e-03,  1.5750e-01, -1.4059e-01,  ..., -8.1997e-03,
           5.3076e-03,  1.3511e-01],
         [ 8.9738e-02, -7.2411e-01, -6.9866e-01,  ...,  5.5321e-02,
           2.7959e-03,  9.0785e-02],
         ...,
         [-3.0286e-01,  4.9638e-02, -6.0990e-01,  ..., -3.7084e-02,
          -4.9527e-04, -8.6008e-03],
         [-1.0844e+00, -6.1457e-02,  2.2966e-01,  ..., -2.6688e-02,
          -1.4368e-02,  3.3245e-02],
         [ 3.7947e-01, -4.9886e-01,  2.6434e-01,  ..., -2.7894e-02,
          -8.9028e-03,  4.8796e-02]]], device='cuda:0', grad_fn=<AddBackward0>)

In [24]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
    
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_mid:", normalized_resid_mid.shape)
        pre = einsum(normalized_resid_mid, self.W_in, "batch position d_model, d_model d_mlp -> batch position d_mlp") + self.b_in
        post = gelu_new(pre)
        mlp_out = einsum(post, self.W_out, "batch position d_mlp, d_mlp d_model -> batch position d_model") + self.b_out
        return mlp_out

rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Normalized_resid_mid: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.4380,  0.3624,  0.5117,  ...,  1.7227,  1.5761,  0.0368],
         [-1.0766, -0.0438,  0.3276,  ..., -0.5437,  0.4033,  0.3717],
         [-1.2182, -1.5481, -0.9702,  ...,  1.0737,  0.7199,  0.5080],
         ...,
         [-0.4004,  0.8475,  0.2047,  ...,  0.3789,  0.0455, -0.4744],
         [-0.0862,  0.7839,  0.9046,  ..., -0.2174, -0.5953,  0.8555],
         [ 0.8448, -0.3743,  1.0397,  ...,  0.0296,  0.3405,  0.3585]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [26]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_final:", normalized_resid_final.shape)
        logits = einsum(normalized_resid_final, self.W_U, "batch position d_model, d_model d_vocab -> batch position d_vocab") + self.b_U
        return logits

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Normalized_resid_final: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35, 768])
Normalized_resid_final: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0878,  -54.3452,
           -42.3645],
         [-128.0392, -127.9936, -130.7010,  ..., -136.7121, -129.9261,
          -129.3965],
         [-119.8521, -121.0064, -123.8819,  ..., -128.5180, -126.6027,
          -121.9060],
         ...,
         [-112.9815, -112.7750, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6725, -104.4889, -108.7361,  ..., -118.3552, -113.8767,
          -106.3604],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1970, -138.5883,
          -122.3697]]], device='cuda:0', grad_fn=<AddBackward0>)

In [30]:
class AttentionOnly(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out
        
        normalized_resid_mid = self.ln2(resid_mid)
        resid_post = resid_mid + normalized_resid_mid
        return resid_post

In [28]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out
        
        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post
rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
torch.Size([2, 12, 4, 4]) torch.Size([2, 4, 12, 64])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Residual: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
torch.Size([1, 12, 35, 35]) torch.Size([1, 35, 12, 64])
Residual: torch.Size([1, 35, 768])
Normalized: torch.Size([1, 35, 768])
Normalized_resid_mid: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.3911,  0.1543,  0.6005,  ...,  1.7198,  1.7365,  0.3930],
         [-0.9039, -0.0360,  0.2351,  ..., -0.4148,  0.3562,  0.3936],
         [-0.9647, -2.4819, -1.4995,  ...,  1.4046,  0.7616,  0.5918],
         ...,
         [-0.7421,  0.9251, -0.3218,  ...,  0.2921,  0.1097, -0.5344],
         [-1.3221,  0.8960,  1.1795,  ..., -0.5544, -0.4071,  0.9255],
         [ 1.1209, -0.8919,  1.3737,  ..., -0.1356,  0.3434,  0.4517]]],
       device='cuda:0', grad_fn=<AddBackward0>)

In [29]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Tokens: torch.Size([2, 4])
Embeddings: torch.Size([2, 4, 768])
Tokens: torch.Size([2, 4])
pos_embed: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
torch.Size([2, 12, 4, 4]) torch.Size([2, 4, 12, 64])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
torch.Size([2, 12, 4, 4]) torch.Size([2, 4, 12, 64])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
torch.Size([2, 12, 4, 4]) torch.Size([2, 4, 12, 64])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
Normalized_resid_mid: torch.Size([2, 4, 768])
Residual: torch.Size([2, 4, 768])
Normalized: torch.Size([2, 4, 768])
torch.Size([2, 12, 4, 4]) torch.Size([2, 4, 12, 64])
Res

tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0877,  -54.3452,
           -42.3644],
         [-128.0391, -127.9936, -130.7010,  ..., -136.7120, -129.9261,
          -129.3965],
         [-119.8521, -121.0064, -123.8819,  ..., -128.5180, -126.6027,
          -121.9060],
         ...,
         [-112.9815, -112.7749, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6724, -104.4888, -108.7361,  ..., -118.3552, -113.8766,
          -106.3604],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1970, -138.5883,
          -122.3697]]], device='cuda:0', grad_fn=<AddBackward0>)