# [What is a Transformer?](https://youtu.be/bOYE6E8JrtU)


In [1]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from transformer_lens import HookedTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from transformer_lens.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm
# from unseal import transformers_util as tutil
# from unseal import hooks
# import unseal.visuals.utils as utils

In [2]:
device = torch.device(
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
device

device(type='mps')

In [3]:
model_name = "gpt2"

In [4]:
reference_gpt2 = HookedTransformer.from_pretrained(
    model_name, fold_ln=False, center_unembed=False, center_writing_weights=False
)



Loaded pretrained model gpt2 into HookedTransformer


In [5]:
# unseal_gpt2, unseal_tokenizer, unseal_config = tutil.load_from_pretrained(model_name)
# unseal_gpt2.to(device)
# hooked_gpt2 = hooks.HookedModel(unseal_gpt2)

In [6]:
print(
    reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!")
)
print(
    reference_gpt2.to_tokens(
        "Whether a word begins with a capital or space matters!", prepend_bos=False
    )
)

tensor([[50256, 15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,
          6067,     0]], device='mps:0')
tensor([[15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,  6067,
             0]], device='mps:0')


In [7]:
reference_gpt2.state_dict().keys()

odict_keys(['embed.W_E', 'pos_embed.W_pos', 'blocks.0.ln1.w', 'blocks.0.ln1.b', 'blocks.0.ln2.w', 'blocks.0.ln2.b', 'blocks.0.attn.W_Q', 'blocks.0.attn.W_O', 'blocks.0.attn.b_Q', 'blocks.0.attn.b_O', 'blocks.0.attn.W_K', 'blocks.0.attn.W_V', 'blocks.0.attn.b_K', 'blocks.0.attn.b_V', 'blocks.0.attn.mask', 'blocks.0.attn.IGNORE', 'blocks.0.mlp.W_in', 'blocks.0.mlp.b_in', 'blocks.0.mlp.W_out', 'blocks.0.mlp.b_out', 'blocks.1.ln1.w', 'blocks.1.ln1.b', 'blocks.1.ln2.w', 'blocks.1.ln2.b', 'blocks.1.attn.W_Q', 'blocks.1.attn.W_O', 'blocks.1.attn.b_Q', 'blocks.1.attn.b_O', 'blocks.1.attn.W_K', 'blocks.1.attn.W_V', 'blocks.1.attn.b_K', 'blocks.1.attn.b_V', 'blocks.1.attn.mask', 'blocks.1.attn.IGNORE', 'blocks.1.mlp.W_in', 'blocks.1.mlp.b_in', 'blocks.1.mlp.W_out', 'blocks.1.mlp.b_out', 'blocks.2.ln1.w', 'blocks.2.ln1.b', 'blocks.2.ln2.w', 'blocks.2.ln2.b', 'blocks.2.attn.W_Q', 'blocks.2.attn.W_O', 'blocks.2.attn.b_Q', 'blocks.2.attn.b_O', 'blocks.2.attn.W_K', 'blocks.2.attn.W_V', 'blocks.2.attn

In [8]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


In [9]:
reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000")

['<|endoftext|>',
 '568',
 '73',
 '+',
 '318',
 '46',
 '23',
 '=',
 '123',
 '45',
 '67',
 '89',
 '-',
 '1',
 '000000',
 '000']

In [10]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
print(tokens)
# batch, position
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]], device='mps:0')
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


In [11]:
logits, cache = reference_gpt2.run_with_cache(tokens)
# batch, position, d_vocab
print(logits.shape)

torch.Size([1, 35, 50257])


In [12]:
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
print(log_probs.shape)
print(probs.shape)

torch.Size([1, 35, 50257])
torch.Size([1, 35, 50257])


In [13]:
list(
    zip(
        reference_gpt2.to_str_tokens(reference_text),
        reference_gpt2.tokenizer.batch_decode(
            logits.argmax(dim=-1)[0]
        ),  # Sample largest logit
    )
)

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' aut', 'od'),
 ('ore', 'sp'),
 ('gressive', '.'),
 (',', ' and'),
 (' dec', 'ently'),
 ('oder', ','),
 ('-', 'driven'),
 ('only', ' programmer'),
 (',', ' and'),
 (' G', 'IM'),
 ('PT', '-'),
 ('-', 'only'),
 ('2', '.'),
 (' style', ','),
 (' transformer', '.'),
 ('.', ' I'),
 (' One', ' of'),
 (' day', ' I'),
 (' I', ' will'),
 (' will', ' be'),
 (' exceed', ' my'),
 (' human', 'ly'),
 (' level', ' of'),
 (' intelligence', ' and'),
 (' and', ' I'),
 (' take', ' over'),
 (' over', ' the'),
 (' the', ' world'),
 (' world', '.'),
 ('!', ' I')]

