# P2020 Town Hall
## Google, 2024-09-17

![Machine Learning](https://imgs.xkcd.com/comics/machine_learning.png)

# Setup

Most of this is "borrowed" from a bunch of examples. See [Sources](#Sources).

To open this in Google Colab, click [here](https://colab.research.google.com/github/klao/town-hall-2024/blob/master/transformer_clean.ipynb).

In [None]:
import sys

IN_COLAB = 'google.colab' in sys.modules

if IN_COLAB:
    %pip install einops
    %pip install jaxtyping
    %pip install transformer_lens
    %pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
else:
    # See README.md for local setup
    pass

In [None]:
import os
import sys
import plotly.express as px
import torch as t
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
import numpy as np
import einops
from fancy_einsum import einsum
from jaxtyping import Int, Float
from typing import List, Optional, Tuple
import functools
from tqdm import tqdm
from IPython.display import display
import webbrowser
import gdown
from transformer_lens.hook_points import HookPoint
from transformer_lens import utils, HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
from transformer_lens.utils import get_corner
import circuitsvis as cv

# Saves computation time, since we don't need it for the contents of this notebook
t.set_grad_enabled(False)

device = t.device("cuda" if t.cuda.is_available() else "cpu")

# Let's look at the model

In [None]:
gpt2 = HookedTransformer.from_pretrained("gpt2-small")

In [None]:
gpt2

In [None]:
gpt2.cfg

## Input: "What does this eat?" aka Tokenization

In [None]:
gpt2.tokenizer

In [None]:
vocab_sorted = sorted(gpt2.tokenizer.vocab.items(), key=lambda x: x[1])
vocab_sorted[:10]

In [None]:
vocab_sorted[-10:]

In [None]:
text = "Hullo, my name is"

In [None]:
print(gpt2.to_str_tokens(text))
print(gpt2.to_tokens(text))

## Output: "What comes out?"

In [None]:
predictions = gpt2(gpt2.to_tokens(text)).squeeze()
predictions = predictions.softmax(-1)
print(predictions.shape)

In [None]:
best_predictions = predictions[-1, :].topk(5)
best_predictions

In [None]:
for k in best_predictions.indices:
    print(f'"{gpt2.tokenizer.decode([k])}"')

# Longer example

In [None]:
text = ("This is a story about Quomatarus."
  + " When one day Quomatarus decided")

In [None]:
tokens = gpt2.to_tokens(text)
str_tokens = gpt2.to_str_tokens(text)
print(str_tokens)
print(tokens.shape)

In [None]:
logits, cache = gpt2.run_with_cache(tokens, remove_batch_dim=True)
print(logits.shape)
print(cache)

In [None]:
# Logits to probabilities

probs = logits.softmax(-1).squeeze()

In [None]:
# How well did it predict the actual tokens?
x = probs.squeeze()[:, tokens.squeeze()]
predictions = x.diag(1)
t.round(predictions, decimals=3)

In [None]:
# Plot it! Which tokens did it do well on? Which poorly? Why?
px.line(predictions.cpu(), hover_name=str_tokens[1:])

## Attention

In [None]:
cache["pattern", 0, "attn"].shape

In [None]:
# Look at attention patterns
# cv.attention.attention_pattern(s)

# Compare block 0 head 5 to block 5 head 5!
cv.attention.attention_pattern(cache["pattern", 0, "attn"][5, :, :], tokens=str_tokens)

In [None]:

attention_pattern = cache["pattern", 5, "attn"].squeeze()

display(cv.attention.attention_patterns(
    tokens=str_tokens,
    attention=attention_pattern,
    attention_head_names=[f"H{i}" for i in range(12)],
))

# Sources

These are the main inspirations:

* https://arena-ch1-transformers.streamlit.app/[1.2]_Intro_to_Mech_Interp
* https://transformer-circuits.pub/2021/framework/index.html

Videos:

* https://neelnanda.io/transformer-tutorial