<a href="https://colab.research.google.com/github/eafdeafd/ml-notebooks/blob/main/neel_nanda_tutorial_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Introduction

This is a template for a clean, first principles implementation of GPT-2 in PyTorch. This is an accompaniment to [my video tutorial on implementing GPT-2](https://neelnanda.io/transformer-tutorial-2). If you want to properly understand how to implement GPT-2, you'll need to do it yourself! **I recommend filling out this template *as* you watch the video, and seeing how far you can get with each section before watching me do it**. Each section comes with tests, so you can check that you got it right. You can see [a solution notebook here](https://www.neelnanda.io/transformer-solution) - use it if you get stuck, but make an attempt first, and no copying and pasting!

There's a [template version of this notebook here](https://neelnanda.io/transformer-template), go and fill in the blanks (no copying and pasting!) and see if you can pass the tests.

If you enjoyed this, I expect you'd enjoy learning more about what's actually going on inside these models and how to reverse engineer them! This is a fascinating young research field, with a lot of low-hanging fruit and open problems! **I recommend starting with my post [Concrete Steps for Getting Started in Mechanistic Interpretability](https://www.neelnanda.io/mechanistic-interpretability/getting-started).**

This notebook was written to accompany my [TransformerLens library](https://github.com/neelnanda-io/TransformerLens) for doing mechanistic interpretability research on GPT-2 style language models, and is a clean implementation of the underlying transformer architecture in the library. (This notebook is based off of an earlier version called EasyTransformer)

Further Resources:
* [A Comprehensive Mechanistic Interpretability Explainer & Glossary](https://www.neelnanda.io/glossary)
    * Expecially [the transformers section](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=pndoEIqJ6GPvC1yENQkEfZYR)
* [200 Concrete Open Problems in Mechanistic Interpretability](https://www.neelnanda.io/concrete-open-problems)
* My [TransformerLens library](https://github.com/neelnanda-io/TransformerLens) for doing mechanistic interpretability research on GPT-2 style language models.
* My walkthrough of [A Mathematical Framework for Transformer Circuits](https://transformer-circuits.pub/2021/framework/index.html), for a deeper dive into how to think about transformers:.

Check out these other intros to transformers for another perspective:
* Jay Alammar's [illustrated transformer](https://jalammar.github.io/illustrated-transformer/)
* [Andrej Karpathy's MinGPT](https://github.com/karpathy/minGPT)

**Sharing Guidelines:** This tutorial is still a bit of a work in progress! I think it's usable, but please don't post it anywhere publicly without checking with me first! Sharing with friends is fine.

If you've found this useful, I'd love to hear about it! Positive and negative feedback also very welcome. You can reach me via [email](mailto:neelnanda27@gmail.com)

## Instructions
* No need to read the Setup Section
* Go to runtime > Change Runtime Type and set it to use a GPU
* Read and run notebook up until the start of the section "Actual Code!". Then go to the template notebook and try coding up the model yourself!
    * Bonus points for doing that without reading the solutions, and before I do it in the video!

# Setup

In [None]:
try:
  import google.colab
  IN_COLAB = True
  print("Running as a Colab notebook")
  %pip install git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  # Install another version of node that makes PySvelte work way faster
  !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
  %pip install git+https://github.com/neelnanda-io/PySvelte.git
  %pip install fancy_einsum
  %pip install einops
except:
  IN_COLAB = False
  print("Running as a Jupyter notebook - intended for development only!")

Running as a Colab notebook
Collecting git+https://github.com/neelnanda-io/Easy-Transformer.git@clean-transformer-demo
  Cloning https://github.com/neelnanda-io/Easy-Transformer.git (to revision clean-transformer-demo) to /tmp/pip-req-build-437n6saj
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/Easy-Transformer.git /tmp/pip-req-build-437n6saj
  Running command git checkout -b clean-transformer-demo --track origin/clean-transformer-demo
  Switched to a new branch 'clean-transformer-demo'
  Branch 'clean-transformer-demo' set up to track remote branch 'clean-transformer-demo' from 'origin'.
  Resolved https://github.com/neelnanda-io/Easy-Transformer.git to commit 1f25219e631aeb478d17075d47274db32c874e88
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting einops (from easy-transformer==0.1.0)
  Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m1.7 MB/s

In [None]:
import einops
from fancy_einsum import einsum
from dataclasses import dataclass
from easy_transformer import EasyTransformer
import torch
import torch.nn as nn
import numpy as np
import math
from easy_transformer.utils import get_corner, gelu_new, tokenize_and_concatenate
import tqdm.auto as tqdm

In [None]:
reference_gpt2 = EasyTransformer.from_pretrained("gpt2-small", fold_ln=False, center_unembed=False, center_writing_weights=False)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Moving model to device:  cuda
Finished loading pretrained model gpt2-small into EasyTransformer!


# Understanding Inputs & Outputs of a Transformer

## What is the point of a transformer?

**Transformers exist to model text!**

We're going to focus GPT-2 style transformers. Key feature: They generate text! You feed in language, and the model generates a probability distn over tokens. And you can repeatedly sample from this to generate text!

### How is the model trained?

You give it a bunch of text, and train it to predict the next token.

Importantly, if you give a model 100 tokens in a sequence, it predicts the next token for *each* prefix, ie it produces 100 predictions. This is kinda weird but it's much easier to make one that does this. And it also makes training more efficient, because you can 100 bits of feedback rather than just one.

#### Objection: Isn't this trivial for the first 99?

No! We make the transformer have *causal attention*. The core thing is that it can only move information forwards in the sequence. The prediction of what comes after token 50 is only a function of the first 50 tokens, *not* of token 51. (Jargon: *autoregressive*)

### Key takeaway:

Transformers are *sequence modelling engines*. It does the same processing in parallel at each sequence position, can move information between positions with attention, and conceptually can take a sequence of arbitrary length (not actually true, see later)

## Tokens - Transformer Inputs

Core point: Input is language (ie a sequence of characters, strings, etc)

### How do we convert language to vectors?

ML models take in vectors, not weird shit like language - how do we convert?

#### Idea: integers to vectors

We basically make a lookup table. Called an embedding.

Jargon: **One-hot encoding** We map eg numbers from 1 to 100, to a 100-dim vector, with a 1 in the kth position, 0 everywhere else. Key intuition is that one-hot encodings let you think about each integer independently - useful when integers = labels.

Dimensions = things that vary independently. Each input has its own dimension, so each input can be thought of independently, we don't bake in any relation.

Lookup tables <=> Multiply a fixed matrix by the one-hot encoded vector.


### Tokens: Language to sequence of integers

Core idea: We need a model that can deal with arbitrary text. We want to convert this into integers, *and* we want these integers to be in a bounded range.

**Idea:** Form a vocabulary!

**Idea 1:** Get a dictionary!

**Problem:** It can't cope with arbitrary text (eg URLs, punctuation, etc) Can't cope with mispellings.

**Idea 2:** Vocab = 256 ASCII characters. Fixed vocab size, can do arbitrary text, etc.

**Problem:** Loses structure of language - some sequences of characters are more meaningful than others.

Eg "language" is a lot more meaningful than "hjksdfiu" - we want the first to be a single token, second to not be. It's a more efficient use of our vocab.

#### What Actually Happens?

This super cursed thing called Byte-Pair Encodings

Ġ ~ means begins with a space, tokens with a leading space vs not are different.

We begin with the 256 ASCII characters as our tokens, and then find the most common pair of tokens, and merge that into a new token. Eg " t" is the most common pair, so it's our next token! Repeat 50000 times...

In [None]:
sorted_vocab = sorted(list(reference_gpt2.tokenizer.vocab.items()), key=lambda n:n[1])
print(sorted_vocab[:20])
print()
print(sorted_vocab[250:270])
print()
print(sorted_vocab[990:1010])
print()

[('!', 0), ('"', 1), ('#', 2), ('$', 3), ('%', 4), ('&', 5), ("'", 6), ('(', 7), (')', 8), ('*', 9), ('+', 10), (',', 11), ('-', 12), ('.', 13), ('/', 14), ('0', 15), ('1', 16), ('2', 17), ('3', 18), ('4', 19)]

[('ľ', 250), ('Ŀ', 251), ('ŀ', 252), ('Ł', 253), ('ł', 254), ('Ń', 255), ('Ġt', 256), ('Ġa', 257), ('he', 258), ('in', 259), ('re', 260), ('on', 261), ('Ġthe', 262), ('er', 263), ('Ġs', 264), ('at', 265), ('Ġw', 266), ('Ġo', 267), ('en', 268), ('Ġc', 269)]

[('Ġprodu', 990), ('Ġstill', 991), ('led', 992), ('ah', 993), ('Ġhere', 994), ('Ġworld', 995), ('Ġthough', 996), ('Ġnum', 997), ('arch', 998), ('imes', 999), ('ale', 1000), ('ĠSe', 1001), ('ĠIf', 1002), ('//', 1003), ('ĠLe', 1004), ('Ġret', 1005), ('Ġref', 1006), ('Ġtrans', 1007), ('ner', 1008), ('ution', 1009)]



Gets to weird esoteric shit.

In [None]:
sorted_vocab[-20:]

[('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

Use the `to_tokens` method to convert text to numbers

Prepends with a special token to give attention a resting position, disable with `prepend_bos=False`

In [None]:
print(reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!"))
print(reference_gpt2.to_tokens("Whether a word begins with a capital or space matters!", prepend_bos=False))

tensor([[50256, 15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,
          6067,     0]])
tensor([[15354,   257,  1573,  6140,   351,   257,  3139,   393,  2272,  6067,
             0]])


### Rant: Tokenization is a Headache

Whether a word begins with a capital or space matters!

In [None]:
print(reference_gpt2.to_str_tokens("Ralph"))
print(reference_gpt2.to_str_tokens(" Ralph"))
print(reference_gpt2.to_str_tokens(" ralph"))
print(reference_gpt2.to_str_tokens("ralph"))

['<|endoftext|>', 'R', 'alph']
['<|endoftext|>', ' Ralph']
['<|endoftext|>', ' r', 'alph']
['<|endoftext|>', 'ral', 'ph']


Arithmetic is a total mess: Length is inconsistent, common numbers bundle together

In [None]:
reference_gpt2.to_str_tokens("56873+3184623=123456789-1000000000")

['<|endoftext|>',
 '568',
 '73',
 '+',
 '318',
 '46',
 '23',
 '=',
 '123',
 '45',
 '67',
 '89',
 '-',
 '1',
 '000000',
 '000']

### Key Takeaway:

* We learn a dictionary of vocab of tokens (sub-words).

* We (approx) losslessly convert language to integers via tokenizing it.

* We convert integers to vectors via a lookup table.

* Note: input to the transformer is a sequence of *tokens* (ie integers), not vectors

## Logits - Transformer Outputs

**Goal:** Probability distribution over next tokens. (for every *prefix* of the sequence - given n tokens, we make n next token predictions)

**Problem:** How to convert a vector to a probability distribution?

**Answer:** Use a softmax ($x_i \to \frac{e^{x_i}}{\sum e^{x_j}}$), exponential makes everything positive, normalization makes it add to one.

So the model outputs a tensor of logits, one vector of size $d_{vocab}$ for each input token.

We can use this to generate things!

## Generation!

**Step 1:** Convert text to tokens

Shape = batch x position

In [None]:
reference_text = "I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world!"
tokens = reference_gpt2.to_tokens(reference_text)
print(tokens)
print(tokens.shape)
print(reference_gpt2.to_str_tokens(tokens))

tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0]])
torch.Size([1, 35])
['<|endoftext|>', 'I', ' am', ' an', ' amazing', ' aut', 'ore', 'gressive', ',', ' dec', 'oder', '-', 'only', ',', ' G', 'PT', '-', '2', ' style', ' transformer', '.', ' One', ' day', ' I', ' will', ' exceed', ' human', ' level', ' intelligence', ' and', ' take', ' over', ' the', ' world', '!']


**Step 2:** Map tokens to logits

(run_with_cache means cache all intermediate activations, not important right now)

shape = batch x position x d_vocab

In [None]:
tokens = tokens.cuda()
logits, cache = reference_gpt2.run_with_cache(tokens)
print(logits.shape)

torch.Size([1, 35, 50257])


**Step 3:** Convert the logits to a distribution with a softmax

In [None]:
log_probs = logits.log_softmax(dim=-1)
probs = logits.log_softmax(dim=-1)
print(log_probs.shape)
print(probs.shape)

torch.Size([1, 35, 50257])
torch.Size([1, 35, 50257])


**Bonus step:** What is the most likely next token at each position?

In [None]:
list(zip(reference_gpt2.to_str_tokens(reference_text), reference_gpt2.tokenizer.batch_decode(logits.argmax(dim=-1)[0])))

[('<|endoftext|>', '\n'),
 ('I', "'m"),
 (' am', ' a'),
 (' an', ' avid'),
 (' amazing', ' person'),
 (' aut', 'od'),
 ('ore', 'sp'),
 ('gressive', '.'),
 (',', ' and'),
 (' dec', 'ently'),
 ('oder', ','),
 ('-', 'driven'),
 ('only', ' programmer'),
 (',', ' and'),
 (' G', 'IM'),
 ('PT', '-'),
 ('-', 'only'),
 ('2', '.'),
 (' style', ','),
 (' transformer', '.'),
 ('.', ' I'),
 (' One', ' of'),
 (' day', ' I'),
 (' I', ' will'),
 (' will', ' be'),
 (' exceed', ' my'),
 (' human', 'ly'),
 (' level', ' of'),
 (' intelligence', ' and'),
 (' and', ' I'),
 (' take', ' over'),
 (' over', ' the'),
 (' the', ' world'),
 (' world', '.'),
 ('!', ' I')]

**Step 4:** Map distribution to a token

In [None]:
next_token = logits[0, -1].argmax(dim=-1)
print(next_token)

tensor(314, device='cuda:0')


**Step 5:** Add this to the end of the input, re-run

(More efficient ways to do this, but whatever, doesn't matter conceptually)

In [None]:
next_tokens = torch.cat([tokens, torch.tensor(next_token, device='cuda', dtype=torch.int64)[None, None]], dim=-1)
new_logits = reference_gpt2(next_tokens)
print("New Input:", next_tokens)
print(next_tokens.shape)
print("New Input:", reference_gpt2.tokenizer.decode(next_tokens[0]))

print(new_logits.shape)
print(new_logits[-1, -1].argmax(-1))

print(reference_gpt2.tokenizer.decode(new_logits[-1, -1].argmax(-1)))


New Input: tensor([[50256,    40,   716,   281,  4998,  1960,   382, 19741,    11,   875,
         12342,    12,  8807,    11,   402, 11571,    12,    17,  3918, 47385,
            13,  1881,  1110,   314,   481,  7074,  1692,  1241,  4430,   290,
          1011,   625,   262,   995,     0,   314]], device='cuda:0')
torch.Size([1, 36])
New Input: <|endoftext|>I am an amazing autoregressive, decoder-only, GPT-2 style transformer. One day I will exceed human level intelligence and take over the world! I
torch.Size([1, 36, 50257])
tensor(716, device='cuda:0')
 am


  next_tokens = torch.cat([tokens, torch.tensor(next_token, device='cuda', dtype=torch.int64)[None, None]], dim=-1)


## Key takeaways:

* Takes in language, predicts next token (for *each* token in a causal way)
* We convert language to a sequence of integers with a tokenizer.
* We convert integers to vectors with a lookup table.

* Output is a vector of logits (one for each input token), we convert to a probability distn with a softmax, and can then convert this to a token (eg taking the largest logit, or sampling).

* We append this to the input + run again to generate more text (Jargon: *autoregressive*)

* Meta level point: Transformers are sequence operation models, they take in a sequence, do processing in parallel at each position, and use attention to move information between positions!

# Clean Transformer Implementation

![](https://github.com/neelnanda-io/Easy-Transformer/blob/clean-transformer-demo/transformer_overview.png?raw=1)

High-Level architecture:

Go watch my [Transformer Circuits walkthrough](https://www.youtube.com/watch?v=KV5gbOmHbjU) if you want more intuitions!

(Diagram is bottom to top)

* Input tokens, integers
* Embedding is a lookup table mapping tokens to vectors
    * Lives in the *residual stream*
* Residual stream - the sum of all previous outputs of layers of the model, is the input to each new layer.
    * *Really* fundamental. It's the central object of the transformer.
        * It's how model remembers things, moves information between layers for composition, and it's the medium used to store the information that attention moves between positions.
* Then we have a series of $n_{layers}$ transformer blocks
    * Confusing jargon - a block contains an attention layer *and* an MLP layer, but we say a transformer has k layers if it has k blocks (ie 2k total layers).
* First we have attention. This moves information from prior positions in the sequence to the current token.
    * We do this for *every* token in parallel using the same parameters. The only difference is that we look backwards only, so later tokens get more room to look back.
        * We look backwards so we can predict the next token without cheating.
    * Only bit of a transformer that moves information between positions.
    * Made up of $n_heads$ heads - each with their own parameters, own attention pattern, and own information how to copy things from source to destination.
        * The heads act independently and additively, we just add their outputs together, and back to the stream
    * Each head:
        * Produces an attention pattern for each destination token, a probability distribution of prior source tokens (including the current one) weighting how much information to copy.
            * Do this for each pair of tokens
            * Copy information in the same way from each source token.
                * What information we copy *does* depend on the source token's *residual stream*. This does not necessarily mean the info of what text token is at the source token's position
                * Copy = apply a linear map.
        * Fundamental point: Figuring out *which* source tokens to copy info from is a separate circuit from figuring out *how* to copy that information.
        * Internal head dimension of $d_{head} = \frac{d_{model}}{n_{heads}}
* MLP Layers - standard neural network. Single hidden layer, linear map -> GELU activation -> linear map
    * Exact activation not conceptually important.
    * Middle dimension normally $d_{mlp} = 4 \times d_{model}$
        * Exactly why the ratios are what they are isn't super important - doesn't matter that much, people basically cargo-cult GPT did.
    * Intuition - once attention has moved relevant information to a single position in the residual stream, MLPs can actually do computation, reasoning, lookup information, etc.
        * Big open problem in transformer mechanistic interpretability is what is going on inside MLPs?! See [Toy Model of Superposition Paper](https://transformer-circuits.pub/2022/toy_model/index.html) for more on why this is hard.
        * Underlying intuition - linear map -> non-linearity -> linear map is the most powerful force in the universe and can approximate arbitrary functions. Idk man it just works
* Finally, we unembed!
    * Apply a linear map, going from final residual stream to a vector of logits - this is the output.


### Bonus things - less conceptually important but key technical details
* LayerNorm
    * Simple normalization function applied at the start of each layer - MLP, Attn and Unembed
    * Converts each input vector (independently in parallel for each batch x position residual stream vector) to have mean zero and variance 1.
    * Then applies an elementwise scaling and translation
    * Cool maths tangent: The scale & translate is just a linear map. LayerNorm is only applied immediately before another linear map. Linear compose linear = linear, so we can just fold this into a single effective linear layer and ignore it.
        * `fold_ln=True` flag in `from_pretrained` does this for you.
    * LayerNorm is super fucking annoying, because the scale part is not linear, so you can't think about different bits of the input independently. But it's *almost* linear - if you're changing a small part of the input it's linear, but if you're changing enough to alter the norm substantially it's not linear :(
* Positional Information
    * This is totally fucked and messy, sorry!
    * **Problem:** Attention operates over all pairs of positions. This means it's symmetric with regards to position - the attention calculation from token 5 to token 1 and token 5 to token 2 are the same by default
        * This is dumb because nearby tokens are more relevant.
    * There's a lot of dumb hacks for this.
    * We'll focus on **learned, absolute positional embeddings**. This means we learn a lookup table mapping the index of the position of each token to a residual stream vector, and add this to the embed.
        * Note that we *add* rather than concatenate. This is because the residual stream is shared memory, and likely under significant superposition (the model compresses more features in there than the model has dimensions)
        * We basically never concatenate inside a transformer, unless doing weird shit like generating text efficiently.

# Actual Code!

## Print All Activation Shapes of Reference Model

Key:
```
batch = 1
position = 35
d_model = 768
n_heads = 12
n_layers = 12
d_mlp = 3072 (4 * d_model)
d_head = 64 (d_model / n_heads)
```

In [None]:
for activation_name, activation in cache.cache_dict.items():
    # Only print for first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 35, 768])
hook_pos_embed torch.Size([1, 35, 768])
blocks.0.hook_resid_pre torch.Size([1, 35, 768])
blocks.0.ln1.hook_scale torch.Size([1, 35, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 35, 768])
blocks.0.attn.hook_q torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 35, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_attn torch.Size([1, 12, 35, 35])
blocks.0.attn.hook_z torch.Size([1, 35, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 35, 768])
blocks.0.hook_resid_mid torch.Size([1, 35, 768])
blocks.0.ln2.hook_scale torch.Size([1, 35, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 35, 768])
blocks.0.mlp.hook_pre torch.Size([1, 35, 3072])
blocks.0.mlp.hook_post torch.Size([1, 35, 3072])
blocks.0.hook_mlp_out torch.Size([1, 35, 768])
blocks.0.hook_resid_post torch.Size([1, 35, 768])
ln_final.hook_scale torch.Size([1, 35, 1])
ln_final.hook_normalized torch.S

## Print All Parameters Shapes of Reference Model

In [None]:
for name, param in reference_gpt2.named_parameters():
    # Only print for first layer
    if ".0." in name or "blocks" not in name:
        print(name, param.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.ln1.w torch.Size([768])
blocks.0.ln1.b torch.Size([768])
blocks.0.ln2.w torch.Size([768])
blocks.0.ln2.b torch.Size([768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
ln_final.w torch.Size([768])
ln_final.b torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


## Config

In [None]:
# As a reference - note there's a lot of stuff we don't care about in here, to do with library internals or other architectures
print(reference_gpt2.cfg)

EasyTransformerConfig(n_layers=12, d_model=768, n_ctx=1024, d_head=64, model_name='gpt2-small', n_heads=12, d_mlp=3072, act_fn='gelu_new', d_vocab=50257, eps=1e-05, use_attn_result=False, use_attn_scale=True, use_local_attn=False, model_family='gpt2', checkpoint=None, tokenizer_name='gpt2', window_size=None, attn_types=None, init_mode='gpt2', normalization_type='LN', device='cuda', attention_dir='causal', attn_only=False, seed=42, initializer_range=0.02886751345948129, init_weights=False, scale_attn_by_inverse_layer_idx=False, positional_embedding_type='standard', final_rms=False, d_vocab_out=50257, parallel_attn_mlp=False, rotary_dim=64, dtype=torch.float32)


We define a stripped down config for our model

In [None]:

@dataclass
class Config:
    d_model: int = 768
    debug: bool = True
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 1024
    d_head: int = 64
    d_mlp: int = 3072
    n_heads: int = 12
    n_layers: int = 12

cfg = Config()
print(cfg)

Config(d_model=768, debug=True, layer_norm_eps=1e-05, d_vocab=50257, init_range=0.02, n_ctx=1024, d_head=64, d_mlp=3072, n_heads=12, n_layers=12)


## Tests

Tests are great, write lightweight ones to use as you go!

**Naive test:** Generate random inputs of the right shape, input to your model, check whether there's an error and print the correct output.

In [None]:
def rand_float_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    random_input = torch.randn(shape).cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def rand_int_test(cls, shape):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    random_input = torch.randint(100, 1000, shape).cuda()
    print("Input shape:", random_input.shape)
    output = layer(random_input)
    print("Output shape:", output.shape)
    print()
    return output

def load_gpt2_test(cls, gpt2_layer, input_name, cache_dict=cache.cache_dict):
    cfg = Config(debug=True)
    layer = cls(cfg).cuda()
    layer.load_state_dict(gpt2_layer.state_dict(), strict=False)
    # Allow inputs of strings or tensors
    if isinstance(input_name, str):
        reference_input = cache_dict[input_name]
    else:
        reference_input = input_name
    print("Input shape:", reference_input.shape)
    output = layer(reference_input)
    print("Output shape:", output.shape)
    reference_output = gpt2_layer(reference_input)
    print("Reference output shape:", reference_output.shape)

    comparison = torch.isclose(output, reference_output, atol=1e-4, rtol=1e-3)
    print(f"{comparison.sum()/comparison.numel():.2%} of the values are correct")
    return output

## LayerNorm

Make mean 0
Normalize to have variance 1
Scale with learned weights
Translate with learned bias

In [None]:
class LayerNorm(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.w = nn.Parameter(torch.ones(cfg.d_model))
        self.b = nn.Parameter(torch.zeros(cfg.d_model))

    def forward(self, residual):
        # residual: [batch, position, d_model]
        # layernorms are applied after attention and mlp layers, giving it d_model as the last dimension
        if self.cfg.debug: print("residual:", residual.shape)
        mean = residual.mean(dim=-1, keepdim=True)
        var = residual.var(dim=-1, keepdim=True, unbiased=False)
        normalized = self.w * ((residual - mean) / torch.sqrt(var + self.cfg.layer_norm_eps)) + self.b
        if self.cfg.debug: print("normalized:", normalized.shape)
        return normalized

In [None]:
_ = rand_float_test(LayerNorm, [2, 4, 768])
_ = load_gpt2_test(LayerNorm, reference_gpt2.ln_final, "blocks.11.hook_resid_post")

Input shape: torch.Size([2, 4, 768])
residual: torch.Size([2, 4, 768])
normalized: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35, 768])
residual: torch.Size([1, 35, 768])
normalized: torch.Size([1, 35, 768])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


## Embedding

Basically a lookup table from tokens to residual stream vectors.

In [None]:
class Embed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_E = nn.Parameter(torch.empty((cfg.d_vocab, cfg.d_model))) # empty returns unintialized data in a (dim x dim x...) tensor
        nn.init.normal_(self.W_E, std=self.cfg.init_range) # init_normal_ takes the parameter and initializes it drawing from a normal dist. .02 as the std is gpt2

    def forward(self, tokens):
        # tokens: [batch, position] where position == context length/seq length, batch = # seq
        # For each token for each position, we look up the corresponding embedding vector from the embedding LUT which is a d_model dimension vector
        # this gives us a [batch, position, d_model] output
        # both are okay, functionally doing the same thing
        # embedding takes a token [batch, position] -> look up table, which is W_E [batch, position, dimension]
        return torch.nn.functional.embedding(tokens, self.W_E)
        #return self.W_E[tokens, :]
rand_int_test(Embed, [2, 4])
load_gpt2_test(Embed, reference_gpt2.embed, tokens)


Input shape: torch.Size([2, 4])
Output shape: torch.Size([2, 4, 768])

Input shape: torch.Size([1, 35])
Output shape: torch.Size([1, 35, 768])
Reference output shape: torch.Size([1, 35, 768])
100.00% of the values are correct


tensor([[[ 0.0514, -0.0277,  0.0499,  ...,  0.0070,  0.1552,  0.1207],
         [ 0.1474, -0.0959,  0.1430,  ...,  0.1030, -0.0625, -0.1131],
         [ 0.1596, -0.1249,  0.1148,  ...,  0.2558,  0.0196,  0.0145],
         ...,
         [-0.0393,  0.0050,  0.0421,  ..., -0.0477,  0.0670, -0.0471],
         [-0.1488,  0.1519,  0.0056,  ..., -0.3107,  0.2073,  0.0377],
         [-0.1101, -0.0393,  0.0331,  ..., -0.1364,  0.0151,  0.0453]]],
       device='cuda:0', grad_fn=<EmbeddingBackward0>)

## Positional Embedding

In [None]:
class PosEmbed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_pos = nn.Parameter(torch.empty((cfg.n_ctx, cfg.d_model)))
        nn.init.normal_(self.W_pos, std=self.cfg.init_range)

    def forward(self, tokens):
        "YOUR CODE HERE"

rand_int_test(PosEmbed, [2, 4])
load_gpt2_test(PosEmbed, reference_gpt2.pos_embed, tokens)

Input shape: torch.Size([2, 4])


AttributeError: 'NoneType' object has no attribute 'shape'

## Attention

* **Step 1:** Produce an attention pattern - for each destination token, probability distribution over previous tokens (incl current token)
    * Linear map from input -> query, key shape [batch, position, head_index, d_head]
    * Dot product every *pair* of queries and keys to get attn_scores [batch, head_index, query_pos, key_pos] (query = dest, key = source)
    * Scale and mask attn_scores to make it lower triangular, ie causal
    * softmax row-wise, to get a probability distribution along each the key_pos dimension - this is our attention pattern!
* **Step 2:** Move information from source tokens to destination token using attention pattern (move = apply linear map)
    * Linear map from input -> value [batch, key_pos, head_index, d_head]
    * Mix along the key_pos with attn pattern to get z, a mixed value [batch, query_pos, head_index, d_head]
    * Map to output, [batch, position, d_model] (position = query_pos, we've summed over all heads)

First, it's useful to visualize and play around with attention patterns - what exactly are we looking at here? (Click on a head to lock onto just showing that head's pattern, it'll make it easier to interpret)

In [None]:
import pysvelte
pysvelte.AttentionMulti(tokens=reference_gpt2.to_str_tokens(reference_text), attention=cache['blocks.0.attn.hook_attn'][0].permute(1, 2, 0)).show()

In [None]:
class Attention(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_Q = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_Q, std=self.cfg.init_range)
        self.b_Q = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_K = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_K, std=self.cfg.init_range)
        self.b_K = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))
        self.W_V = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_model, cfg.d_head)))
        nn.init.normal_(self.W_V, std=self.cfg.init_range)
        self.b_V = nn.Parameter(torch.zeros((cfg.n_heads, cfg.d_head)))

        self.W_O = nn.Parameter(torch.empty((cfg.n_heads, cfg.d_head, cfg.d_model)))
        nn.init.normal_(self.W_O, std=self.cfg.init_range)
        self.b_O = nn.Parameter(torch.zeros((cfg.d_model)))

        self.register_buffer("IGNORE", torch.tensor(-1e5, dtype=torch.float32, device="cuda"))

    def forward(self, normalized_resid_pre):
        # normalized_resid_pre: [batch, position, d_model]
        "YOUR CODE HERE"

    def apply_causal_mask(self, attn_scores):
        # attn_scores: [batch, n_heads, query_pos, key_pos]
        "YOUR CODE HERE"

rand_float_test(Attention, [2, 4, 768])
load_gpt2_test(Attention, reference_gpt2.blocks[0].attn, cache["blocks.0.ln1.hook_normalized"])

## MLP

In [None]:
class MLP(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_in = nn.Parameter(torch.empty((cfg.d_model, cfg.d_mlp)))
        nn.init.normal_(self.W_in, std=self.cfg.init_range)
        self.b_in = nn.Parameter(torch.zeros((cfg.d_mlp)))
        self.W_out = nn.Parameter(torch.empty((cfg.d_mlp, cfg.d_model)))
        nn.init.normal_(self.W_out, std=self.cfg.init_range)
        self.b_out = nn.Parameter(torch.zeros((cfg.d_model)))

    def forward(self, normalized_resid_mid):
        # normalized_resid_mid: [batch, position, d_model]
        if self.cfg.debug: print("normalized_resid_mid:", normalized_resid_mid.shape)
        nn.Linear()
rand_float_test(MLP, [2, 4, 768])
load_gpt2_test(MLP, reference_gpt2.blocks[0].mlp, cache["blocks.0.ln2.hook_normalized"])

## Transformer Block

In [None]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg

        self.ln1 = LayerNorm(cfg)
        self.attn = Attention(cfg)
        self.ln2 = LayerNorm(cfg)
        self.mlp = MLP(cfg)

    def forward(self, resid_pre):
        # resid_pre [batch, position, d_model]
        "YOUR CODE HERE"
rand_float_test(TransformerBlock, [2, 4, 768])
load_gpt2_test(TransformerBlock, reference_gpt2.blocks[0], cache["resid_pre", 0])

## Unembedding

In [None]:
class Unembed(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.W_U = nn.Parameter(torch.empty((cfg.d_model, cfg.d_vocab)))
        nn.init.normal_(self.W_U, std=self.cfg.init_range)
        self.b_U = nn.Parameter(torch.zeros((cfg.d_vocab), requires_grad=False))

    def forward(self, normalized_resid_final):
        # normalized_resid_final [batch, position, d_model]
        "YOUR CODE HERE"

rand_float_test(Unembed, [2, 4, 768])
load_gpt2_test(Unembed, reference_gpt2.unembed, cache["ln_final.hook_normalized"])

## Full Transformer

In [None]:
class DemoTransformer(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.embed = Embed(cfg)
        self.pos_embed = PosEmbed(cfg)
        self.blocks = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.n_layers)])
        self.ln_final = LayerNorm(cfg)
        self.unembed = Unembed(cfg)

    def forward(self, tokens):
        # tokens [batch, position]
        "YOUR CODE HERE"

rand_int_test(DemoTransformer, [2, 4])
load_gpt2_test(DemoTransformer, reference_gpt2, tokens)

# Try it out!

In [None]:
demo_gpt2 = DemoTransformer(Config(debug=False))
demo_gpt2.load_state_dict(reference_gpt2.state_dict(), strict=False)
demo_gpt2.cuda()

Take a test string - the intro paragraph of today's featured Wikipedia article. Let's calculate the loss!

In [None]:
test_string = """Mini scule is a species of microhylid frog endemic to Madagascar that was described in 2019. The scientific name of the species refers to its size, being a pun on the word minuscule. It is very small, measuring only 8.4 to 10.8 mm (0.33 to 0.43 in) in snout–vent length. It has bronze underparts with a brown groin and back of the thigh, cream upperparts with brown flecking, a dark brown side of the head, and a red iris. On the hind feet, the first toe is absent and the second and fifth toes are strongly reduced. The frog is known only from the Sainte Luce Reserve, where it inhabits areas with deep leaf litter near semi-permanent water bodies. Specimens of frogs from Mandena, the Vohimena mountains, the southern Anosy Mountains, and Tsitongambarika may also be of this species. Along with Mini mum and Mini ature, the other two species in its genus, it received media attention when first described due to the wordplay in its scientific name. (Full article...)"""

In [None]:
test_tokens = reference_gpt2.to_tokens(test_string).cuda()
demo_logits = demo_gpt2(test_tokens)

In [None]:
def lm_cross_entropy_loss(logits, tokens):
    # Measure next token loss
    # Logits have shape [batch, position, d_vocab]
    # Tokens have shape [batch, position]
    log_probs = logits.log_softmax(dim=-1)
    pred_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens[:, 1:].unsqueeze(-1)).squeeze(-1)
    return -pred_log_probs.mean()
loss = lm_cross_entropy_loss(demo_logits, test_tokens)
print(loss)
print("Loss as average prob", (-loss).exp())
print("Loss as 'uniform over this many variables'", (loss).exp())
print("Uniform loss over the vocab", math.log(demo_gpt2.cfg.d_vocab))

We can also greedily generate text:

In [None]:
test_string = "Breaking News: President Trump has been impeached by the House of Representatives for abuse of power and obstruction of Congress. The vote was 230 to 197, with 10 Republicans joining all Democrats in voting to impeach. The president is now only the third in American history to be impeached, and the first to be impeached twice. The House will now send the articles of impeachment to the Senate, where a trial will be held to determine whether to remove the president from office. The Senate is expected to begin the trial on"
for i in tqdm.tqdm(range(100)):
    test_tokens = reference_gpt2.to_tokens(test_string).cuda()
    demo_logits = demo_gpt2(test_tokens)
    test_string += reference_gpt2.tokenizer.decode(demo_logits[-1, -1].argmax())
print(test_string)

# Training a Model!

This is a lightweight demonstration of how you can actually train your own GPT-2 with this code! Here we train a tiny model on a tiny dataset, but it's fundamentally the same code for training a larger/more real model (though you'll need beefier GPUs and data parallelism to do it remotely efficiently, and fancier parallelism for much bigger ones).

For our purposes, we'll train 2L 4 heads per layer model, with context length 256, for 1000 steps of batch size 8, just to show what it looks like (and so the notebook doesn't melt your colab lol).

In [None]:
if IN_COLAB:
    %pip install datasets
    %pip install transformers
import datasets
import transformers
import plotly.express as px

## Config

In [None]:
batch_size = 8
num_epochs = 1
max_steps = 1000
log_every = 10
lr = 1e-3
weight_decay = 1e-2
model_cfg = Config(debug=False, d_model=256, n_heads=4, d_head=64, d_mlp=1024, n_layers=2, n_ctx=256, d_vocab=reference_gpt2.cfg.d_vocab)



## Create Data

We load in a tiny dataset I made, with the first 10K entries in the Pile (inspired by Stas' version for OpenWebText!)


In [None]:
dataset = datasets.load_dataset("NeelNanda/pile-10k", split="train")
print(dataset)
print(dataset[0]['text'][:100])
tokens_dataset = tokenize_and_concatenate(dataset, reference_gpt2.tokenizer, streaming=False, max_length=model_cfg.n_ctx, column_name="text", add_bos_token=True, num_proc=4)
data_loader = torch.utils.data.DataLoader(tokens_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)


## Create Model


In [None]:
model = DemoTransformer(model_cfg)
model.cuda()


## Create Optimizer
We use AdamW - it's a pretty standard optimizer.

In [None]:
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)

## Run Training Loop


In [None]:
losses = []
print("Number of batches:", len(data_loader))
for epoch in range(num_epochs):
    for c, batch in tqdm.tqdm(enumerate(data_loader)):
        tokens = batch['tokens'].cuda()
        logits = model(tokens)
        loss = lm_cross_entropy_loss(logits, tokens)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        losses.append(loss.item())
        if c % log_every == 0:
            print(f"Step: {c}, Loss: {loss.item():.4f}")
        if c > max_steps:
            break


We can now plot a loss curve!

In [None]:
px.line(y=losses, x=np.arange(len(losses))*(model_cfg.n_ctx * batch_size), labels={"y":"Loss", "x":"Tokens"}, title="Training curve for my tiny demo model!")