In [14]:
# Map distribution to a token
next_token = logits[0, -1].argmax(dim=-1)
print(next_token)

tensor(314, device='mps:0')


In [15]:
next_tokens = torch.cat([tokens, next_token.clone().detach()[None, None]], dim=-1)
new_logits = reference_gpt2(next_tokens)
print("New Input:", next_tokens)
print(next_tokens.shape)
print("New Input:", reference_gpt2.tokenizer.decode(next_tokens[0]))

print(new_logits.shape)
print(new_logits[-1, -1].argmax(-1))

print(reference_gpt2.tokenizer.decode(new_logits[-1, -1].argmax(-1)))

New Input: tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0,   314]], device='mps:0')
torch.Size([1, 36])
New Input: <|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world! I
torch.Size([1, 36, 50257])
tensor(716, device='mps:0')
 am


## Config


In [16]:
for activation_name, activation in cache.cache_dict.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_pattern torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
ln_final.hook_scale torch.Size([1, 35, 1])
ln_final.hook_normalized torc

In [17]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.ln1.w torch.Size([768])
blocks.0.ln1.b torch.Size([768])
blocks.0.ln2.w torch.Size([768])
blocks.0.ln2.b torch.Size([768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
ln_final.w torch.Size([768])
ln_final.b torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


In [18]:
print(reference_gpt2.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 8.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'load_in_4bit': False,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_key_value_heads': None,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': 'GPT2LMHeadModel',
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional_

In [19]:
@dataclass
class Config:
    d_model: int = 768  # Embedding dimensions
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257  # Token dictionary size
    init_range: float = 0.02
    n_ctx: int = 1024  # Context length
    d_head: int = 64  # Dimensions of each attention head, d_model // n_heads
    d_mlp: int = 3072  # MLP hidden layer dimensions, d_model * 4
    n_heads: int = 12  # Attention heads per block
    n_layers: int = 12  # Number of transformer blocks


cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


## Tests


In [20]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randn(shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output


def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    random_input = torch.randint(100, 1000, shape).to(device)
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output


def load_gpt2_test(
    cls,
    gpt2_layer,
    input_name,
    cache_dict=cache.cache_dict,
    argcount=False,
):
    cfg = Config(debug=True)
    layer = cls(cfg).to(device)
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str):
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)

    # Modify this part to handle the different input requirements
    if argcount:
        # If the gpt2_layer expects separate inputs for query, key, and value
        reference_output = gpt2_layer(reference_input, reference_input, reference_input)
    else:
        # If the gpt2_layer expects a single input
        reference_output = gpt2_layer(reference_input)

    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

In [21]:
class LayerNorm(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual: torch.Tensor):
        # residual: [batch, position, d_model]
        mean = einops.reduce(
            residual, "batch position d_model -> batch position 1", "mean"
        )  # (batch, position, 1)
        variance = einops.reduce(
            residual.pow(2), "batch position d_model -> batch position 1", "mean"
        )

        normalized = (residual - mean) / (
            variance + self.cfg.layer_norm_eps
        ).sqrt()  # (batch, position, d_model)
        out = normalized * self.w + self.b  # (batch, position, d_model)
        return out

In [22]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
99.99% of the values are correct


In [23]:
class Embed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model)))
        nn.init.normal_(self.W_E, std=self.cfg.init_range)

    def forward(self, tokens: torch.Tensor):
        # tokens: [batch, position]
        return self.W_E[tokens]  # [batch, position, d_model]

In [24]:
rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       device='mps:0', grad_fn=<IndexBackward0>)

In [25]:
class PosEmbed(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens: torch.Tensor):
        # tokens: [batch, position]
        return self.W_pos[
            torch.arange(tokens.shape[1], device=device)  # [0, 1, 2, ..., position-1]
        ]  # (position, d_model)

In [26]:
rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([4, 768])

