[WIP] GPT2 implementation based on [Neel Nanda's Clean Transformer Video Tutorial](https://www.youtube.com/watch?v=bOYE6E8JrtU&list=PL7m7hLIqA0hoIUPhC26ASCVs_VrqcDpAz&index=2&ab_channel=NeelNanda) and Template.

In [20]:
from dataclasses import dataclass
import torch
from torch import nn
import einops
import unittest
from fancy_einsum import einsum
import math
from easy_transformer import EasyTransformer

In [21]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [22]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    max_context: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()

In [23]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        mean = einops.reduce(residual, 'b p d -> b p', 'mean')
        broadcast_mean = einops.repeat(mean,'b p -> b p d', d=cfg.d_model)
        residual -= broadcast_mean
        std_dev = torch.sqrt(einops.reduce(residual ** 2, 'b p d -> b p', 'mean') + cfg.layer_norm_eps)
        broadcast_std_dev = einops.repeat(std_dev, 'b p -> b p d', d=cfg.d_model)
        normalized = residual / broadcast_std_dev
        return normalized * self.w + self.b


In [24]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        one_hot_tokens = nn.functional.one_hot(tokens, num_classes = cfg.d_vocab).float()
        return torch.matmul(one_hot_tokens, self.W_E)

In [25]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.max_context, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        batch_size, max_tokens = tokens.shape
        truncuated_W_pos = self.W_pos[:max_tokens, :]
        return torch.broadcast_to(truncuated_W_pos, (batch_size, max_tokens, self.cfg.d_model))

In [73]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_query = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_query, std=self.cfg.init_range)
        self.b_query = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_key = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_key, std=self.cfg.init_range)
        self.b_key = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_value = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_value, std=self.cfg.init_range)
        self.b_value = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        
        self.W_out = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))

        if IN_COLAB:
            self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))
        else:
            self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cpu"))
            
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]

        queries = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_query) + self.b_query

        keys = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_key) + self.b_key

        values = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_value) + self.b_value
        
        attention_scores = einsum(
            'batch query_position n_heads d_head, batch key_position n_heads d_head -> batch n_heads query_position key_position',
            queries,
            keys)
        attention_scores = attention_scores / math.sqrt(cfg.d_head)
        
        mask = torch.triu(torch.ones(attention_scores.shape[-2], attention_scores.shape[-1]), diagonal=1).bool()
        attention_scores.masked_fill_(mask, self.IGNORE)
        prob_dist = torch.softmax(attention_scores, dim=3)

        z = einsum("batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head", prob_dist, values)

        # attn_out = einsum("batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model", z, self.W_out) + self.b_out
        # print(attn_out)
        # return attn_out

        
        sum_after_attention = einsum(
            'batch key_position n_heads d_head, batch n_heads query_position key_position -> batch n_heads query_position d_head',
            values,
            prob_dist)

        out_per_heads = einsum(
            'batch n_heads query_position d_head, n_heads d_head d_model -> batch n_heads query_position d_model',
            sum_after_attention,
            self.W_out) + self.b_out

        print(normalized_resid_pre + torch.sum(out_per_heads, dim=1))

        return normalized_resid_pre + torch.sum(out_per_heads, dim=1)

## Tests

