In [1]:
import plotly.io as pio
pio.renderers.default = "notebook_connected" # or use "browser" if you want plots to open with browser

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML, display

import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import HookedRootModule, HookPoint  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache
import circuitsvis as cv

# Saves computation time, since we don't need it for the contents of this notebook
torch.set_grad_enabled(False)

def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    return px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)
    return px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)
    return px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs)

device = "cuda" if torch.cuda.is_available() else "cpu"

model = HookedTransformer.from_pretrained("gpt2-small", device=device)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer


# Tokenization

In [2]:
example_text = "The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph."
example_text_str_tokens = model.to_str_tokens(example_text)
print(example_text_str_tokens)

['<|endoftext|>', 'The', ' first', ' thing', ' you', ' need', ' to', ' figure', ' out', ' is', ' *', 'how', '*', ' things', ' are', ' token', 'ized', '.', ' `', 'model', '.', 'to', '_', 'str', '_', 't', 'ok', 'ens', '`', ' splits', ' a', ' string', ' into', ' the', ' tokens', ' *', 'as', ' a', ' list', ' of', ' sub', 'strings', '*,', ' and', ' so', ' lets', ' you', ' explore', ' what', ' the', ' text', ' looks', ' like', '.', ' To', ' demonstrate', ' this', ',', ' let', "'s", ' use', ' it', ' on', ' this', ' paragraph', '.']


In [3]:
example_text_tokens = model.to_tokens(example_text)
print(example_text_tokens)

tensor([[50256,   464,   717,  1517,   345,   761,   284,  3785,   503,   318,
          1635,  4919,     9,  1243,   389, 11241,  1143,    13,  4600, 19849,
            13,  1462,    62,  2536,    62,    83,   482,   641,    63, 30778,
           257,  4731,   656,   262, 16326,  1635,   292,   257,  1351,   286,
           850, 37336, 25666,   290,   523,  8781,   345,  7301,   644,   262,
          2420,  3073,   588,    13,  1675, 10176,   428,    11,  1309,   338,
           779,   340,   319,   428,  7322,    13]])


In [4]:
example_multi_text = ["The cat sat on the mat.", "The cat sat on the mat really hard."]
example_multi_text_tokens = model.to_tokens(example_multi_text)
print(example_multi_text_tokens)

tensor([[50256,   464,  3797,  3332,   319,   262,  2603,    13, 50256, 50256],
        [50256,   464,  3797,  3332,   319,   262,  2603,  1107,  1327,    13]])


In [5]:
cat_text = "The cat sat on the mat."
cat_logits = model(cat_text)
cat_probs = cat_logits.softmax(dim=-1)
print(f"Probability tensor shape [batch, position, d_vocab] == {cat_probs.shape}")

capital_the_token_index = model.to_single_token(" The")
print(f"| The| probability: {cat_probs[0, -1, capital_the_token_index].item():.2%}")

capital_the_token_index = model.to_single_token(".")
print(f"| .| probability: {cat_probs[0, -2, capital_the_token_index].item():.2%}")

Probability tensor shape [batch, position, d_vocab] == torch.Size([1, 8, 50257])
| The| probability: 11.98%
| .| probability: 5.55%


In [6]:
print(f"Token 256 - the most common pair of ASCII characters: |{model.to_string(256)}|")
# Squeeze means to remove dimensions of length 1. 
# Here, that removes the dummy batch dimension so it's a rank 1 tensor and returns a string
# Rank 2 tensors map to a list of strings
print(f"De-Tokenizing the example tokens: {model.to_string(example_text_tokens.squeeze())}")

Token 256 - the most common pair of ASCII characters: | t|
De-Tokenizing the example tokens: <|endoftext|>The first thing you need to figure out is *how* things are tokenized. `model.to_str_tokens` splits a string into the tokens *as a list of substrings*, and so lets you explore what the text looks like. To demonstrate this, let's use it on this paragraph.


In [7]:
print("With BOS:", model.get_token_position(" cat", "The cat sat on the mat"))
print("Without BOS:", model.get_token_position(" cat", "The cat sat on the mat", prepend_bos=False))