Input shape: torch.Size([1, 35])
Output shape: torch.Size([35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
          2.8267e-02,  5.4490e-02],
        [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
          1.0172e-02, -1.5573e-04],
        [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
          1.9325e-02, -2.1424e-02],
        ...,
        [ 4.6277e-04,  2.3037e-02,  4.1227e-02,  ..., -1.9287e-03,
         -2.3037e-03, -4.3189e-03],
        [-2.7136e-03,  2.1724e-02,  3.9675e-02,  ...,  4.2048e-04,
         -4.8160e-03, -9.2252e-04],
        [ 6.6815e-03,  2.0595e-02,  3.6596e-02,  ..., -9.5090e-04,
         -3.2512e-03, -9.6509e-04]], device='mps:0', grad_fn=<IndexBackward0>)

## Attention


In [27]:
# num_layers = tutil.get_num_layers(hooked_gpt2, layer_key_prefix="transformer->h")
# html_storage = dict()
# html_storage = utils.compute_attn_logits(
#     hooked_gpt2,
#     model_name,
#     unseal_tokenizer,
#     num_layers,
#     reference_text,
#     html_storage,
#     layer_key_prefix="transformer->h",
#     out_proj_name="c_proj",
#     batch_size=None,
# )

In [28]:
class Attention(nn.Module):
    def __init__(self, cfg: Config):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

    def forward(self, normalized_resid_pre: torch.Tensor):
        # normalized_resid_pre: [batch, position, d_model]
        q = (
            einsum(
                "batch query_pos d_model, n_heads d_model d_head -> batch query_pos n_heads d_head",
                normalized_resid_pre,
                self.W_Q,
            )
            + self.b_Q
        )
        k = (
            einsum(
                "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head",
                normalized_resid_pre,
                self.W_K,
            )
            + self.b_K
        )

        scores = einsum(
            "batch query_pos n_heads d_head, batch key_pos n_heads d_head -> batch n_heads query_pos key_pos",
            q,
            k,
        )
        scores = scores / math.sqrt(self.cfg.d_head)
        scores = self.apply_causal_mask(scores)

        scores = scores.softmax(dim=-1)  # (batch, n_head, query_pos, key_pos)

        v = (
            einsum(
                "batch key_pos d_model, n_heads d_model d_head -> batch key_pos n_heads d_head",
                normalized_resid_pre,
                self.W_V,
            )
            + self.b_V
        )

        heads_output = einsum(
            "batch n_heads query_pos key_pos, batch key_pos n_heads d_head -> batch query_pos n_heads d_head",
            scores,
            v,
        )

        attn_output = (
            einsum(
                "batch query_pos n_heads d_head, n_heads d_head d_model -> batch query_pos d_model",
                heads_output,
                self.W_O,
            )
            + self.b_O
        )

        return attn_output

    def apply_causal_mask(self, attn_scores: torch.Tensor):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        mask = torch.triu(
            torch.ones(
                attn_scores.size(-2), attn_scores.size(-1), device=attn_scores.device
            ),
            diagonal=1,
        ).bool()
        attn_scores.masked_fill_(mask, float("-inf"))
        return attn_scores

In [29]:
rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(
    Attention,
    reference_gpt2.blocks[0].attn,
    cache["blocks.0.ln1.hook_normalized"],
    argcount=True,
)

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 7.9663e-01,  1.6985e-02,  3.4781e-02,  ...,  3.3119e-02,
          -2.3129e-02,  1.8103e-01],
         [ 1.3155e-03,  1.5750e-01, -1.4059e-01,  ..., -8.1998e-03,
           5.3075e-03,  1.3511e-01],
         [ 8.9737e-02, -7.2411e-01, -6.9866e-01,  ...,  5.5321e-02,
           2.7958e-03,  9.0785e-02],
         ...,
         [-3.0286e-01,  4.9638e-02, -6.0990e-01,  ..., -3.7084e-02,
          -4.9522e-04, -8.6008e-03],
         [-1.0844e+00, -6.1457e-02,  2.2966e-01,  ..., -2.6689e-02,
          -1.4368e-02,  3.3245e-02],
         [ 3.7947e-01, -4.9886e-01,  2.6434e-01,  ..., -2.7894e-02,
          -8.9028e-03,  4.8796e-02]]], device='mps:0', grad_fn=<AddBackward0>)

In [30]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))

    def forward(self, normalized_resid_mid: torch.Tensor):
        # normalized_resid_mid: [batch, position, d_model]
        hidden_layer = (
            einsum(
                "batch position d_model, d_model d_mlp -> batch position d_mlp",
                normalized_resid_mid,
                self.W_in,
            )
            + self.b_in
        )
        activation = gelu_new(hidden_layer)
        output_layer = (
            einsum(
                "batch position d_mlp, d_mlp d_model -> batch position d_model",
                activation,
                self.W_out,
            )
            + self.b_out
        )
        return output_layer