In [74]:
class Tests(unittest.TestCase):
    def setUp(self):
        self.reference_gpt2 = self.get_reference_gpt2()
        reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
        self.tokens = self.reference_gpt2.to_tokens(reference_text)
        if IN_COLAB:
            self.tokens = self.tokens.cuda()
        self.cache_dict = self.get_gpt2_cache_dict(self.tokens)
        self.cfg = Config(debug=True)
        
    def get_reference_gpt2(self):
        return EasyTransformer.from_pretrained(
            "gpt2-small",
            fold_ln=False,
            center_unembed=False,
            center_writing_weights=False)

    def get_gpt2_cache_dict(self, tokens):    
        _, cache = self.reference_gpt2.run_with_cache(tokens)
        return cache.cache_dict

    def rand_float_test(self, cls, shape):
        layer = cls(self.cfg)
        if IN_COLAB:
            layer = layer.cuda()
        random_input = torch.randn(shape)
        if IN_COLAB:
            random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def rand_int_test(self, cls, shape):
        layer = cls(self.cfg)
        if IN_COLAB:
            layer = layer.cuda()
        random_input = torch.randint(100, 1000, shape)
        if IN_COLAB:
            random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def load_gpt2_test(self, cls, gpt2_layer, input_name):
        layer = cls(cfg)
        if IN_COLAB:
            layer = layer.cuda()
        layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
        # Allow inputs of strings or tensors
        if isinstance(input_name, str):
            reference_input = self.cache_dict[input_name]
        else:
            reference_input = input_name
        print(reference_input.shape)
        output = layer(reference_input)
        print(output.shape)
        reference_output = gpt2_layer(reference_input)
        print(reference_output.shape)
        comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
        correct_ratio = comparison.sum()/comparison.numel()
        self.assertEqual(correct_ratio, 1, f'{torch.round(correct_ratio * 100)}% of values are correct')
        return output

    def test_layer_norm(self):
        self.rand_float_test(LayerNorm, [2, 4, 768])
        self.load_gpt2_test(LayerNorm, self.reference_gpt2.ln_final, "blocks.11.hook_resid_post")

    def test_embed(self):
        self.rand_int_test(Embed, [2, 4])
        self.load_gpt2_test(Embed, self.reference_gpt2.embed, self.tokens)

    def test_pos_embed(self):
        self.rand_int_test(PosEmbed, [2, 4])
        self.load_gpt2_test(PosEmbed, self.reference_gpt2.pos_embed, self.tokens)

    def test_attention(self):
        self.rand_float_test(Attention, [2, 4, 768])
        self.load_gpt2_test(Attention, self.reference_gpt2.blocks[0].attn, "blocks.0.ln1.hook_normalized")

In [75]:
suite = unittest.TestSuite()
# suite.addTest(Tests('test_layer_norm'))
# suite.addTest(Tests('test_embed'))
# suite.addTest(Tests('test_pos_embed'))
suite.addTest(Tests('test_attention'))
runner = unittest.TextTestRunner()
runner.run(suite)

Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


F
FAIL: test_attention (__main__.Tests)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/tmp/ipykernel_8680/1241486984.py", line 76, in test_attention
    self.load_gpt2_test(Attention, self.reference_gpt2.blocks[0].attn, "blocks.0.ln1.hook_normalized")
  File "/tmp/ipykernel_8680/1241486984.py", line 59, in load_gpt2_test
    self.assertEqual(correct_ratio, 1, f'{torch.round(correct_ratio * 100)}% of values are correct')
AssertionError: tensor(0.0007) != 1 : 0.0% of values are correct

----------------------------------------------------------------------
Ran 1 test in 3.956s

FAILED (failures=1)


tensor([[[ 1.4781e+00,  9.4271e-02,  1.1739e+00,  ...,  4.5701e-01,
          -5.3349e-01, -1.5206e+00],
         [ 9.9553e-01,  1.1775e+00, -2.8085e-01,  ...,  1.1971e+00,
          -1.9991e-01, -1.8756e+00],
         [ 8.9338e-02,  1.2168e-04, -1.7418e+00,  ...,  1.2769e+00,
          -4.6312e-01, -1.0442e+00],
         [ 1.5478e+00, -1.9147e+00, -7.7552e-01,  ...,  1.0997e+00,
          -7.1781e-01, -1.0318e+00]],

        [[ 7.6765e-01,  1.6938e-01,  7.3310e-02,  ..., -1.4813e+00,
           2.3008e-01, -2.2283e-01],
         [ 2.3846e+00,  1.6257e+00,  5.1552e-01,  ...,  1.0848e+00,
          -8.0254e-01, -1.6505e-01],
         [ 6.4801e-01,  9.7900e-02,  2.8726e-01,  ..., -5.5491e-01,
          -1.6021e+00, -3.6980e-01],
         [-1.4114e+00,  3.2381e+00,  1.8587e-01,  ..., -2.7156e-01,
          -7.7272e-01,  4.2269e-01]]], grad_fn=<AddBackward0>)
torch.Size([1, 35, 768])
tensor([[[ 0.0184, -0.0613, -0.0540,  ..., -0.0196,  0.1006,  0.0794],
         [ 0.1735, -0.0779, -0.0592,

<unittest.runner.TextTestResult run=1 errors=0 failures=1>