# P2020 Town Hall
## Google, 2024-09-17

![Machine Learning](https://imgs.xkcd.com/comics/machine_learning.png)

# Setup

Most of this is "borrowed" from a bunch of examples. See [Sources](#Sources).

To open this in Google Colab, click [here](https://colab.research.google.com/github/klao/town-hall-2024/blob/master/transformer_clean.ipynb).

In [43]:
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
else:
    # See README.md for local setup
    pass

In [44]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from fancy_einsum import einsum
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import get_corner
import circuitsvis as cv

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

# Let's look at the model

In [45]:
gpt2 = HookedTransformer.from_pretrained("gpt2-small")

Loaded pretrained model gpt2-small into HookedTransformer


In [46]:
gpt2

HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0-11): 12 x TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp_out): HookPoint()
      (h

In [47]:
gpt2.cfg

HookedTransformerConfig:
{'act_fn': 'gelu_new',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 3072,
 'd_model': 768,
 'd_vocab': 50257,
 'd_vocab_out': 50257,
 'default_prepend_bos': True,
 'device': device(type='cuda'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.02886751345948129,
 'model_name': 'gpt2',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 12,
 'n_layers': 12,
 'n_params': 84934656,
 'normalization_type': 'LNPre',
 'original_architecture': 'GPT2LMHeadModel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': None,
 'tokenizer_name': 'gpt2',
 'tokenizer_prepends_bos': False,
 'use_attn_in': F

## Input: "What does this eat?" aka Tokenization

In [48]:
gpt2.tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	50256: AddedToken("<|endoftext|>", rstrip=False, lstrip=False, single_word=False, normalized=True, special=True),
}

In [49]:
vocab_sorted = sorted(gpt2.tokenizer.vocab.items(), key=lambda x: x[1])
vocab_sorted[:10]

[('!', 0),
 ('"', 1),
 ('#', 2),
 ('$', 3),
 ('%', 4),
 ('&', 5),
 ("'", 6),
 ('(', 7),
 (')', 8),
 ('*', 9)]

In [50]:
vocab_sorted[-10:]

[('Ġ(/', 50247),
 ('âĢ¦."', 50248),
 ('Compar', 50249),
 ('Ġamplification', 50250),
 ('ominated', 50251),
 ('Ġregress', 50252),
 ('ĠCollider', 50253),
 ('Ġinformants', 50254),
 ('Ġgazed', 50255),
 ('<|endoftext|>', 50256)]

In [51]:
text = "Hullo, my name is"

In [52]:
print(gpt2.to_str_tokens(text))
print(gpt2.to_tokens(text))

['<|endoftext|>', 'H', 'ull', 'o', ',', ' my', ' name', ' is']
tensor([[50256,    39,   724,    78,    11,   616,  1438,   318]],
       device='cuda:0')


## Output: "What comes out?"

In [53]:
predictions = gpt2(gpt2.to_tokens(text)).squeeze()
predictions = predictions.softmax(-1)
print(predictions.shape)

torch.Size([8, 50257])


In [54]:
best_predictions = predictions[-1, :].topk(5)
best_predictions

torch.return_types.topk(
values=tensor([0.0230, 0.0181, 0.0119, 0.0113, 0.0113], device='cuda:0'),
indices=tensor([1757,  367, 2940, 3619, 3899], device='cuda:0'))

In [55]:
for k in best_predictions.indices:
    print(f'"{gpt2.tokenizer.decode([k])}"')

" John"
" H"
" Mark"
" Jack"
" Michael"


# Longer example

In [56]:
text = ("This is a story about Quomatarus."
  + " When one day Quomatarus decided")

In [57]:
tokens = gpt2.to_tokens(text)
str_tokens = gpt2.to_str_tokens(text)
print(str_tokens)
print(tokens.shape)

['<|endoftext|>', 'This', ' is', ' a', ' story', ' about', ' Qu', 'om', 'atar', 'us', '.', ' When', ' one', ' day', ' Qu', 'om', 'atar', 'us', ' decided']
torch.Size([1, 19])


In [59]:
logits, cache = gpt2.run_with_cache(tokens, remove_batch_dim=True)
print(cache)

ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'blocks.0.hook_resid_post', 'blocks.1.hook_resid_pre', 'blocks.1.ln1.hook_scale', 'blocks.1.ln1.hook_normalized', 'blocks.1.attn.hook_q', 'blocks.1.attn.hook_k', 'blocks.1.attn.hook_v', 'blocks.1.attn.hook_attn_scores', 'blocks.1.attn.hook_pattern', 'blocks.1.attn.hook_z', 'blocks.1.hook_attn_out', 'blocks.1.hook_resid_mid', 'blocks.1.ln2.hook_scale', 'blocks.1.ln2.hook_normalized', 'blocks.1.mlp.hook_pre', 'blocks.1.mlp.hook_post', 'blocks.1.hook_mlp_out', 'blocks.1.hook_resid_post', 'blocks.2.hook_re

In [60]:
# Logits to probabilities

probs = logits.softmax(-1).squeeze()

In [61]:
# How well did it predict the actual tokens?
x = probs.squeeze()[:, tokens.squeeze()]
predictions = x.diag(1)
t.round(predictions, decimals=3)

tensor([0.0080, 0.1840, 0.3330, 0.0130, 0.4980, 0.0000, 0.0000, 0.0000, 0.0000,
        0.1410, 0.0040, 0.0090, 0.0150, 0.0260, 0.9140, 0.9800, 0.9960, 0.0090],
       device='cuda:0')

In [62]:
# Plot it! Which tokens did it do well on? Which poorly? Why?
px.line(predictions.cpu(), hover_name=str_tokens[1:])

## Attention

In [63]:
cache["pattern", 0, "attn"].shape

torch.Size([12, 19, 19])

In [64]:
# Look at attention patterns
# cv.attention.attention_pattern(s)

# Compare block 0 head 5 to block 5 head 5!
cv.attention.attention_pattern(cache["pattern", 0, "attn"][5, :, :], tokens=str_tokens)

In [65]:

attention_pattern = cache["pattern", 5, "attn"].squeeze()

display(cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attention_pattern,
    attention_head_names=[f"H{i}" for i in range(12)],
))

# Sources

These are the main inspirations:

* https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp
* https://transformer-circuits.pub/2021/framework/index.html

Videos:

* https://neelnanda.io/transformer-tutorial