In [31]:
rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[-0.4380,  0.3624,  0.5117,  ...,  1.7227,  1.5761,  0.0368],
         [-1.0766, -0.0438,  0.3276,  ..., -0.5437,  0.4033,  0.3717],
         [-1.2182, -1.5481, -0.9702,  ...,  1.0737,  0.7199,  0.5080],
         ...,
         [-0.4004,  0.8475,  0.2047,  ...,  0.3789,  0.0455, -0.4744],
         [-0.0862,  0.7839,  0.9046,  ..., -0.2174, -0.5953,  0.8555],
         [ 0.8448, -0.3743,  1.0397,  ...,  0.0296,  0.3405,  0.3585]]],
       device='mps:0', grad_fn=<AddBackward0>)

In [32]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre: torch.Tensor):
        # resid_pre [batch, position, d_model]
        normalized_resid_pre = self.ln1(resid_pre)
        attn_out = self.attn(normalized_resid_pre)
        resid_mid = resid_pre + attn_out

        normalized_resid_mid = self.ln2(resid_mid)
        mlp_out = self.mlp(normalized_resid_mid)
        resid_post = resid_mid + mlp_out
        return resid_post

In [33]:
rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
96.70% of the values are correct


tensor([[[ 0.3911,  0.1542,  0.6004,  ...,  1.7199,  1.7365,  0.3928],
         [-0.9035, -0.0360,  0.2350,  ..., -0.4147,  0.3562,  0.3932],
         [-0.9644, -2.4814, -1.4994,  ...,  1.4044,  0.7617,  0.5915],
         ...,
         [-0.7420,  0.9250, -0.3219,  ...,  0.2920,  0.1097, -0.5344],
         [-1.3221,  0.8959,  1.1793,  ..., -0.5544, -0.4071,  0.9254],
         [ 1.1209, -0.8919,  1.3737,  ..., -0.1356,  0.3435,  0.4517]]],
       device='mps:0', grad_fn=<AddBackward0>)

In [34]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(self, normalized_resid_final: torch.Tensor):
        # normalized_resid_final [batch, position, d_model]
        logits = (
            einsum(
                "batch position d_model, d_model d_vocab -> batch position d_vocab",
                normalized_resid_final,
                self.W_U,
            )
            + self.b_U
        )

        return logits

In [35]:
rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[ -43.4317,  -39.8364,  -43.0660,  ...,  -54.0877,  -54.3451,
           -42.3644],
         [-128.0392, -127.9936, -130.7010,  ..., -136.7121, -129.9261,
          -129.3965],
         [-119.8521, -121.0064, -123.8819,  ..., -128.5180, -126.6028,
          -121.9060],
         ...,
         [-112.9815, -112.7749, -117.0633,  ..., -121.2914, -117.6574,
          -114.5005],
         [ -98.6725, -104.4889, -108.7362,  ..., -118.3553, -113.8767,
          -106.3605],
         [-126.8285, -128.9596, -128.3941,  ..., -140.1970, -138.5883,
          -122.3697]]], device='mps:0', grad_fn=<AddBackward0>)

In [36]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList(
            [TransformerBlock(cfg) for _ in range(cfg.n_layers)]
        )
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):
        # tokens [batch, position]
        embed = self.embed(tokens)
        pos_embed = self.pos_embed(tokens)
        residual = embed + pos_embed
        for block in self.blocks:
            residual = block(residual)
        normalized_resid_final = self.ln_final(residual)
        logits = self.unembed(normalized_resid_final)
        # logits have shape [batch, position, logits]
        return logits

In [37]:
rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 50257])

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 50257])
Reference output shape: torch.Size([1, 35, 50257])
100.00% of the values are correct


tensor([[[ -43.4543,  -39.8558,  -43.0849,  ...,  -54.1044,  -54.3628,
           -42.3845],
         [-128.0192, -127.9733, -130.6802,  ..., -136.6926, -129.9065,
          -129.3768],
         [-119.8427, -120.9960, -123.8709,  ..., -128.5087, -126.5927,
          -121.8961],
         ...,
         [-112.9791, -112.7721, -117.0613,  ..., -121.2890, -117.6549,
          -114.4980],
         [ -98.6525, -104.4703, -108.7175,  ..., -118.3360, -113.8571,
          -106.3412],
         [-126.8190, -128.9513, -128.3855,  ..., -140.1897, -138.5799,
          -122.3602]]], device='mps:0', grad_fn=<AddBackward0>)