[WIP] GPT2 implementation based on [Neel Nanda's Clean Transformer Video Tutorial](https://www.youtube.com/watch?v=bOYE6E8JrtU&list=PL7m7hLIqA0hoIUPhC26ASCVs_VrqcDpAz&index=2&ab_channel=NeelNanda) and Template.

In [20]:
from dataclasses import dataclass
import torch
from torch import nn
import einops
import unittest
from fancy_einsum import einsum
import math
from easy_transformer import EasyTransformer

In [21]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [22]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    max_context: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()

In [23]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        mean = einops.reduce(residual, 'b p d -> b p', 'mean')
        broadcast_mean = einops.repeat(mean,'b p -> b p d', d=cfg.d_model)
        residual -= broadcast_mean
        std_dev = torch.sqrt(einops.reduce(residual ** 2, 'b p d -> b p', 'mean') + cfg.layer_norm_eps)
        broadcast_std_dev = einops.repeat(std_dev, 'b p -> b p d', d=cfg.d_model)
        normalized = residual / broadcast_std_dev
        return normalized * self.w + self.b


In [24]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        one_hot_tokens = nn.functional.one_hot(tokens, num_classes = cfg.d_vocab).float()
        return torch.matmul(one_hot_tokens, self.W_E)

In [25]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.max_context, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        batch_size, max_tokens = tokens.shape
        truncuated_W_pos = self.W_pos[:max_tokens, :]
        return torch.broadcast_to(truncuated_W_pos, (batch_size, max_tokens, self.cfg.d_model))

In [73]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        if IN_COLAB:
            self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))
        else:
            self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cpu"))
            
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        
        queries = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_Q) + self.b_Q

        keys = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_K) + self.b_K

        values = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_V) + self.b_V

        attention_scores = einsum(
            'batch query_position n_heads d_head, batch key_position n_heads d_head -> batch n_heads query_position key_position',
            queries,
            keys)
        attention_scores = attention_scores / math.sqrt(cfg.d_head)

        mask = torch.triu(torch.ones(attention_scores.shape[-2], attention_scores.shape[-1]), diagonal=1).bool()
        attention_scores.masked_fill_(mask, self.IGNORE)
        prob_dist = torch.softmax(attention_scores, dim=3)

        sum_after_attention = einsum(
            'batch key_position n_heads d_head, batch n_heads query_position key_position -> batch n_heads query_position d_head',
            values,
            prob_dist)
        print(sum_after_attention.sum())
        out_per_heads = einsum(
            'batch n_heads query_position d_head, n_heads d_head d_model -> batch n_heads query_position d_model',
            sum_after_attention,
            self.W_O) + self.b_O
        print('aaaaaaaaaa')
        print(out_per_heads.sum( dim=1).sum())
        print(out_per_heads.sum())

        out_per_heads = einsum(
            'batch n_heads query_position d_head, n_heads d_head d_model -> batch query_position d_model',
            sum_after_attention,
            self.W_O) + self.b_O
        print('bbbbb')
        print(out_per_heads.sum())
        
        return out_per_heads

In [151]:
a = torch.rand(2,3)

In [152]:
b = torch.rand(2,3)

In [155]:
einsum('i j, i j -> i j', a, b)

tensor([[0.1533, 0.0137, 0.7793],
        [0.3396, 0.0423, 0.4133]])

In [156]:
a

tensor([[0.8646, 0.1245, 0.8098],
        [0.5283, 0.5886, 0.4926]])

In [157]:
b

tensor([[0.1774, 0.1100, 0.9624],
        [0.6429, 0.0719, 0.8390]])

## Tests

