[WIP] GPT2 implementation based on [Neel Nanda's Clean Transformer Video Tutorial](https://www.youtube.com/watch?v=bOYE6E8JrtU&list=PL7m7hLIqA0hoIUPhC26ASCVs_VrqcDpAz&index=2&ab_channel=NeelNanda) and Template.

In [1]:
from dataclasses import dataclass
import torch
from torch import nn
import einops
import unittest
from fancy_einsum import einsum
import math
from easy_transformer import EasyTransformer
from easy_transformer.utils import gelu_new

In [2]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    max_context: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

In [3]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        mean = einops.reduce(residual, 'batch position d_model -> batch position', 'mean')
        broadcast_mean = einops.repeat(mean,'batch position -> batch position d_model', d_model=self.cfg.d_model)
        residual = residual - broadcast_mean
        # TODO: For some reason, residual -= broadcast_mean makes the TransformerBlock test fail. Why?
        std_dev = torch.sqrt(einops.reduce(residual ** 2, 'batch position d_model -> batch position', 'mean') + self.cfg.layer_norm_eps)
        broadcast_std_dev = einops.repeat(std_dev, 'batch position -> batch position d_model', d_model=self.cfg.d_model)
        normalized = residual / broadcast_std_dev
        return normalized * self.w + self.b

In [4]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        one_hot_tokens = nn.functional.one_hot(tokens, num_classes = self.cfg.d_vocab).float()
        return one_hot_tokens @ self.W_E

In [5]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.max_context, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)
    
    def forward(self, tokens):
        # tokens: [batch, position]
        batch_size, max_tokens = tokens.shape
        truncuated_W_pos = self.W_pos[:max_tokens, :]
        return torch.broadcast_to(truncuated_W_pos, (batch_size, max_tokens, self.cfg.d_model))

In [6]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        
        # Parameters to calculate queries
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        # Parameters to calculate keys
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        # Parameters to calculate values
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        # Parameters to combine head outputs
        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))
        
        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))
            
    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        
        queries = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_Q) + self.b_Q

        keys = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_K) + self.b_K

        values = einsum(
            'batch position d_model, n_heads d_model d_head -> batch position n_heads d_head',
            normalized_resid_pre, self.W_V) + self.b_V

        prob_dist = self._get_attention(queries, keys)

        sum_after_attention = einsum(
            'batch key_position n_heads d_head, batch n_heads query_position key_position -> batch n_heads query_position d_head',
            values,
            prob_dist)
        
        out = einsum(
            'batch n_heads query_position d_head, n_heads d_head d_model -> batch query_position d_model',
            sum_after_attention,
            self.W_O) + self.b_O
        
        return out

    def _get_attention(self, queries, keys):
        attention_scores = einsum(
            'batch query_position n_heads d_head, batch key_position n_heads d_head -> batch n_heads query_position key_position',
            queries,
            keys)
        attention_scores = attention_scores / math.sqrt(self.cfg.d_head)
        mask = torch.triu(torch.ones(attention_scores.shape[-2], attention_scores.shape[-1]), diagonal=1).bool().cuda()
        attention_scores.masked_fill_(mask, self.IGNORE)
        prob_dist = torch.softmax(attention_scores, dim=3)
        return prob_dist

In [7]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))
    
    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        middle = einsum(
            'd_model d_mlp, batch position d_model-> batch position d_mlp',
            self.W_in,
            normalized_resid_mid) + self.b_in
        after_non_lin = gelu_new(middle)
        out = einsum(
            'd_mlp d_model, batch position d_mlp -> batch position d_model',
            self.W_out,
            after_non_lin) + self.b_out
        return out


In [8]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)
    
    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        attn = self.attn(self.ln1(resid_pre))
        resid_mid = resid_pre + attn
        mlp_out = self.mlp(self.ln2(resid_mid))
        out = resid_mid + mlp_out
        return out

In [9]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))
    
    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        return einsum(
            'd_model d_vocab, batch position d_model -> batch position d_vocab',
            self.W_U,
            normalized_resid_final)

In [10]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)
    
    def forward(self, tokens):
        # tokens [batch, position]
        res = self.embed(tokens) + self.pos_embed(tokens)
        for block in self.blocks:
            res = block(res)
        normalized_res = self.ln_final(res)
        logits = self.unembed(normalized_res)
        return logits

## Tests