With BOS: 2
Without BOS: 1


In [8]:
print("First occurence", model.get_token_position(
    " cat", 
    "The cat sat on the mat. The mat sat on the cat.", 
    mode="first"))
print("Final occurence", model.get_token_position(
    " cat", 
    "The cat sat on the mat. The mat sat on the cat.", 
    mode="last"))

First occurence 2
Final occurence 13


In [9]:
print("Logits shape by default (with BOS)", model("Hello World").shape)
print("Logits shape with BOS", model("Hello World", prepend_bos=True).shape)
print("Logits shape without BOS - only 2 positions!", model("Hello World", prepend_bos=False).shape)

Logits shape by default (with BOS) torch.Size([1, 3, 50257])
Logits shape with BOS torch.Size([1, 3, 50257])
Logits shape without BOS - only 2 positions! torch.Size([1, 2, 50257])


In [10]:
ioi_logits_with_bos = model("Claire and Mary went to the shops, then Mary gave a bottle of milk to", prepend_bos=True)
mary_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(" Mary")].item()
claire_logit_with_bos = ioi_logits_with_bos[0, -1, model.to_single_token(" Claire")].item()
print(f"Logit difference with BOS: {(claire_logit_with_bos - mary_logit_with_bos):.3f}")

ioi_logits_without_bos = model("Claire and Mary went to the shops, then Mary gave a bottle of milk to", prepend_bos=False)
mary_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(" Mary")].item()
claire_logit_without_bos = ioi_logits_without_bos[0, -1, model.to_single_token(" Claire")].item()
print(f"Logit difference without BOS: {(claire_logit_without_bos - mary_logit_without_bos):.3f}")

Logit difference with BOS: 6.754
Logit difference without BOS: 2.782


In [11]:
print(f"| Claire| -> {model.to_str_tokens(' Claire', prepend_bos=False)}")
print(f"|Claire| -> {model.to_str_tokens('Claire', prepend_bos=False)}")

| Claire| -> [' Claire']
|Claire| -> ['Cl', 'aire']


In [13]:
model.tokenizer

PreTrainedTokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_len=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<|endoftext|>'})

Key tokenization takeaways:
- There are a number of utility functions for moving between tokens/strings and finding tokens in strings
- Use single token names when doing tasks to make interpretation easier
- Beware the BOS which can meaningfully change the output of the model (BOS is used in some models and not in others)

# Factored Matrix Class

Low rank factored matrices are useful for transformer interpretability because they crop in the W_OV circuit/interpretation. Projecting the W_OV output to the original input space is a useful way to interpret the model and this matrix is itself a low rank matrix. 

Where to move information from W_QK = W_Q W_K^T determines the attention pattern which is a map from residul to residual via `residual @ W_QK @ residual.T` (the attention pattern acts on the position dimension of the [position, d_model] shaped residual).

What information to move W_OV = W_V @ W_O. 

In [28]:
A = torch.randn(5, 2)
B = torch.randn(2, 5)
AB = A @ B
AB_factor = FactoredMatrix(A, B)
print("Norms:")
print(AB.norm())
print(AB_factor.norm())

print(f"Right dimension: {AB_factor.rdim}, Left dimension: {AB_factor.ldim}, Hidden dimension: {AB_factor.mdim}")

Norms:
tensor(10.0432)
tensor(10.0432)
Right dimension: 5, Left dimension: 5, Hidden dimension: 2


In [29]:
print("Eigenvalues:")
print(torch.linalg.eig(AB).eigenvalues)
print(AB_factor.eigenvalues)
print()
print("Singular Values:")
print(torch.linalg.svd(AB).S)
print(AB_factor.S)

Eigenvalues:
tensor([ 4.7684e-07+0.j, -4.3484e+00+0.j, -2.2422e+00+0.j, -8.1658e-09+0.j,
        -2.1347e-07+0.j])
tensor([-4.3484+0.j, -2.2422+0.j])

Singular Values:
tensor([7.8301e+00, 6.2893e+00, 3.2326e-07, 1.6823e-07, 2.1128e-09])
tensor([7.8301, 6.2893])


