[WIP] GPT2 implementation based on [Neel Nanda's Clean Transformer Video Tutorial](https://www.youtube.com/watch?v=bOYE6E8JrtU&list=PL7m7hLIqA0hoIUPhC26ASCVs_VrqcDpAz&index=2&ab_channel=NeelNanda) and Template.

In [14]:
import transformer_lens
from dataclasses import dataclass
import torch
from torch import nn
import einops
import unittest
from fancy_einsum import einsum

In [3]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [4]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    max_context: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()

In [5]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        mean = einops.reduce(residual, 'b p d -> b p', 'mean')
        broadcast_mean = einops.repeat(mean,'b p -> b p d', d=cfg.d_model)
        residual -= broadcast_mean
        std_dev = torch.sqrt(einops.reduce(residual ** 2, 'b p d -> b p', 'mean') + cfg.layer_norm_eps)
        broadcast_std_dev = einops.repeat(std_dev, 'b p -> b p d', d=cfg.d_model)
        normalized = residual / broadcast_std_dev
        return normalized * self.w + self.b


In [6]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        one_hot_tokens = nn.functional.one_hot(tokens, num_classes = cfg.d_vocab).float()
        return torch.matmul(one_hot_tokens, self.W_E)

In [11]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.max_context, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        batch_size, max_tokens = tokens.shape
        truncuated_W_pos = self.W_pos[:max_tokens, :]
        return torch.broadcast_to(truncuated_W_pos, (batch_size, max_tokens, self.cfg.d_model))

In [None]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        batch_size, max_position, _ = normalized_resid_pre.shape
        broadcast_b_Q = torch.broadcast_to(
            [batch_size, max_position, cfg.n_heads, cfg.d_head],
            self.b_Q)
        queries = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_head d_head',
            normalized_resid_pre, self.W_Q) + broadcast_b_Q

        broadcast_b_K = torch.broadcast_to(
            [batch_size, max_position, cfg.n_heads, cfg.d_head],
            self.b_K)
        keys = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_head d_head',
            normalized_resid_pre, self.W_K) + broadcast_b_Q

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        "YOUR CODE HERE"



In [12]:
class Tests(unittest.TestCase):
    def setUp(self):
        self.reference_gpt2 = self.get_reference_gpt2()
        reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
        self.tokens = self.reference_gpt2.to_tokens(reference_text)
        if IN_COLAB:
            self.tokens = self.tokens.cuda()
        self.cache_dict = self.get_gpt2_cache_dict(self.tokens)
        self.cfg = Config(debug=True)
        
    def get_reference_gpt2(self):
        return transformer_lens.HookedTransformer.from_pretrained(
            "gpt2-small",
            fold_ln=False,
            center_unembed=False,
            center_writing_weights=False)

    def get_gpt2_cache_dict(self, tokens):    
        _, cache = self.reference_gpt2.run_with_cache(tokens)
        return cache.cache_dict

    def rand_float_test(self, cls, shape):
        layer = cls(self.cfg)
        if IN_COLAB:
            layer = layer.cuda()
        random_input = torch.randn(shape)
        if IN_COLAB:
            random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def rand_int_test(self, cls, shape):
        layer = cls(self.cfg)
        if IN_COLAB:
            layer = layer.cuda()
        random_input = torch.randint(100, 1000, shape)
        if IN_COLAB:
            random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def load_gpt2_test(self, cls, gpt2_layer, input_name):
        layer = cls(cfg)
        if IN_COLAB:
            layer = layer.cuda()
        layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
        # Allow inputs of strings or tensors
        if isinstance(input_name, str):
            reference_input = self.cache_dict[input_name]
        else:
            reference_input = input_name
        output = layer(reference_input)
        reference_output = gpt2_layer(reference_input)
        comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
        correct_ratio = comparison.sum()/comparison.numel()
        self.assertEqual(correct_ratio, 1)
        return output

    def test_layer_norm(self):
        self.rand_float_test(LayerNorm, [2, 4, 768])
        self.load_gpt2_test(LayerNorm, self.reference_gpt2.ln_final, "blocks.11.hook_resid_post")

    def test_embed(self):
        self.rand_int_test(Embed, [2, 4])
        self.load_gpt2_test(Embed, self.reference_gpt2.embed, self.tokens)

    def test_pos_embed(self):
        self.rand_int_test(PosEmbed, [2, 4])
        self.load_gpt2_test(PosEmbed, self.reference_gpt2.pos_embed, self.tokens)

    def test_attention(self):
        self.rand_float_test(Attention, [2, 4, 768])
        self.load_gpt2_test(Attention, self.reference_gpt2.blocks[0].attn, "blocks.0.ln1.hook_normalized")

In [13]:
suite = unittest.TestSuite()
suite.addTest(Tests('test_layer_norm'))
suite.addTest(Tests('test_embed'))
suite.addTest(Tests('test_pos_embed'))
runner = unittest.TextTestRunner()
runner.run(suite)

Using pad_token, but it is not set yet.
.

Loaded pretrained model gpt2-small into HookedTransformer


Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


.Using pad_token, but it is not set yet.
.
----------------------------------------------------------------------
Ran 3 tests in 9.144s

OK


Loaded pretrained model gpt2-small into HookedTransformer


<unittest.runner.TextTestResult run=3 errors=0 failures=0>