# Introduction

This is a template for a clean, first principles implementation of GPT-2 in PyTorch. This is an accompaniment to [my video tutorial on implementing GPT-2](https://neelnanda.io/transformer-tutorial-2). If you want to properly understand how to implement GPT-2, you'll need to do it yourself! **I recommend filling out this template *as* you watch the video, and seeing how far you can get with each section before watching me do it**. Each section comes with tests, so you can check that you got it right. You can see [a solution notebook here](https://www.neelnanda.io/transformer-solution) - use it if you get stuck, but make an attempt first, and no copying and pasting!

There's a [template version of this notebook here](https://neelnanda.io/transformer-template), go and fill in the blanks (no copying and pasting!) and see if you can pass the tests. 

If you enjoyed this, I expect you'd enjoy learning more about what's actually going on inside these models and how to reverse engineer them! This is a fascinating young research field, with a lot of low-hanging fruit and open problems! **I recommend starting with my post [Concrete Steps for Getting Started in Mechanistic Interpretability](https://www.neelnanda.io/mechanistic-interpretability/getting-started).**

This notebook was written to accompany my [TransformerLens library](https://github.com/neelnanda-io/TransformerLens) for doing mechanistic interpretability research on GPT-2 style language models, and is a clean implementation of the underlying transformer architecture in the library. (This notebook is based off of an earlier version called EasyTransformer)

Further Resources:
* [A Comprehensive Mechanistic Interpretability Explainer & Glossary](https://www.neelnanda.io/glossary) 
    * Expecially [the transformers section](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=pndoEIqJ6GPvC1yENQkEfZYR)
* [200 Concrete Open Problems in Mechanistic Interpretability](https://www.neelnanda.io/concrete-open-problems)
* My [TransformerLens library](https://github.com/neelnanda-io/TransformerLens) for doing mechanistic interpretability research on GPT-2 style language models.
* My walkthrough of [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html), for a deeper dive into how to think about transformers:.

Check out these other intros to transformers for another perspective:
* Jay Alammar's [illustrated transformer](https://jalammar.github.io/illustrated-transformer/)
* [Andrej Karpathy's MinGPT](https://github.com/karpathy/minGPT)

**Sharing Guidelines:** This tutorial is still a bit of a work in progress! I think it's usable, but please don't post it anywhere publicly without checking with me first! Sharing with friends is fine. 

If you've found this useful, I'd love to hear about it! Positive and negative feedback also very welcome. You can reach me via [email](mailto:neelnanda27@gmail.com)

## Instructions
* No need to read the Setup Section
* Go to runtime > Change Runtime Type and set it to use a GPU
* Read and run notebook up until the start of the section "Actual Code!". Then go to the template notebook and try coding up the model yourself!
    * Bonus points for doing that without reading the solutions, and before I do it in the video!

# Setup

In [1]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  # Install another version of node that makes PySvelte work way faster
  !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
  %pip install git+https://github.com/neelnanda-io/PySvelte.git
  %pip install fancy_einsum
  %pip install einops
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")

Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git (to revision clean-transformer-demo) to /tmp/pip-req-build-ah2_edaw
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-ah2_edaw
  Running command git checkout -b clean-transformer-demo --track origin/clean-transformer-demo
  Switched to a new branch 'clean-transformer-demo'
  Branch 'clean-transformer-demo' set up to track remote branch 'clean-transformer-demo' from 'origin'.
  Resolved https://github.com/neelnanda-io/Easy-Transformer.git to commit 1f25219e631aeb478d17075d47274db32c874e88
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━

In [2]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm
import matplotlib.pyplot as plt


In [3]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

Downloading (…)lve/main/config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/548M [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Moving model to device:  cpu
Finished loading pretrained model gpt2-small into EasyTransformer!


# Understanding Inputs & Outputs of a Transformer

## What is the point of a transformer?

**Transformers exist to model text!**

We're going to focus GPT-2 style transformers. Key feature: They generate text! You feed in language, and the model generates a probability distn over tokens. And you can repeatedly sample from this to generate text!

### How is the model trained?

You give it a bunch of text, and train it to predict the next token.

Importantly, if you give a model 100 tokens in a sequence, it predicts the next token for *each* prefix, ie it produces 100 predictions. This is kinda weird but it's much easier to make one that does this. And it also makes training more efficient, because you can 100 bits of feedback rather than just one.

#### Objection: Isn't this trivial for the first 99?

No! We make the transformer have *causal attention*. The core thing is that it can only move information forwards in the sequence. The prediction of what comes after token 50 is only a function of the first 50 tokens, *not* of token 51. (Jargon: *autoregressive*)

### Key takeaway:

Transformers are *sequence modelling engines*. It does the same processing in parallel at each sequence position, can move information between positions with attention, and conceptually can take a sequence of arbitrary length (not actually true, see later)

## Tokens - Transformer Inputs

Core point: Input is language (ie a sequence of characters, strings, etc)

### How do we convert language to vectors?

ML models take in vectors, not weird shit like language - how do we convert?

#### Idea: integers to vectors

We basically make a lookup table. Called an embedding. 

Jargon: **One-hot encoding** We map eg numbers from 1 to 100, to a 100-dim vector, with a 1 in the kth position, 0 everywhere else. Key intuition is that one-hot encodings let you think about each integer independently - useful when integers = labels. 

Dimensions = things that vary independently. Each input has its own dimension, so each input can be thought of independently, we don't bake in any relation.

Lookup tables <=> Multiply a fixed matrix by the one-hot encoded vector. 


### Tokens: Language to sequence of integers

Core idea: We need a model that can deal with arbitrary text. We want to convert this into integers, *and* we want these integers to be in a bounded range. 

**Idea:** Form a vocabulary!

**Idea 1:** Get a dictionary! 

**Problem:** It can't cope with arbitrary text (eg URLs, punctuation, etc) Can't cope with mispellings.

**Idea 2:** Vocab = 256 ASCII characters. Fixed vocab size, can do arbitrary text, etc.

**Problem:** Loses structure of language - some sequences of characters are more meaningful than others.

Eg "language" is a lot more meaningful than "hjksdfiu" - we want the first to be a single token, second to not be. It's a more efficient use of our vocab.

#### What Actually Happens?

This super cursed thing called Byte-Pair Encodings

Ġ ~ means begins with a space, tokens with a leading space vs not are different.

We begin with the 256 ASCII characters as our tokens, and then find the most common pair of tokens, and merge that into a new token. Eg " t" is the most common pair, so it's our next token! Repeat 50000 times...

In [4]:
sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n:n[1])
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



Gets to weird esoteric shit.

In [5]:
sorted_vocab[-20:]

[('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

Use the `to_tokens` method to convert text to numbers

Prepends with a special token to give attention a resting position, disable with `prepend_bos=False`

In [6]:
print(reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!"))
print(reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!", prepend_bos=False))

tensor([[50256, 15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,
          6067,     0]])
tensor([[15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,  6067,
             0]])


### Rant: Tokenization is a Headache

Whether a word begins with a capital or space matters!

In [7]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


Arithmetic is a total mess: Length is inconsistent, common numbers bundle together

In [8]:
reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000")

['<|endoftext|>',
 '568',
 '73',
 '+',
 '318',
 '46',
 '23',
 '=',
 '123',
 '45',
 '67',
 '89',
 '-',
 '1',
 '000000',
 '000']

### Key Takeaway:

* We learn a dictionary of vocab of tokens (sub-words).

* We (approx) losslessly convert language to integers via tokenizing it.

* We convert integers to vectors via a lookup table.

* Note: input to the transformer is a sequence of *tokens* (ie integers), not vectors

## Logits - Transformer Outputs

**Goal:** Probability distribution over next tokens. (for every *prefix* of the sequence - given n tokens, we make n next token predictions)

**Problem:** How to convert a vector to a probability distribution? 

**Answer:** Use a softmax ($x_i \to \frac{e^{x_i}}{\sum e^{x_j}}$), exponential makes everything positive, normalization makes it add to one.

So the model outputs a tensor of logits, one vector of size $d_{vocab}$ for each input token.

We can use this to generate things!

## Generation!

**Step 1:** Convert text to tokens

Shape = batch x position

In [9]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]])
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


**Step 2:** Map tokens to logits

(run_with_cache means cache all intermediate activations, not important right now)

shape = batch x position x d_vocab

In [10]:

logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 35, 50257])


**Step 3:** Convert the logits to a distribution with a softmax

In [11]:
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
print(log_probs.shape)
print(probs.shape)

torch.Size([1, 35, 50257])
torch.Size([1, 35, 50257])


**Bonus step:** What is the most likely next token at each position?

In [12]:
list(zip(reference_gpt2.to_str_tokens(reference_text), reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' aut', 'od'),
 ('ore', 'sp'),
 ('gressive', '.'),
 (',', ' and'),
 (' dec', 'ently'),
 ('oder', ','),
 ('-', 'driven'),
 ('only', ' programmer'),
 (',', ' and'),
 (' G', 'IM'),
 ('PT', '-'),
 ('-', 'only'),
 ('2', '.'),
 (' style', ','),
 (' transformer', '.'),
 ('.', ' I'),
 (' One', ' of'),
 (' day', ' I'),
 (' I', ' will'),
 (' will', ' be'),
 (' exceed', ' my'),
 (' human', 'ly'),
 (' level', ' of'),
 (' intelligence', ' and'),
 (' and', ' I'),
 (' take', ' over'),
 (' over', ' the'),
 (' the', ' world'),
 (' world', '.'),
 ('!', ' I')]

**Step 4:** Map distribution to a token

In [13]:
next_token = logits[0, -1].argmax(dim=-1)
print(next_token)

tensor(314)


**Step 5:** Add this to the end of the input, re-run

(More efficient ways to do this, but whatever, doesn't matter conceptually)

In [14]:
next_tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)
new_logits = reference_gpt2(next_tokens)
print("New Input:", next_tokens)
print(next_tokens.shape)
print("New Input:", reference_gpt2.tokenizer.decode(next_tokens[0]))

print(new_logits.shape)
print(new_logits[-1, -1].argmax(-1))

print(reference_gpt2.tokenizer.decode(new_logits[-1, -1].argmax(-1)))


  next_tokens = torch.cat([tokens, torch.tensor(next_token, dtype=torch.int64)[None, None]], dim=-1)


New Input: tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0,   314]])
torch.Size([1, 36])
New Input: <|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world! I
torch.Size([1, 36, 50257])
tensor(716)
 am


## Key takeaways:

* Takes in language, predicts next token (for *each* token in a causal way)
* We convert language to a sequence of integers with a tokenizer.
* We convert integers to vectors with a lookup table.

* Output is a vector of logits (one for each input token), we convert to a probability distn with a softmax, and can then convert this to a token (eg taking the largest logit, or sampling).

* We append this to the input + run again to generate more text (Jargon: *autoregressive*)

* Meta level point: Transformers are sequence operation models, they take in a sequence, do processing in parallel at each position, and use attention to move information between positions!

# Clean Transformer Implementation

![](https://github.com/neelnanda-io/Easy-Transformer/blob/clean-transformer-demo/transformer_overview.png?raw=1)

High-Level architecture:

Go watch my [Transformer Circuits walkthrough](https://www.youtube.com/watch?v=KV5gbOmHbjU) if you want more intuitions!

(Diagram is bottom to top)

* Input tokens, integers
* Embedding is a lookup table mapping tokens to vectors
    * Lives in the *residual stream*
* Residual stream - the sum of all previous outputs of layers of the model, is the input to each new layer.
    * *Really* fundamental. It's the central object of the transformer.
        * It's how model remembers things, moves information between layers for composition, and it's the medium used to store the information that attention moves between positions.
* Then we have a series of $n_{layers}$ transformer blocks
    * Confusing jargon - a block contains an attention layer *and* an MLP layer, but we say a transformer has k layers if it has k blocks (ie 2k total layers).
* First we have attention. This moves information from prior positions in the sequence to the current token. 
    * We do this for *every* token in parallel using the same parameters. The only difference is that we look backwards only, so later tokens get more room to look back.
        * We look backwards so we can predict the next token without cheating.
    * Only bit of a transformer that moves information between positions.
    * Made up of $n_heads$ heads - each with their own parameters, own attention pattern, and own information how to copy things from source to destination.
        * The heads act independently and additively, we just add their outputs together, and back to the stream
    * Each head:
        * Produces an attention pattern for each destination token, a probability distribution of prior source tokens (including the current one) weighting how much information to copy.
            * Do this for each pair of tokens
            * Copy information in the same way from each source token.
                * What information we copy *does* depend on the source token's *residual stream*. This does not necessarily mean the info of what text token is at the source token's position
                * Copy = apply a linear map.
        * Fundamental point: Figuring out *which* source tokens to copy info from is a separate circuit from figuring out *how* to copy that information.
        * Internal head dimension of $d_{head} = \frac{d_{model}}{n_{heads}}
* MLP Layers - standard neural network. Single hidden layer, linear map -> GELU activation -> linear map
    * Exact activation not conceptually important.
    * Middle dimension normally $d_{mlp} = 4 \times d_{model}$
        * Exactly why the ratios are what they are isn't super important - doesn't matter that much, people basically cargo-cult GPT did.
    * Intuition - once attention has moved relevant information to a single position in the residual stream, MLPs can actually do computation, reasoning, lookup information, etc.
        * Big open problem in transformer mechanistic interpretability is what is going on inside MLPs?! See [Toy Model of Superposition Paper](https://transformer-circuits.pub/2022/toy_model/index.html) for more on why this is hard.
        * Underlying intuition - linear map -> non-linearity -> linear map is the most powerful force in the universe and can approximate arbitrary functions. Idk man it just works
* Finally, we unembed!
    * Apply a linear map, going from final residual stream to a vector of logits - this is the output.


### Bonus things - less conceptually important but key technical details
* LayerNorm
    * Simple normalization function applied at the start of each layer - MLP, Attn and Unembed
    * Converts each input vector (independently in parallel for each batch x position residual stream vector) to have mean zero and variance 1.
    * Then applies an elementwise scaling and translation
    * Cool maths tangent: The scale & translate is just a linear map. LayerNorm is only applied immediately before another linear map. Linear compose linear = linear, so we can just fold this into a single effective linear layer and ignore it.
        * `fold_ln=True` flag in `from_pretrained` does this for you.
    * LayerNorm is super fucking annoying, because the scale part is not linear, so you can't think about different bits of the input independently. But it's *almost* linear - if you're changing a small part of the input it's linear, but if you're changing enough to alter the norm substantially it's not linear :(
* Positional Information
    * This is totally fucked and messy, sorry!
    * **Problem:** Attention operates over all pairs of positions. This means it's symmetric with regards to position - the attention calculation from token 5 to token 1 and token 5 to token 2 are the same by default
        * This is dumb because nearby tokens are more relevant.
    * There's a lot of dumb hacks for this.
    * We'll focus on **learned, absolute positional embeddings**. This means we learn a lookup table mapping the index of the position of each token to a residual stream vector, and add this to the embed.
        * Note that we *add* rather than concatenate. This is because the residual stream is shared memory, and likely under significant superposition (the model compresses more features in there than the model has dimensions)
        * We basically never concatenate inside a transformer, unless doing weird shit like generating text efficiently.

In [15]:
print(reference_gpt2.cfg)

EasyTransformerConfig(n_layers=12, d_model=768, n_ctx=1024, d_head=64, model_name='gpt2-small', n_heads=12, d_mlp=3072, act_fn='gelu_new', d_vocab=50257, eps=1e-05, use_attn_result=False, use_attn_scale=True, use_local_attn=False, model_family='gpt2', checkpoint=None, tokenizer_name='gpt2', window_size=None, attn_types=None, init_mode='gpt2', normalization_type='LN', device='cpu', attention_dir='causal', attn_only=False, seed=42, initializer_range=0.02886751345948129, init_weights=False, scale_attn_by_inverse_layer_idx=False, positional_embedding_type='standard', final_rms=False, d_vocab_out=50257, parallel_attn_mlp=False, rotary_dim=64, dtype=torch.float32)


In [16]:
@dataclass
class Config:
  eps = 1e-5
  debug = True
  d_model = 768
  d_vocab = 50257
  heads = 12
  d_head = 64
  init_range = 0.02
  ctx = 1024

## Test Functions

In [17]:
def rand_float_test(cls, shape):
  cfg = Config()
  layer = cls(cfg)
  random_input = torch.randn(shape)
  print("Input shape:", random_input)
  output = layer(random_input)
  print("Output shape:", output)
  return output

def rand_int_test(cls, shape):
  cfg = Config()
  layer = cls(cfg)
  random_input = torch.randint(100, 1000, shape)
  print("Input shape:", random_input.shape)
  output = layer(random_input)
  print("Output shape:", output.shape)
  return output

def layer_norm_test(tensor):
  mean_zero = torch.allclose(tensor.mean(dim=-1), torch.zeros_like(tensor.mean(dim=-1)), atol=1e-3)
  std_one = torch.allclose(tensor.std(dim=-1), torch.ones_like(tensor.std(dim=-1)), atol=1e-3)
  if not mean_zero:
    raise Exception("Mean is not ~= zero")
  elif not std_one:
    raise Exception("Standard deviation is not ~= one")
  else:
    print("LayerNorm is good!")
  
def load_gpt2(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
  cfg = Config()
  layer = cls(cfg)
  layer.load_state_dict(gpt2_layer.state_dict())

  if isinstance(input_name, str):
    reference_input = cache_dict[input_name]
  else:
    reference_input = input_name
  print("Input shape:", reference_input.shape)
  output = layer(reference_input)
  print("Output shape:", output.shape)
  reference_output = gpt2_layer(reference_input)
  print("Reference output shape:", reference_output.shape)
  comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
  print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct.")
  return output


## Layer Norm



In [18]:
class MyLayerNorm(nn.Module):
  def __init__(self, cfg):
    super(MyLayerNorm, self).__init__()
    self.cfg = cfg
    self.w = nn.Parameter(torch.ones(self.cfg.d_model))
    self.b = nn.Parameter(torch.zeros(self.cfg.d_model))
  
  def forward(self, tensor):
    # Calculate the mean and standard deviation of the hidden dimension
    hidden_mean = tensor.float().mean(dim=-1, keepdim=True)
    hidden_std = tensor.float().std(dim=-1, keepdim=True)

    # Apply LayerNorm to the hidden dimension
    tensor_normalized = (tensor - hidden_mean) / (hidden_std + self.cfg.eps)
    
    return self.w * tensor_normalized + self.b

### Layer Norm Tests


In [19]:
shape = [2, 4, 768]
out1 = rand_float_test(MyLayerNorm, shape)
layer_norm_test(out1)
out2 = rand_int_test(MyLayerNorm, shape)
layer_norm_test(out2)
out3 = load_gpt2(MyLayerNorm, reference_gpt2.ln_final, 'blocks.11.hook_resid_post')

Input shape: tensor([[[ 1.9269,  1.4873,  0.9007,  ..., -1.6034, -0.4298,  0.5762],
         [ 0.3444, -3.1016, -1.4587,  ...,  0.2200,  0.3249,  1.3190],
         [-0.8497, -0.6987, -0.2052,  ..., -0.0298,  1.2715,  1.0849],
         [-1.0124, -0.1467, -0.4966,  ..., -2.3891,  0.7178, -1.5831]],

        [[-0.9635, -1.0543, -0.6110,  ..., -1.5511, -1.4303,  0.2053],
         [-0.1592,  0.0515,  0.6379,  ..., -1.1309,  1.2331,  0.8789],
         [ 0.7664,  0.4325, -1.0184,  ..., -1.2797, -0.1962,  0.1990],
         [-0.5971, -1.0348,  0.4210,  ...,  0.1757,  0.2550,  0.8070]]])
Output shape: tensor([[[ 1.8939,  1.4537,  0.8663,  ..., -1.6415, -0.4662,  0.5412],
         [ 0.3751, -3.1031, -1.4449,  ...,  0.2495,  0.3554,  1.3588],
         [-0.8692, -0.7139, -0.2061,  ..., -0.0257,  1.3133,  1.1212],
         [-1.0152, -0.1615, -0.5066,  ..., -2.3727,  0.6909, -1.5779]],

        [[-0.9679, -1.0586, -0.6161,  ..., -1.5543, -1.4338,  0.1985],
         [-0.1415,  0.0713,  0.6636,  ..., -

## Embedding Layer

In [20]:
class MyEmbeddingLayer(nn.Module):
  def __init__(self, cfg):
    super(MyEmbeddingLayer, self).__init__()
    self.W_E = nn.Parameter(torch.empty(cfg.d_vocab, cfg.d_model))
    nn.init.normal_(self.W_E, std=cfg.init_range)

  def forward(self, x):
    return nn.functional.embedding(x, self.W_E)


In [21]:
rand_int_test(MyEmbeddingLayer, [1, 10])
load_gpt2(MyEmbeddingLayer, reference_gpt2.embed, tokens)


Input shape: torch.Size([1, 10])
Output shape: torch.Size([1, 10, 768])
Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct.


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       grad_fn=<EmbeddingBackward0>)

## Positional Embedding

In [22]:
for key, val in cache.cache_dict.items():
  print(key, val.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_attn torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
blocks.1.hook_resid_pre torch.Size([1, 35, 768])
blocks.1.ln1.hook_scale to

In [23]:
class MyPositionalEmbedding(nn.Module):
  def __init__(self, cfg):
    super(MyPositionalEmbedding, self).__init__()
    self.cfg = cfg
    self.W_pos = nn.Parameter(torch.empty(cfg.ctx, cfg.d_model))
    nn.init.normal_(self.W_pos, std=cfg.init_range)

  def forward(self, x):
    pos_subset = self.W_pos[:x.size(1), :]
    pos_broadcasted = einops.repeat(pos_subset, 'a b -> x a b', x=x.size(0))
    return pos_broadcasted


### Positional Embedding Tests

In [24]:
rand_int_test(MyPositionalEmbedding, [2, 10])
load_gpt2(MyPositionalEmbedding, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 10])
Output shape: torch.Size([2, 10, 768])
Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct.


tensor([[[-1.8821e-02, -1.9742e-01,  4.0267e-03,  ..., -4.3044e-02,
           2.8267e-02,  5.4490e-02],
         [ 2.3959e-02, -5.3792e-02, -9.4879e-02,  ...,  3.4170e-02,
           1.0172e-02, -1.5573e-04],
         [ 4.2161e-03, -8.4764e-02,  5.4515e-02,  ...,  1.9745e-02,
           1.9325e-02, -2.1424e-02],
         ...,
         [ 4.6277e-04,  2.3037e-02,  4.1227e-02,  ..., -1.9287e-03,
          -2.3037e-03, -4.3189e-03],
         [-2.7136e-03,  2.1724e-02,  3.9675e-02,  ...,  4.2048e-04,
          -4.8160e-03, -9.2252e-04],
         [ 6.6815e-03,  2.0595e-02,  3.6596e-02,  ..., -9.5090e-04,
          -3.2512e-03, -9.6509e-04]]], grad_fn=<ReshapeAliasBackward0>)

## Attention



In [26]:
class MyAttention(nn.Module):
  def __init__(self, cfg):
    super(MyAttention, self).__init__()
    self.cfg = cfg
    self.W_Q = nn.Parameter(torch.empty(self.cfg.heads, self.cfg.d_model, self.cfg.d_head))
    self.W_K = nn.Parameter(torch.empty(self.cfg.heads, self.cfg.d_model, self.cfg.d_head))
    self.W_V = nn.Parameter(torch.empty(self.cfg.heads, self.cfg.d_model, self.cfg.d_head))

  def forward(self, x):
    print(x)