In [30]:
C = torch.randn(5, 300)
ABC = AB @ C
ABC_factor = AB_factor @ C
print("Unfactored:", ABC.shape, ABC.norm())
print("Factored:", ABC_factor.shape, ABC_factor.norm())
print(f"Right dimension: {ABC_factor.rdim}, Left dimension: {ABC_factor.ldim}, Hidden dimension: {ABC_factor.mdim}")

Unfactored: torch.Size([5, 300]) tensor(166.8944)
Factored: torch.Size([5, 300]) tensor(166.8944)
Right dimension: 300, Left dimension: 5, Hidden dimension: 2


In [31]:
AB_unfactored = AB_factor.AB
print(torch.isclose(AB_unfactored, AB).all())

tensor(True)


# EigenValue Copying Scores

In [32]:
OV_circuit_all_heads = model.OV
print(OV_circuit_all_heads)

FactoredMatrix: Shape(torch.Size([12, 12, 768, 768])), Hidden Dim(64)


In [33]:
OV_circuit_all_heads_eigenvalues = OV_circuit_all_heads.eigenvalues 
print(OV_circuit_all_heads_eigenvalues.shape)
print(OV_circuit_all_heads_eigenvalues.dtype)

torch.Size([12, 12, 64])
torch.complex64


In [35]:
OV_copying_score = OV_circuit_all_heads_eigenvalues.sum(dim=-1).real / OV_circuit_all_heads_eigenvalues.abs().sum(dim=-1)
fig = imshow(utils.to_numpy(OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head in GPT-2 Small", zmax=1.0, zmin=-1.0)

In [37]:
fig = scatter(x=OV_circuit_all_heads_eigenvalues[-1, -1, :].real, y=OV_circuit_all_heads_eigenvalues[-1, -1, :].imag, title="Eigenvalues of Head L11H11 of GPT-2 Small", xaxis="Real", yaxis="Imaginary")

In [38]:
full_OV_circuit = model.embed.W_E @ OV_circuit_all_heads @ model.unembed.W_U
print(full_OV_circuit)

FactoredMatrix: Shape(torch.Size([12, 12, 50257, 50257])), Hidden Dim(64)


In [39]:
full_OV_circuit_eigenvalues = full_OV_circuit.eigenvalues
print(full_OV_circuit_eigenvalues.shape)
print(full_OV_circuit_eigenvalues.dtype)

torch.Size([12, 12, 64])
torch.complex64


In [41]:
full_OV_copying_score = full_OV_circuit_eigenvalues.sum(dim=-1).real / full_OV_circuit_eigenvalues.abs().sum(dim=-1)
fig = imshow(utils.to_numpy(full_OV_copying_score), xaxis="Head", yaxis="Layer", title="OV Copying Score for each head in GPT-2 Small", zmax=1.0, zmin=-1.0)

In [43]:
fig = scatter(x=full_OV_copying_score.flatten(), y=OV_copying_score.flatten(), hover_name=[f"L{layer}H{head}" for layer in range(12) for head in range(12)], title="OV Copying Score for each head in GPT-2 Small", xaxis="Full OV Copying Score", yaxis="OV Copying Score")

Factored Matrix Class Notes:
- Factored Matrix Class makes it faster to do matrix operations with large but low rank matrices
- This includes calculating eigenvalues for the WOV circuit and the full WOV circuit. 

Side note:
- positive eigenvalues in the WOV circuit are an indication of copying behavior. 

# Generating Text

In [49]:
model.generate("(CNN) President Trump caught in embarrassing new scandal\n", max_new_tokens=50, temperature=0.7, prepend_bos=True)

  0%|          | 0/50 [00:00<?, ?it/s]

"(CNN) President Trump caught in embarrassing new scandal\n\nWhat's the amount of money that Trump paid for his inaugural balls after he took office?\n\nActually, there was $3 billion.\n\nIf this were an accounting book, we'd have a very different set of numbers.\n\n"

# Hook Points

In [50]:
from transformer_lens.hook_points import HookedRootModule, HookPoint

class SquareThenAdd(nn.Module):
    def __init__(self, offset):
        super().__init__()
        self.offset = nn.Parameter(torch.tensor(offset))
        self.hook_square = HookPoint()

    def forward(self, x):
        # The hook_square doesn't change the value, but lets us access it
        square = self.hook_square(x * x)
        return self.offset + square

class TwoLayerModel(HookedRootModule):
    def __init__(self):
        super().__init__()
        self.layer1 = SquareThenAdd(3.0)
        self.layer2 = SquareThenAdd(-4.0)
        self.hook_in = HookPoint()
        self.hook_mid = HookPoint()
        self.hook_out = HookPoint()

        # We need to call the setup function of HookedRootModule to build an
        # internal dictionary of modules and hooks, and to give each hook a name
        super().setup()

    def forward(self, x):
        # We wrap the input and each layer's output in a hook - they leave the
        # value unchanged (unless there's a hook added to explicitly change it),
        # but allow us to access it.
        x_in = self.hook_in(x)
        x_mid = self.hook_mid(self.layer1(x_in))
        x_out = self.hook_out(self.layer2(x_mid))
        return x_out

model = TwoLayerModel()

In [56]:
out, cache = model.run_with_cache(torch.tensor(5.0))
print("Model output:", out.item())
for key in cache:
    print(f"Value cached at hook {key}", cache[key].item())

Model output: 780.0
Value cached at hook hook_in 5.0
Value cached at hook layer1.hook_square 25.0
Value cached at hook hook_mid 28.0
Value cached at hook layer2.hook_square 784.0
Value cached at hook hook_out 780.0


In [57]:
def set_to_zero_hook(tensor, hook):
    print(hook.name)
    return torch.tensor(0.0)

print(
    "Output after intervening on layer2.hook_scaled",
    model.run_with_hooks(
        torch.tensor(5.0), fwd_hooks=[("layer2.hook_square", set_to_zero_hook)]
    ).item(),
)

layer2.hook_square
Output after intervening on layer2.hook_scaled -4.0


In [58]:
model(torch.tensor(5.0))

tensor(780.)

# Trained Checkpoint Models

In [59]:
from transformer_lens.loading_from_pretrained import get_checkpoint_labels
for model_name in ["attn-only-2l", "solu-12l", "stanford-gpt2-small-a"]:
    checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(model_name)
    line(checkpoint_labels, xaxis="Checkpoint Index", yaxis=f"Checkpoint Value ({checkpoint_label_type})", title=f"Checkpoint Values for {model_name} (Log scale)", log_y=True, markers=True)
for model_name in ["solu-1l-pile", "solu-6l-pile"]:
    checkpoint_labels, checkpoint_label_type = get_checkpoint_labels(model_name)
    line(checkpoint_labels, xaxis="Checkpoint Index", yaxis=f"Checkpoint Value ({checkpoint_label_type})", title=f"Checkpoint Values for {model_name} (Linear scale)", log_y=False, markers=True)

In [60]:
from transformer_lens import evals
# We use the two layer model with SoLU activations, chosen fairly arbitrarily as being both small (so fast to download and keep in memory) and pretty good at the induction task.
model_name = "solu-2l"
# We can load a model from a checkpoint by specifying the checkpoint_index, -1 means the final checkpoint
checkpoint_indices = [10, 25, 35, 60, -1]
checkpointed_models = []
tokens_trained_on = []
induction_losses = []

In [62]:
for index in checkpoint_indices:
    # Load the model from the relevant checkpoint by index
    model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device = "cpu")
    checkpointed_models.append(model_for_this_checkpoint)

    tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
    tokens_trained_on.append(tokens_seen_for_this_checkpoint)

    induction_loss_for_this_checkpoint = evals.induction_loss(model_for_this_checkpoint).item()
    induction_losses.append(induction_loss_for_this_checkpoint)

RuntimeError: Attempting to deserialize object on a CUDA device but torch.cuda.is_available() is False. If you are running on a CPU-only machine, please use torch.load with map_location=torch.device('cpu') to map your storages to the CPU.

In [None]:
line(induction_losses, x=tokens_trained_on, xaxis="Tokens Trained On", yaxis="Induction Loss", title="Induction Loss over training: solu-2l", markers=True, log_x=True)