In [149]:
class Tests(unittest.TestCase):
    def setUp(self):
        self.reference_gpt2 = self.get_reference_gpt2()
        reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
        self.tokens = self.reference_gpt2.to_tokens(reference_text)
        if IN_COLAB:
            self.tokens = self.tokens.cuda()
        self.cache_dict = self.get_gpt2_cache_dict(self.tokens)
        self.cfg = Config(debug=True)
        
    def get_reference_gpt2(self):
        return EasyTransformer.from_pretrained(
            "gpt2-small",
            fold_ln=False,
            center_unembed=False,
            center_writing_weights=False)

    def get_gpt2_cache_dict(self, tokens):    
        _, cache = self.reference_gpt2.run_with_cache(tokens)
        return cache.cache_dict

    def rand_float_test(self, cls, shape):
        layer = cls(self.cfg)
        if IN_COLAB:
            layer = layer.cuda()
        random_input = torch.randn(shape)
        if IN_COLAB:
            random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def rand_int_test(self, cls, shape):
        layer = cls(self.cfg)
        if IN_COLAB:
            layer = layer.cuda()
        random_input = torch.randint(100, 1000, shape)
        if IN_COLAB:
            random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def load_gpt2_test(self, cls, gpt2_layer, input_name):
        layer = cls(cfg)
        if IN_COLAB:
            layer = layer.cuda()
        layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
        # Allow inputs of strings or tensors
        if isinstance(input_name, str):
            reference_input = self.cache_dict[input_name]
        else:
            reference_input = input_name
        output = layer(reference_input)
        reference_output = gpt2_layer(reference_input)
        comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
        correct_ratio = comparison.sum()/comparison.numel()
        self.assertEqual(correct_ratio, 1, f'{torch.round(correct_ratio * 100)}% of values are correct')
        return output

    def test_layer_norm(self):
        self.rand_float_test(LayerNorm, [2, 4, 768])
        self.load_gpt2_test(LayerNorm, self.reference_gpt2.ln_final, "blocks.11.hook_resid_post")

    def test_embed(self):
        self.rand_int_test(Embed, [2, 4])
        self.load_gpt2_test(Embed, self.reference_gpt2.embed, self.tokens)

    def test_pos_embed(self):
        self.rand_int_test(PosEmbed, [2, 4])
        self.load_gpt2_test(PosEmbed, self.reference_gpt2.pos_embed, self.tokens)

    def test_attention(self):
        # self.rand_float_test(Attention, [2, 4, 768])
        self.load_gpt2_test(Attention, self.reference_gpt2.blocks[0].attn, "blocks.0.ln1.hook_normalized")

In [150]:
suite = unittest.TestSuite()
# suite.addTest(Tests('test_layer_norm'))
# suite.addTest(Tests('test_embed'))
# suite.addTest(Tests('test_pos_embed'))
suite.addTest(Tests('test_attention'))
runner = unittest.TextTestRunner()
runner.run(suite)

Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


.
----------------------------------------------------------------------
Ran 1 test in 3.058s

OK


tensor(157.6061, grad_fn=<SumBackward0>)
aaaaaaaaaa
tensor(-1716.8521, grad_fn=<SumBackward0>)
tensor(-1716.8517, grad_fn=<SumBackward0>)
bbbbb
tensor(326.3628, grad_fn=<SumBackward0>)


<unittest.runner.TextTestResult run=1 errors=0 failures=0>

In [125]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        if IN_COLAB:
            self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))
        else:
            self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cpu"))
            
    
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        if self.cfg.debug: print("Normalized_resid_pre:", normalized_resid_pre.shape)
        q = einsum("batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head", normalized_resid_pre, self.W_Q) + self.b_Q
        k = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_K) + self.b_K
        
        attn_scores = einsum("batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos", q, k)
        attn_scores = attn_scores / math.sqrt(self.cfg.d_head)

        attn_scores = self.apply_causal_mask(attn_scores)

        pattern = attn_scores.softmax(dim=-1) # [batch, n_head, query_pos, key_pos]

        
        v = einsum("batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head", normalized_resid_pre, self.W_V) + self.b_V
        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", pattern, v)
        print(z.sum())
        attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_O) + self.b_O
        print(attn_out)
        return attn_out

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(torch.ones(attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device), diagonal=1).bool()
        attn_scores.masked_fill_(mask, self.IGNORE)
        return attn_scores

In [126]:
suite = unittest.TestSuite()
# suite.addTest(Tests('test_layer_norm'))
# suite.addTest(Tests('test_embed'))
# suite.addTest(Tests('test_pos_embed'))
suite.addTest(Tests('test_attention'))
runner = unittest.TextTestRunner()
runner.run(suite)

Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


.
----------------------------------------------------------------------
Ran 1 test in 3.229s

OK


Normalized_resid_pre: torch.Size([1, 35, 768])
tensor(157.6061, grad_fn=<SumBackward0>)
tensor([[[ 7.9663e-01,  1.6985e-02,  3.4781e-02,  ...,  3.3120e-02,
          -2.3129e-02,  1.8103e-01],
         [ 1.3167e-03,  1.5750e-01, -1.4059e-01,  ..., -8.1997e-03,
           5.3075e-03,  1.3511e-01],
         [ 8.9738e-02, -7.2411e-01, -6.9866e-01,  ...,  5.5321e-02,
           2.7958e-03,  9.0785e-02],
         ...,
         [-3.0286e-01,  4.9638e-02, -6.0990e-01,  ..., -3.7084e-02,
          -4.9524e-04, -8.6008e-03],
         [-1.0844e+00, -6.1457e-02,  2.2966e-01,  ..., -2.6688e-02,
          -1.4368e-02,  3.3245e-02],
         [ 3.7947e-01, -4.9886e-01,  2.6434e-01,  ..., -2.7894e-02,
          -8.9028e-03,  4.8796e-02]]], grad_fn=<AddBackward0>)


<unittest.runner.TextTestResult run=1 errors=0 failures=0>