In [11]:
class Tests(unittest.TestCase):

    @classmethod
    def setUpClass(cls):
        cls.reference_gpt2 = cls.get_reference_gpt2()
        reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
        cls.tokens = cls.reference_gpt2.to_tokens(reference_text)
        cls.tokens = cls.tokens.cuda()
        cls.cache = cls.get_gpt2_cache_dict(cls.tokens)
        cls.cfg = Config(debug=True)
    
    @classmethod
    def get_reference_gpt2(cls):
        return EasyTransformer.from_pretrained(
            "gpt2-small",
            fold_ln=False,
            center_unembed=False,
            center_writing_weights=False)

    @classmethod
    def get_gpt2_cache_dict(cls, tokens):    
        _, cache = cls.reference_gpt2.run_with_cache(tokens)
        return cache

    def rand_float_test(self, cls, shape):
        layer = cls(self.cfg)
        layer = layer.cuda()
        random_input = torch.randn(shape)
        random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def rand_int_test(self, cls, shape):
        layer = cls(self.cfg)
        layer = layer.cuda()
        random_input = torch.randint(100, 1000, shape)
        random_input = random_input.cuda()
        output = layer(random_input)
        return output

    def load_gpt2_test(self, cls, gpt2_layer, input_name):
        layer = cls(self.cfg)
        layer = layer.cuda()
        layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
        # Allow inputs of strings or tensors
        if isinstance(input_name, str):
            reference_input = self.cache.cache_dict[input_name]
        else:
            reference_input = input_name
        reference_input = reference_input.cuda()
        output = layer(reference_input)
        reference_output = gpt2_layer(reference_input)
        comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
        correct_ratio = comparison.sum()/comparison.numel()
        self.assertEqual(correct_ratio, 1, f'{torch.round(correct_ratio * 100)}% of values are correct')
        return output

    def test_layer_norm(self):
        self.rand_float_test(LayerNorm, [2, 4, 768])
        self.load_gpt2_test(LayerNorm, self.reference_gpt2.ln_final, "blocks.11.hook_resid_post")

    def test_embed(self):
        self.rand_int_test(Embed, [2, 4])
        self.load_gpt2_test(Embed, self.reference_gpt2.embed, self.tokens)

    def test_pos_embed(self):
        self.rand_int_test(PosEmbed, [2, 4])
        self.load_gpt2_test(PosEmbed, self.reference_gpt2.pos_embed, self.tokens)

    def test_attention(self):
        self.rand_float_test(Attention, [2, 4, 768])
        self.load_gpt2_test(Attention, self.reference_gpt2.blocks[0].attn, "blocks.0.ln1.hook_normalized")

    def test_mlp(self):
        self.rand_float_test(MLP, [2, 4, 768])
        self.load_gpt2_test(MLP, self.reference_gpt2.blocks[0].mlp, "blocks.0.ln2.hook_normalized")

    def test_transformer_block(self):
        self.rand_float_test(TransformerBlock, [2, 4, 768])
        self.load_gpt2_test(TransformerBlock, self.reference_gpt2.blocks[0], self.cache["resid_pre", 0])

    def test_unembed(self):
        self.rand_float_test(Unembed, [2, 4, 768])
        self.load_gpt2_test(Unembed, self.reference_gpt2.unembed, "ln_final.hook_normalized")

    def test_full_transformer(self):
        self.rand_int_test(DemoTransformer, [2, 4])
        self.load_gpt2_test(DemoTransformer, self.reference_gpt2, self.tokens)

In [12]:
suite = unittest.TestSuite()
suite.addTest(Tests('test_layer_norm'))
suite.addTest(Tests('test_embed'))
suite.addTest(Tests('test_pos_embed'))
suite.addTest(Tests('test_attention'))
suite.addTest(Tests('test_mlp'))
suite.addTest(Tests('test_transformer_block'))
suite.addTest(Tests('test_unembed'))
suite.addTest(Tests('test_full_transformer'))

runner = unittest.TextTestRunner()
runner.run(suite)

Moving model to device:  cuda
Finished loading pretrained model gpt2-small into EasyTransformer!


........
----------------------------------------------------------------------
Ran 8 tests in 7.754s

OK


<unittest.runner.TextTestResult run=8 errors=0 failures=0>

## Try with pretrained weights

In [13]:
demo_gpt2 = DemoTransformer(Config(debug=False))
reference_gpt2 = EasyTransformer.from_pretrained(
            "gpt2-small",
            fold_ln=False,
            center_unembed=False,
            center_writing_weights=False)
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
demo_gpt2.cuda()

test_string = """Mini scule is a species of microhylid frog endemic to Madagascar that was described in 2019. The scientific name of the species refers to its size, being a pun on the word minuscule. It is very small, measuring only 8.4 to 10.8 mm (0.33 to 0.43 in) in snout–vent length. It has bronze underparts with a brown groin and back of the thigh, cream upperparts with brown flecking, a dark brown side of the head, and a red iris. On the hind feet, the first toe is absent and the second and fifth toes are strongly reduced. The frog is known only from the Sainte Luce Reserve, where it inhabits areas with deep leaf litter near semi-permanent water bodies. Specimens of frogs from Mandena, the Vohimena mountains, the southern Anosy Mountains, and Tsitongambarika may also be of this species. Along with Mini mum and Mini ature, the other two species in its genus, it received media attention when first described due to the wordplay in its scientific name. (Full article...)"""
test_tokens = reference_gpt2.to_tokens(test_string).cuda()
demo_logits = demo_gpt2(test_tokens)

Moving model to device:  cuda
Finished loading pretrained model gpt2-small into EasyTransformer!


In [14]:
# logits is [batch, position, vocab_size]
b, p, v = demo_logits.shape
reshaped_logits = demo_logits[:, :-1,:].view(b * (p - 1), v)
reshaped_targets = test_tokens[:,1:].view(b * (p - 1))
loss = torch.nn.functional.cross_entropy(reshaped_logits, reshaped_targets)
print(loss)
print("Loss as average prob", (-loss).exp())
print("Loss as 'uniform over this many variables'", (loss).exp())
print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))

tensor(3.7186, device='cuda:0', grad_fn=<NllLossBackward0>)
Loss as average prob tensor(0.0243, device='cuda:0', grad_fn=<ExpBackward0>)
Loss as 'uniform over this many variables' tensor(41.2080, device='cuda:0', grad_fn=<ExpBackward0>)
Uniform loss over the vocab 10.82490511970208
