# Intro

![Machine Learning](https://imgs.xkcd.com/comics/machine_learning.png)

# Setup

I stole this from an example, we don't need all this complexity. But I think it's cool to see.

Moreover, most of my python / jupyter / colab knowledge is copied from a bunch of examples. See [Sources](#Sources).

To open this in Google Colab, click [here](https://colab.research.google.com/github/klao/t9r-class/blob/master/htt_clean.ipynb).

In [1]:
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
else:
    # See README.md for local setup
    pass

In [2]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from fancy_einsum import einsum
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import get_corner
import circuitsvis as cv

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

  return torch._C._cuda_getDeviceCount() > 0


# Getting acquainted with the model

In [3]:
gpt2 = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [4]:
gpt2

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [5]:
gpt2.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cpu'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': Fa

## Input: "What does this eat?" aka Tokenization

In [6]:
gpt2.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [7]:
gpt2.tokenizer.vocab

{'Ġchosen': 7147,
 'Ġdates': 9667,
 'ĠLil': 16342,
 'ĠPro': 1041,
 'ĠFlavor': 26438,
 'Ġfatal': 10800,
 'Ġextends': 14582,
 'Ġodor': 28192,
 'ĠVotes': 39584,
 '434': 47101,
 'Ġdropping': 12047,
 'Ġresource': 8271,
 'ĠMeow': 42114,
 'Ġsill': 49276,
 '----------------': 1783,
 'Ġempower': 17549,
 'Ġmushrooms': 23452,
 'ricting': 42870,
 'ĠOfficial': 15934,
 'äº': 12859,
 'armed': 12026,
 'Ġstrings': 13042,
 'Ġmemorial': 17357,
 'ĠTill': 17888,
 'Ġsmartphones': 18151,
 'imore': 9401,
 'ĠCycle': 26993,
 'CAR': 20034,
 'Ġproclaimed': 25546,
 'Wa': 33484,
 'ĠEncounter': 40998,
 'Ġdisbanded': 47302,
 'Consumer': 49106,
 'otaur': 35269,
 'Ġarisen': 42091,
 'running': 20270,
 'Ġ22': 2534,
 'va': 6862,
 'ĠApply': 27967,
 'ĠScheme': 32448,
 'ĠBoo': 21458,
 'Virtual': 37725,
 'Ġtaxed': 31075,
 'Taking': 26556,
 'ĠElk': 40151,
 'ystem': 6781,
 'ĠDeath': 5830,
 'ĠACC': 15859,
 'ĠMikhail': 42040,
 'Ġlockout': 48449,
 'Ter': 15156,
 'Ď': 202,
 'ĠWilliam': 3977,
 'adian': 18425,
 '476': 35435,
 'ĠXu': 

In [8]:
vocab_sorted = sorted(gpt2.tokenizer.vocab.items(), key=lambda x: x[1])

In [9]:
vocab_sorted[-50:]

[('Ġkernels', 50207),
 ('ĠFranÃ§ois', 50208),
 ('ĠDuff', 50209),
 ('ĠPon', 50210),
 ('ĠLeica', 50211),
 ('ĠGarmin', 50212),
 ('Ġorphans', 50213),
 ('ĠClaudia', 50214),
 ('Ġcalendars', 50215),
 ('ĠLeilan', 50216),
 ('ento', 50217),
 ('Rocket', 50218),
 ('Ġbrunch', 50219),
 ('ĠHawking', 50220),
 ('ainers', 50221),
 ('Ġsensibilities', 50222),
 ('ĠkW', 50223),
 ('ĠKand', 50224),
 ('Ġreclaimed', 50225),
 ('Ġinterestingly', 50226),
 ('×©', 50227),
 ('romy', 50228),
 ('JM', 50229),
 ('ĠEnhancement', 50230),
 ('bush', 50231),
 ('Skip', 50232),
 ('Ġrappers', 50233),
 ('Ġgazing', 50234),
 ('pedia', 50235),
 ('athlon', 50236),
 ('Revolution', 50237),
 ('Ġsnipers', 50238),
 ('Ġreverted', 50239),
 ('Ġconglomerate', 50240),
 ('Terry', 50241),
 ('794', 50242),
 ('Ġharsher', 50243),
 ('Ġdesolate', 50244),
 ('ĠHitman', 50245),
 ('Commission', 50246),
 ('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('

In [10]:
gpt2.to_str_tokens("827364 + 128736 = 9868734?")

['<|endoftext|>',
 '8',
 '27',
 '364',
 ' +',
 ' 12',
 '87',
 '36',
 ' =',
 ' 98',
 '687',
 '34',
 '?']

In [11]:
text = ("This is a story about Quomatarus."
  + " When one day Quomatarus decided to do something different and bought a plane ticket to Lamanandu."
  + " When he arrived to the airport Quomatarus noticed")

In [12]:
tokens = gpt2.to_tokens(text)
str_tokens = gpt2.to_str_tokens(text)
print(str_tokens)
print(tokens.shape)

['<|endoftext|>', 'This', ' is', ' a', ' story', ' about', ' Qu', 'om', 'atar', 'us', '.', ' When', ' one', ' day', ' Qu', 'om', 'atar', 'us', ' decided', ' to', ' do', ' something', ' different', ' and', ' bought', ' a', ' plane', ' ticket', ' to', ' L', 'aman', 'and', 'u', '.', ' When', ' he', ' arrived', ' to', ' the', ' airport', ' Qu', 'om', 'atar', 'us', ' noticed']
torch.Size([1, 45])


In [13]:
tokens

tensor([[50256,  1212,   318,   257,  1621,   546,  2264,   296,  9459,   385,
            13,  1649,   530,  1110,  2264,   296,  9459,   385,  3066,   284,
           466,  1223,  1180,   290,  5839,   257,  6614,  7846,   284,   406,
         10546,   392,    84,    13,  1649,   339,  5284,   284,   262,  9003,
          2264,   296,  9459,   385,  6810]])

### Embedding

In [14]:
gpt2.W_E.shape

torch.Size([50257, 768])

In [15]:
get_corner(gpt2.W_E)

tensor([[-0.1106, -0.0398,  0.0326],
        [ 0.0359, -0.0531,  0.0418],
        [-0.1301,  0.0453,  0.1815]], requires_grad=True)

## Output: "What comes out?"

In [16]:
print(gpt2(tokens).shape)
# print(gpt2(tokens, return_type="loss"))

torch.Size([1, 45, 50257])


In [17]:
#
# logits, cache = gpt2.run_with_cache(tokens, remove_batch_dim=True)
logits, cache = gpt2.run_with_cache(tokens)
print(logits.shape)
print(cache)

torch.Size([1, 45, 50257])
ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_res

In [18]:
# Logits? Probabilities?
get_corner(logits)

tensor([[[ 7.5261, 11.1214,  7.8919],
         [ 3.8748,  4.2731,  0.8752],
         [ 5.0762,  4.1506,  2.0943]]])

In [19]:
probs = logits.squeeze().softmax(dim=-1)
print(probs.shape)
get_corner(probs)
einops.reduce(probs, "pos tokens -> pos", "sum")

torch.Size([45, 50257])


tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])

In [20]:
# Next token?
probs[44, :].argmax()
gpt2.tokenizer.decode([probs[44, :].argmax().item()])

' that'

In [21]:
next = probs[44, :].topk(10)
next
for k in next.indices:
    print(f"'{gpt2.tokenizer.decode(k)}'")

' that'
' a'
' the'
' he'
' his'
' something'
' an'
' there'
' some'
' how'


In [22]:
# How well did it predict the actual tokens?
# Log probs
x = probs[:, tokens.squeeze()]
x.diag(1).shape

torch.Size([44])

In [23]:
# Plot it! Which tokens did it do well on? Which poorly? Why?
px.line(utils.to_numpy(x.diag(1)), hover_name=str_tokens[1:])

# Structure

## What do the "big brothers" look like?

- GPT-3: https://arxiv.org/abs/2005.14165v4
- PaLM: https://jmlr.org/papers/v24/22-1144.html
- LLaMA: https://arxiv.org/abs/2302.13971

In [24]:
for name, p in gpt2.named_parameters():
  if ".0." in name or "blocks" not in name:
    print(name, p.shape)

embed.W_E torch.Size([50257, 768])
pos_embed.W_pos torch.Size([1024, 768])
blocks.0.attn.W_Q torch.Size([12, 768, 64])
blocks.0.attn.W_K torch.Size([12, 768, 64])
blocks.0.attn.W_V torch.Size([12, 768, 64])
blocks.0.attn.W_O torch.Size([12, 64, 768])
blocks.0.attn.b_Q torch.Size([12, 64])
blocks.0.attn.b_K torch.Size([12, 64])
blocks.0.attn.b_V torch.Size([12, 64])
blocks.0.attn.b_O torch.Size([768])
blocks.0.mlp.W_in torch.Size([768, 3072])
blocks.0.mlp.b_in torch.Size([3072])
blocks.0.mlp.W_out torch.Size([3072, 768])
blocks.0.mlp.b_out torch.Size([768])
unembed.W_U torch.Size([768, 50257])
unembed.b_U torch.Size([50257])


In [25]:
for activation_name, activation in cache.items():
    # Only print for the first layer
    if ".0." in activation_name or "blocks" not in activation_name:
        print(activation_name, activation.shape)

hook_embed torch.Size([1, 45, 768])
hook_pos_embed torch.Size([1, 45, 768])
blocks.0.hook_resid_pre torch.Size([1, 45, 768])
blocks.0.ln1.hook_scale torch.Size([1, 45, 1])
blocks.0.ln1.hook_normalized torch.Size([1, 45, 768])
blocks.0.attn.hook_q torch.Size([1, 45, 12, 64])
blocks.0.attn.hook_k torch.Size([1, 45, 12, 64])
blocks.0.attn.hook_v torch.Size([1, 45, 12, 64])
blocks.0.attn.hook_attn_scores torch.Size([1, 12, 45, 45])
blocks.0.attn.hook_pattern torch.Size([1, 12, 45, 45])
blocks.0.attn.hook_z torch.Size([1, 45, 12, 64])
blocks.0.hook_attn_out torch.Size([1, 45, 768])
blocks.0.hook_resid_mid torch.Size([1, 45, 768])
blocks.0.ln2.hook_scale torch.Size([1, 45, 1])
blocks.0.ln2.hook_normalized torch.Size([1, 45, 768])
blocks.0.mlp.hook_pre torch.Size([1, 45, 3072])
blocks.0.mlp.hook_post torch.Size([1, 45, 3072])
blocks.0.hook_mlp_out torch.Size([1, 45, 768])
blocks.0.hook_resid_post torch.Size([1, 45, 768])
ln_final.hook_scale torch.Size([1, 45, 1])
ln_final.hook_normalized torc

In [28]:
# Look at attention patterns
# cv.attention.attention_pattern(s), don't forget to squeeze!

# Compare block 0 head 5 to block 5 head 5!
cv.attention.attention_pattern(cache["blocks.5.attn.hook_pattern"][0, 5, :, :], tokens=str_tokens)

# Induction Heads

In [30]:

attention_pattern = cache["pattern", 5, "attn"].squeeze()

display(cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attention_pattern,
    attention_head_names=[f"L5H{i}" for i in range(12)],
))

# Sources

These are the main inspirations:

* https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp
* https://transformer-circuits.pub/2021/framework/index.html

Videos:

* https://neelnanda.io/transformer-tutorial

Other:

* https://www.lesswrong.com/posts/TvrfY4c9eaGLeyDkE/induction-heads-illustrated
