# Intro

![Machine Learning](https://imgs.xkcd.com/comics/machine_learning.png)

# Setup

I stole this from an example, we don't need all this complexity. But I think it's cool to see.

Moreover, most of my python / jupyter / colab knowledge is copied from a bunch of examples. See [Sources](#Sources).

In [1]:
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
else:
    # See README.md for local setup
    pass

In [2]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import get_corner
import circuitsvis as cv

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

# Getting acquainted with the modellel

In [3]:
gpt2 = HookedTransformer.from_pretrained("gpt2-small")

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [5]:
#print(gpt2)
print(gpt2.cfg)

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': F

## Input: "What does this eat?" aka Tokenization

In [6]:
gpt2.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [7]:
gpt2.tokenizer.vocab

{'Ġanatomical': 48631,
 'Jerry': 43462,
 'Ġarises': 22068,
 'Ġcrabs': 49172,
 'Mem': 13579,
 'Ġfriends': 2460,
 'ĠSto': 22025,
 'rics': 10466,
 'ĠMau': 28931,
 'ĠOil': 11474,
 'Ġphysical': 3518,
 'ĠVolume': 14701,
 'âĵĺ': 45563,
 'ĠWillow': 33021,
 'âĢĶ"': 19056,
 'ĠFall': 7218,
 'ĠDUI': 39414,
 'Ġtouches': 18105,
 'Ġinvest': 1325,
 'Ġutilize': 17624,
 'ĠJS': 26755,
 'ĠMED': 26112,
 'Ġborrowers': 31617,
 'rentice': 20098,
 'ael': 3010,
 'Ġprojecting': 37298,
 'ĠSaber': 40506,
 'Ġefficient': 6942,
 'Ġposters': 19379,
 'Ġwords': 2456,
 'Ġscience': 3783,
 'ribe': 4892,
 'Ġtant': 24246,
 'rans': 26084,
 'Kill': 27100,
 'Ġbarring': 34928,
 'FAQ': 42680,
 'Ġtickets': 8587,
 'TG': 35990,
 'Ġplethora': 35146,
 'latable': 49009,
 'overed': 2557,
 'dry': 39140,
 'Ġlosers': 29502,
 'imental': 9134,
 'quart': 36008,
 'gil': 37718,
 'cool': 24494,
 'okin': 36749,
 'Bruce': 38509,
 'lf': 1652,
 'umann': 40062,
 'Ġconverting': 23202,
 'Ġworkshop': 20243,
 'lus': 41790,
 'ĠDresden': 46993,
 'acement':

In [8]:
vocab_sorted = sorted(gpt2.tokenizer.vocab.items(), key=lambda x: x[1])
vocab_sorted[-20:]

[('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

In [9]:
text = ("This is a story about Quomatarus."
  + " When one day Quomatarus decided to do something different and bought a plane ticket to Lamanandu."
  + " When he arrived to the airport Quomatarus noticed")

In [11]:
gpt2.tokenizer.encode(text)

[1212,
 318,
 257,
 1621,
 546,
 2264,
 296,
 9459,
 385,
 13,
 1649,
 530,
 1110,
 2264,
 296,
 9459,
 385,
 3066,
 284,
 466,
 1223,
 1180,
 290,
 5839,
 257,
 6614,
 7846,
 284,
 406,
 10546,
 392,
 84,
 13,
 1649,
 339,
 5284,
 284,
 262,
 9003,
 2264,
 296,
 9459,
 385,
 6810]

In [12]:
# gpt2.to_str_tokens(text)
tokens = gpt2.to_tokens(text)
str_tokens = gpt2.to_str_tokens(text)
print(str_tokens)
print(tokens.shape)

['<|endoftext|>', 'This', ' is', ' a', ' story', ' about', ' Qu', 'om', 'atar', 'us', '.', ' When', ' one', ' day', ' Qu', 'om', 'atar', 'us', ' decided', ' to', ' do', ' something', ' different', ' and', ' bought', ' a', ' plane', ' ticket', ' to', ' L', 'aman', 'and', 'u', '.', ' When', ' he', ' arrived', ' to', ' the', ' airport', ' Qu', 'om', 'atar', 'us', ' noticed']
torch.Size([1, 45])


### Embedding

In [13]:
gpt2.W_E.shape

torch.Size([50257, 768])

In [14]:
embedded = gpt2.W_E[tokens, :]
print(embedded.shape)
utils.get_corner(embedded)

torch.Size([1, 45, 768])


tensor([[[ 0.0517, -0.0274,  0.0502],
         [ 0.0270, -0.0939,  0.0738],
         [-0.0078,  0.0120,  0.0575]]], device='cuda:0')

## Output: "What comes out?"

In [15]:
# gpt2(tokens, return_type="loss")
gpt2(tokens).shape

torch.Size([1, 45, 50257])

In [16]:
# logits, cache = gpt2.run_with_cache(tokens, remove_batch_dim=True)
logits, cache = gpt2.run_with_cache(tokens)
print(logits.shape)
print(cache)

torch.Size([1, 45, 50257])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_res

In [29]:
probs = logits.squeeze().softmax(dim=-1)
print(utils.get_corner(probs))
einops.reduce(probs, 'pos token -> pos', 'sum')

tensor([[6.6198e-04, 2.4113e-02, 9.5430e-04],
        [1.1472e-05, 1.7084e-05, 5.7139e-07],
        [1.9788e-05, 7.8417e-06, 1.0032e-06]], device='cuda:0')


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000],
       device='cuda:0')

In [21]:
last = logits[0, -1]
print(last.argmax())
gpt2.tokenizer.decode(last.argmax())

tensor(326, device='cuda:0')


' that'

In [24]:
print(last.topk(5))
for i in last.topk(5).indices:
    print(f'"{gpt2.tokenizer.decode(i)}"')

torch.return_types.topk(
values=tensor([16.2408, 15.2805, 14.8070, 13.7773, 13.6780], device='cuda:0'),
indices=tensor([326, 257, 262, 339, 465], device='cuda:0'))
" that"
" a"
" the"
" he"
" his"


In [30]:
print(t.round(probs[:, tokens.squeeze()].diag(1), decimals=3))

log_probs = logits.squeeze().log_softmax(dim=-1)
token_log_probs = log_probs[:, tokens.squeeze()].diag(1)
# print(token_log_probs)

# token_log_probs = log_probs[:, :-1].gather(dim=-1, index=tokens.squeeze()[1:, None]).squeeze()
token_log_probs

tensor([0.0080, 0.1840, 0.3330, 0.0130, 0.4980, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1410, 0.0040, 0.0090, 0.0150, 0.0260, 0.9140, 0.9800, 0.9960, 0.0090,
        0.8060, 0.0170, 0.5500, 0.0160, 0.0380, 0.0010, 0.4080, 0.0110, 0.0940,
        0.5280, 0.0020, 0.0010, 0.0020, 0.0050, 0.0320, 0.0190, 0.3050, 0.0910,
        0.0140, 0.0970, 0.0600, 0.1080, 0.9930, 0.9940, 0.9980, 0.0200],
       device='cuda:0')


tensor([-4.8580e+00, -1.6949e+00, -1.1008e+00, -4.3222e+00, -6.9673e-01,
        -1.0088e+01, -8.0420e+00, -1.0918e+01, -9.2068e+00, -1.9596e+00,
        -5.5019e+00, -4.7312e+00, -4.1974e+00, -3.6655e+00, -9.0204e-02,
        -1.9922e-02, -3.7245e-03, -4.6675e+00, -2.1606e-01, -4.0747e+00,
        -5.9848e-01, -4.1610e+00, -3.2812e+00, -7.3975e+00, -8.9542e-01,
        -4.5045e+00, -2.3666e+00, -6.3921e-01, -6.3236e+00, -7.3244e+00,
        -6.1944e+00, -5.3811e+00, -3.4567e+00, -3.9479e+00, -1.1884e+00,
        -2.4014e+00, -4.2923e+00, -2.3332e+00, -2.8089e+00, -2.2295e+00,
        -6.7844e-03, -5.6575e-03, -2.0999e-03, -3.9126e+00], device='cuda:0')

In [28]:
px.line(utils.to_numpy(token_log_probs), hover_name=str_tokens[1:])

# Structure

## What do the "big brothers" look like?

- GPT-3: https://arxiv.org/abs/2005.14165v4
- PaLM: https://jmlr.org/papers/v24/22-1144.html
- LLaMA: https://arxiv.org/abs/2302.13971

In [31]:
for name, p in gpt2.named_parameters():
  if ".0." in name or "blocks" not in name:
    print(name, p.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


In [32]:
for activation_name, activation in cache.items():
    # Only print for the first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 45, 768])
hook_pos_embed torch.Size([1, 45, 768])
blocks.0.hook_resid_pre torch.Size([1, 45, 768])
blocks.0.ln1.hook_scale torch.Size([1, 45, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 45, 768])
blocks.0.attn.hook_q torch.Size([1, 45, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 45, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 45, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 45, 45])
blocks.0.attn.hook_pattern torch.Size([1, 12, 45, 45])
blocks.0.attn.hook_z torch.Size([1, 45, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 45, 768])
blocks.0.hook_resid_mid torch.Size([1, 45, 768])
blocks.0.ln2.hook_scale torch.Size([1, 45, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 45, 768])
blocks.0.mlp.hook_pre torch.Size([1, 45, 3072])
blocks.0.mlp.hook_post torch.Size([1, 45, 3072])
blocks.0.hook_mlp_out torch.Size([1, 45, 768])
blocks.0.hook_resid_post torch.Size([1, 45, 768])
ln_final.hook_scale torch.Size([1, 45, 1])
ln_final.hook_normalized torc

In [34]:
from fancy_einsum import einsum

In [35]:
mlp_before = cache["normalized", 0, "ln2"]
# gpt2.W_in[0].shape
mlp_mid1 = einsum("batch pos d_model, d_model d_mlp -> batch pos d_mlp", mlp_before, gpt2.W_in[0]) + gpt2.b_in[0]
print(utils.get_corner(mlp_mid1))
print(utils.get_corner(cache["pre", 0, "mlp"]))

tensor([[[-0.1944, -2.0492, -2.7343],
         [ 0.3661, -1.2688, -1.3038],
         [ 0.0980, -1.5448, -1.3435]]], device='cuda:0')
tensor([[[-0.1944, -2.0492, -2.7343],
         [ 0.3661, -1.2688, -1.3038],
         [ 0.0980, -1.5448, -1.3435]]], device='cuda:0')


In [37]:
mlp_mid2 = utils.gelu_new(mlp_mid1)
# print(utils.get_corner(mlp_mid2))
# print(utils.get_corner(cache["post", 0, "mlp"]))
mlp_after = einsum("batch pos d_mlp, d_mlp d_model -> batch pos d_model", mlp_mid2, gpt2.W_out[0]) + gpt2.b_out[0]
print(utils.get_corner(mlp_after))
print(utils.get_corner(cache["mlp_out", 0]))


tensor([[[-0.5169,  0.2836,  0.4329],
         [-0.6278, -0.1156,  1.0684],
         [-1.6660,  0.3645, -0.8681]]], device='cuda:0')
tensor([[[-0.5169,  0.2836,  0.4329],
         [-0.6278, -0.1156,  1.0684],
         [-1.6660,  0.3645, -0.8681]]], device='cuda:0')


In [38]:
from transformer_lens.utils import get_corner

In [39]:
attention = cache["pattern", 0].squeeze()
print(attention.shape)
cv.attention.attention_pattern(attention=attention[5], tokens=str_tokens)
# Compare block 0 head 5 to block 5 head 5!

torch.Size([12, 45, 45])


# Induction Heads

In [45]:
# for layer in range(gpt2.cfg.n_layers):
    # attention_pattern = cache["pattern", layer]
    # display(cv.attention.attention_patterns(tokens=str_tokenek, attention=attention_pattern))

attention_pattern = cache["pattern", 5, "attn"].squeeze()
# print(utils.get_corner(attention_pattern))
print(attention_pattern.shape)

html = cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attention_pattern,
    attention_head_names=[f"L5H{i}" for i in range(12)],
)
# with open("attention.html", "w") as f:
    # f.write(f'{html}')
display(html)

torch.Size([12, 45, 45])


# Sources

These are the main inspirations:

* https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp
* https://transformer-circuits.pub/2021/framework/index.html

Videos:

* https://neelnanda.io/transformer-tutorial

Other:

* https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated
