GPT2 implementation [WIP] based on [Neel Nanda's Clean Transformer Template](https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/clean-transformer-demo/Clean_Transformer_Demo_Template.ipynb)

In [1]:
from easy_transformer import EasyTransformer
from dataclasses import dataclass
import torch
from torch import nn
import einops

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
try:
  import google.colab
  IN_COLAB = True
except:
  IN_COLAB = False

In [3]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
if IN_COLAB:
    tokens = tokens.cuda()
logits, cache = reference_gpt2.run_with_cache(tokens)

Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


In [4]:
@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()

In [5]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    if IN_COLAB:
        layer = layer.cuda()
    random_input = torch.randn(shape)
    if IN_COLAB:
        random_input = random_input.cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg)
    if IN_COLAB:
        layer = layer.cuda()
    random_input = torch.randint(100, 1000, shape)
    if IN_COLAB:
        random_input = random_input.cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg)
    if IN_COLAB:
        layer = layer.cuda()
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str):
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)
    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

In [6]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        mean = einops.reduce(residual, 'b p d -> b p', 'mean')
        broadcast_mean = einops.repeat(mean,'b p -> b p d', d=cfg.d_model)
        print(residual.shape)
        print(broadcast_mean.shape)
        residual -= broadcast_mean
        std_dev = torch.sqrt(einops.reduce(residual ** 2, 'b p d -> b p', 'mean') + cfg.layer_norm_eps)
        broadcast_std_dev = einops.repeat(std_dev, 'b p -> b p d', d=cfg.d_model)
        normalized = residual / broadcast_std_dev
        return normalized * self.w + self.b


In [7]:
rand_float_test(LayerNorm, [2, 4, 768])
load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
torch.Size([1, 35, 768])
torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.0667,  0.0881, -0.3085,  ...,  0.0307,  0.0512, -0.0019],
         [ 0.0278, -0.2843,  0.2504,  ...,  0.0993,  0.0567,  0.0519],
         [-0.5468, -0.5119, -0.6429,  ...,  0.2615, -0.1498,  0.1759],
         ...,
         [ 0.3988, -0.1717,  0.0907,  ..., -0.0095,  0.5077,  0.1327],
         [-0.0164, -0.3170, -1.5848,  ..., -0.0970,  0.3219,  0.1132],
         [ 0.2771, -0.4338, -0.2735,  ...,  0.1036, -0.1892,  0.0365]]],
       grad_fn=<AddBackward0>)

In [8]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens):
        # tokens: [batch, position]
        one_hot_tokens = nn.functional.one_hot(tokens, num_classes = cfg.d_vocab).float()
        # one_hot_tokens = einops.rearrange(one_hot_tokens, 'b p v -> (b p) v')
        print(one_hot_tokens.shape)
        print(self.W_E.shape)
        print(one_hot_tokens)
        print(self.W_E)
        return torch.matmul(one_hot_tokens, self.W_E)

rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
torch.Size([2, 4, 50257])
torch.Size([50257, 768])
tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]])
Parameter containing:
tensor([[-0.0062,  0.0263,  0.0154,  ..., -0.0125, -0.0275, -0.0333],
        [-0.0087,  0.0454,  0.0400,  ..., -0.0160,  0.0122, -0.0161],
        [-0.0011, -0.0049,  0.0183,  ...,  0.0088,  0.0225,  0.0035],
        ...,
        [ 0.0214, -0.0059, -0.0252,  ..., -0.0079,  0.0016, -0.0199],
        [ 0.0416, -0.0020,  0.0497,  ..., -0.0060, -0.0022,  0.0104],
        [ 0.0101,  0.0414, -0.0059,  ...,  0.0102, -0.0337, -0.0387]],
       requires_grad=True)
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
torch.Size([1, 35, 50257])
tor

tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       grad_fn=<UnsafeViewBackward0>)