## Sources

- Source [paper](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html)

- Transformer interpretability [lib](https://github.com/neelnanda-io/TransformerLens)

- This code is mostly taken from [here](https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb)

- And from [here](https://colab.research.google.com/github/neelnanda-io/Easy-Transformer/blob/main/Exploratory_Analysis_Demo.ipynb#scrollTo=ep-DB-c05OEZ)

# Setup

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
!pip install torchtyping
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    %pip install git+https://github.com/neelnanda-io/PySvelte.git
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torchtyping
  Downloading torchtyping-0.1.4-py3-none-any.whl (17 kB)
Collecting typeguard>=2.11.1
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Installing collected packages: typeguard, torchtyping
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.7.1
    Uninstalling typeguard-2.7.1:
      Successfully uninstalled typeguard-2.7.1
Successfully installed torchtyping-0.1.4 typeguard-2.13.3
Running as a Colab notebook
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting git+https://github.com/neelnanda-io/TransformerLens.git
  Cloning https://github.com/neelnanda-io/TransformerLens.git to /tmp/pip-req-build-dxl896ng
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/TransformerLens.git /tmp/pip-req-build-dxl896ng
  Resolved https://github.com/neelnand

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Connor")

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f955ef73820>

In [None]:
from transformer_lens import evals
import matplotlib.pyplot as plt
import collections
import plotly.graph_objects as go

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# useful for sanity checks
model = HookedTransformer.from_pretrained(
    "attn-only-2l",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading (…)"model_final.pth";:   0%|          | 0.00/210M [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.04M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Loaded pretrained model attn-only-2l into HookedTransformer


In [None]:
example_prompt = "A B C A B"
example_answer = " C"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|BOS|>', 'A', ' B', ' C', ' A', ' B']
Tokenized answer: [' C']


Top 0th token. Logit: 15.36 Prob: 49.53% Token: | C|
Top 1th token. Logit: 13.97 Prob: 12.33% Token: | B|
Top 2th token. Logit: 13.03 Prob:  4.79% Token: | A|
Top 3th token. Logit: 12.41 Prob:  2.59% Token: |.|
Top 4th token. Logit: 12.20 Prob:  2.09% Token: | D|
Top 5th token. Logit: 11.93 Prob:  1.61% Token: |,|
Top 6th token. Logit: 11.62 Prob:  1.17% Token: | and|
Top 7th token. Logit: 11.48 Prob:  1.02% Token: |’|
Top 8th token. Logit: 11.45 Prob:  0.99% Token: | E|
Top 9th token. Logit: 11.18 Prob:  0.76% Token: | (|


In [None]:
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
utils.test_prompt(example_prompt, example_answer, model, prepend_bos=True)

Tokenized prompt: ['<|BOS|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 14.52 Prob:  9.41% Token: | the|
Top 1th token. Logit: 13.43 Prob:  3.18% Token: | his|
Top 2th token. Logit: 13.43 Prob:  3.17% Token: | help|
Top 3th token. Logit: 13.42 Prob:  3.13% Token: | a|
Top 4th token. Logit: 13.32 Prob:  2.84% Token: | be|
Top 5th token. Logit: 13.13 Prob:  2.36% Token: | her|
Top 6th token. Logit: 12.78 Prob:  1.66% Token: | get|
Top 7th token. Logit: 12.66 Prob:  1.46% Token: | go|
Top 8th token. Logit: 12.64 Prob:  1.44% Token: | keep|
Top 9th token. Logit: 12.54 Prob:  1.30% Token: | make|


## Early Heads are Induction Heads(?!)

A really weird observation is that some of the early heads detecting duplicated tokens are induction heads, not just direct duplicate token heads. This is very weird! What's up with that? 

First off, what's an induction head? An induction head is an important type of attention head that can detect and continue repeated sequences. It is the second head in a two head induction circuit, which looks for previous copies of the current token and attends to the token *after* it, and then copies that to the current position and predicts that it will come next. They're enough of a big deal that [we wrote a whole paper on them](https://transformer-circuits.pub/2022/in-context-learning-and-induction-heads/index.html).

![](https://pbs.twimg.com/media/FNWAzXjVEAEOGRe.jpg)

Second, why is it surprising that they come up here? It's surprising because it feels like overkill. The model doesn't care about *what* token comes after the first copy of the subject, just that it's duplicated. And it already has simpler duplicate token heads. My best guess is that it just already had induction heads around and that, in addition to their main function, they *also* only activate on duplicated tokens. So it was useful to repurpose this existing machinery. 

This suggests that as we look for circuits in larger models life may get more and more complicated, as components in simpler circuits get repurposed and built upon. 

In [None]:
import pysvelte

In [None]:
from jaxtyping import Float, Int

def visualize_attention_patterns(
    model: torch.nn.Module,
    heads: Union[List[int], int, Float[torch.Tensor, "heads"]], 
    local_cache: Optional[ActivationCache]=None, 
    local_tokens: Optional[torch.Tensor]=None, 
    title: str=""):
    # Heads are given as a list of integers or a single integer in [0, n_layers * n_heads)
    if isinstance(heads, int):
        heads = [heads]
    elif isinstance(heads, list) or isinstance(heads, torch.Tensor):
        heads = utils.to_numpy(heads)
    # Cache defaults to the original activation cache
    if local_cache is None:
        local_cache = cache
    # Tokens defaults to the tokenization of the first prompt (including the BOS token)
    if local_tokens is None:
        # The tokens of the first prompt
        local_tokens = tokens[0]
    
    labels = []
    patterns = []
    batch_index = 0
    for head in heads:
        layer = head // model.cfg.n_heads
        head_index = head % model.cfg.n_heads
        # Get the attention patterns for the head
        # Attention patterns have shape [batch, head_index, query_pos, key_pos]
        patterns.append(local_cache["attn", layer][batch_index, head_index])
        labels.append(f"L{layer}H{head_index}")
    str_tokens = model.to_str_tokens(local_tokens)
    patterns = torch.stack(patterns, dim=-1)
    # Plot the attention patterns
    attention_vis = pysvelte.AttentionMulti(attention=patterns, tokens=str_tokens, head_labels=labels)
    display(HTML(f"<h3>{title}</h3>"))
    attention_vis.show()

In [None]:
def prev_token_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=1, dim1=-1, dim2=-2)
    prev_token_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")

def duplicate_token_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=seq_len, dim1=-1, dim2=-2)
    duplicate_token_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")

def induction_hook(pattern, hook):
    layer = hook.layer()
    diagonal = pattern.diagonal(offset=seq_len-1, dim1=-1, dim2=-2)
    induction_scores[layer] = einops.reduce(diagonal, "batch head_index diagonal -> head_index", "mean")

seq_len = 100
batch_size = 2
original_tokens = torch.randint(100, 20000, size=(batch_size, seq_len))
repeated_tokens = einops.repeat(original_tokens, "batch seq_len -> batch (2 seq_len)").cuda()

In [None]:
prev_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")

duplicate_token_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")

induction_scores = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device="cuda")

pattern_filter = lambda act_name: 'hook_attn_scores' in act_name
loss = model.run_with_hooks(repeated_tokens, return_type="loss", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])
print(utils.get_corner(prev_token_scores))
print(utils.get_corner(duplicate_token_scores))
print(utils.get_corner(induction_scores))

tensor([[ 2.7946, -2.0014, -2.8201],
        [-0.0425,  0.7624,  2.2916]], device='cuda:0')
tensor([[ -3.7780,  -7.0300,  -4.0856],
        [ -3.1768, -11.0179,  -5.0102]], device='cuda:0')
tensor([[ -3.8912,  -7.3408,  -5.3847],
        [ -4.1113, -11.8300,  -6.0766]], device='cuda:0')


In [None]:
def imshow(tensor, renderer=None, **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", **kwargs).show(renderer)

In [None]:
imshow(prev_token_scores, labels={"x":"Head", "y":"Layer"}, title="Previous Token Scores")
imshow(duplicate_token_scores, labels={"x":"Head", "y":"Layer"}, title="Duplicate Token Scores")
imshow(induction_scores, labels={"x":"Head", "y":"Layer"}, title="Induction Head Scores")

In [None]:
example_repeated_text = "A B A B A"
example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)
example_repeated_logits, example_repeated_cache = model.run_with_cache(example_repeated_tokens)
induction_head_labels = [1, 9, 14, 15]
visualize_attention_patterns(induction_head_labels, example_repeated_cache, example_repeated_tokens, title="Induction Heads")

In [None]:
example_repeated_text = "Why attention is all I need why attention is"
example_repeated_tokens = model.to_tokens(example_repeated_text, prepend_bos=True)
example_repeated_logits, example_repeated_cache = model.run_with_cache(example_repeated_tokens)
induction_head_labels = [1, 9, 14, 15]
visualize_attention_patterns(induction_head_labels, example_repeated_cache, example_repeated_tokens, title="Induction Heads")

In [None]:
# useful for sanity checks
model_1l = HookedTransformer.from_pretrained(
    "attn-only-1l",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading (…)"model_final.pth";:   0%|          | 0.00/205M [00:00<?, ?B/s]

Loaded pretrained model attn-only-1l into HookedTransformer


In [None]:

prev_token_scores = torch.zeros((model_1l.cfg.n_layers, model_1l.cfg.n_heads), device="cuda")
duplicate_token_scores = torch.zeros((model_1l.cfg.n_layers, model_1l.cfg.n_heads), device="cuda")
induction_scores = torch.zeros((model_1l.cfg.n_layers, model_1l.cfg.n_heads), device="cuda")

pattern_filter = lambda act_name: 'hook_attn_scores' in act_name
loss = model_1l.run_with_hooks(repeated_tokens, return_type="loss", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])
print(utils.get_corner(prev_token_scores))
print(utils.get_corner(duplicate_token_scores))
print(utils.get_corner(induction_scores))

tensor([[-1.1187,  3.1710, -0.5237]], device='cuda:0')
tensor([[-2.0813, -5.1710, -7.2485]], device='cuda:0')
tensor([[-2.1976, -4.7857, -6.5926]], device='cuda:0')


In [None]:
imshow(prev_token_scores, labels={"x":"Head", "y":"Layer"}, title="Previous Token Scores")
imshow(duplicate_token_scores, labels={"x":"Head", "y":"Layer"}, title="Duplicate Token Scores")
imshow(induction_scores, labels={"x":"Head", "y":"Layer"}, title="Induction Head Scores")

In [None]:
example_prompt = "A B A B A"
example_answer = " B"
utils.test_prompt(example_prompt, example_answer, model_1l, prepend_bos=True)

Tokenized prompt: ['<|BOS|>', 'A', ' B', ' A', ' B', ' A']
Tokenized answer: [' B']


Top 0th token. Logit: 11.83 Prob:  5.65% Token: | B|
Top 1th token. Logit: 11.48 Prob:  3.96% Token: |.|
Top 2th token. Logit: 10.94 Prob:  2.31% Token: |1|
Top 3th token. Logit: 10.88 Prob:  2.18% Token: |2|
Top 4th token. Logit: 10.66 Prob:  1.75% Token: |3|
Top 5th token. Logit: 10.65 Prob:  1.73% Token: |4|
Top 6th token. Logit: 10.53 Prob:  1.54% Token: |,|
Top 7th token. Logit: 10.45 Prob:  1.42% Token: |-|
Top 8th token. Logit: 10.04 Prob:  0.95% Token: |5|
Top 9th token. Logit:  9.99 Prob:  0.90% Token: |/|


In [None]:
# useful for sanity checks
model_4l = HookedTransformer.from_pretrained(
    "attn-only-4l",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

Downloading (…)"model_final.pth";:   0%|          | 0.00/221M [00:00<?, ?B/s]

Loaded pretrained model attn-only-4l into HookedTransformer


In [None]:
prev_token_scores = torch.zeros((model_4l.cfg.n_layers, model_4l.cfg.n_heads), device="cuda")
duplicate_token_scores = torch.zeros((model_4l.cfg.n_layers, model_4l.cfg.n_heads), device="cuda")
induction_scores = torch.zeros((model_4l.cfg.n_layers, model_4l.cfg.n_heads), device="cuda")

pattern_filter = lambda act_name: 'hook_attn_scores' in act_name
loss = model_4l.run_with_hooks(repeated_tokens, return_type="loss", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])
print(utils.get_corner(prev_token_scores))
print(utils.get_corner(duplicate_token_scores))
print(utils.get_corner(induction_scores))

tensor([[-2.5028, -1.9090, -2.2259],
        [-0.0322,  0.5963,  1.6917],
        [-0.8632, -0.1954, -1.8356]], device='cuda:0')
tensor([[-11.5681,   3.1836, -16.1644],
        [-25.8845, -10.6417,   5.2217],
        [-20.1879,  -9.5915,  -0.7754]], device='cuda:0')
tensor([[-14.9748,  -2.4154, -15.5631],
        [-25.5594, -10.2254,   1.6136],
        [-19.1228,  -9.5003,  -2.0012]], device='cuda:0')


Try bigger model

In [None]:
# useful for sanity checks
model_1b = HookedTransformer.from_pretrained(
    "EleutherAI/pythia-1b",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device
)

Downloading (…)lve/main/config.json:   0%|          | 0.00/536 [00:00<?, ?B/s]

Downloading (…)"pytorch_model.bin";:   0%|          | 0.00/2.09G [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/394 [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/99.0 [00:00<?, ?B/s]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-1b into HookedTransformer


In [None]:
prev_token_scores = torch.zeros((model_1b.cfg.n_layers, model_1b.cfg.n_heads), device="cuda")
duplicate_token_scores = torch.zeros((model_1b.cfg.n_layers, model_1b.cfg.n_heads), device="cuda")
induction_scores = torch.zeros((model_1b.cfg.n_layers, model_1b.cfg.n_heads), device="cuda")

pattern_filter = lambda act_name: 'hook_attn_scores' in act_name
loss = model_1b.run_with_hooks(repeated_tokens, return_type="loss", fwd_hooks=[(pattern_filter, prev_token_hook), (pattern_filter, duplicate_token_hook), (pattern_filter, induction_hook)])
print(utils.get_corner(prev_token_scores))
print(utils.get_corner(duplicate_token_scores))
print(utils.get_corner(induction_scores))

tensor([[3.8616, 2.5976, 1.0943],
        [1.3210, 0.3892, 0.3681],
        [1.7755, 2.8693, 2.3252]], device='cuda:0')
tensor([[4.8397, 3.7121, 1.8610],
        [1.1546, 0.2846, 0.3927],
        [2.3109, 2.6099, 2.7266]], device='cuda:0')
tensor([[ 0.1484,  0.2440,  0.2785],
        [-0.4092, -0.0606,  0.6635],
        [ 1.7494,  1.4450,  1.5451]], device='cuda:0')


In [None]:
imshow(prev_token_scores, labels={"x":"Head", "y":"Layer"}, title="Previous Token Scores")
imshow(duplicate_token_scores, labels={"x":"Head", "y":"Layer"}, title="Duplicate Token Scores")
imshow(induction_scores, labels={"x":"Head", "y":"Layer"}, title="Induction Head Scores")

In [None]:
example_repeated_text = "Why attention is all I need Why attention is"
example_repeated_tokens = model_1b.to_tokens(example_repeated_text, prepend_bos=True)
example_repeated_logits, example_repeated_cache = model_1b.run_with_cache(example_repeated_tokens)
induction_head_labels = [
    0, 1, 2, 3, 4, 5, 6, 7,
    8, 9, 10, 11,
    15, 16, 17,
    109, 115, 116, 118]
visualize_attention_patterns(model_1b, induction_head_labels, example_repeated_cache, example_repeated_tokens, title="Induction Heads")

# Models with more than one layer have an abrubt improvement in in-context learning

In [None]:
def in_context_learning_score(model, tokens):
    loss_vec = model(tokens, return_type='loss', loss_per_token=True)
    return (loss_vec[..., 500] - loss_vec[..., 50]).mean()

In [None]:
# Small batch size to avoid cuda memory issues on colab
pile_batch_size = 4
pile_dataloader = evals.make_pile_data_loader(tokenizer=model.tokenizer, batch_size=pile_batch_size)

In [None]:
checkpoint_indices = [10, 25, 35, 60, -1]
model_to_in_context_learning_scores = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    tokens_trained_on = []
    in_context_learning_scores = []
    for index in checkpoint_indices:
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        in_context_learning_score_for_this_checkpoint = 0
        # Use subset of dataset for the sake of time
        num_batches = 2000 // pile_batch_size
        for i, x in enumerate(pile_dataloader):
            tokens = x['tokens'].to(device)
            in_context_learning_score_for_this_checkpoint += in_context_learning_score(model_for_this_checkpoint, tokens).item()
            if i == num_batches:
                break
        in_context_learning_score_for_this_checkpoint /= num_batches
        in_context_learning_scores.append(in_context_learning_score_for_this_checkpoint)
    model_to_in_context_learning_scores[model_name] = in_context_learning_scores
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_in_context_learning_scores:
    in_context_learning_scores = model_to_in_context_learning_scores[model_name]
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    fig = px.line(x=tokens_trained_on, y=in_context_learning_scores, title=model_name, labels={"x":"Elapsed Training Tokens", "y":"In-Context Learning Scores"}, log_x=True)
    fig.update_layout(yaxis_range=[-0.6,0.2])
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    fig.show()

# Induction Heads form in phase change (Prefix Matching Score)

In [None]:
batch_size = 10
seq_len = 50
random_tokens = torch.randint(1000, 10000, (batch_size, seq_len)).to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")
repeated_tokens[:, 0] = model.tokenizer.bos_token_id

In [None]:
# hook copied from transformer lens main demo 
def induction_score_hook(
    pattern: TT["batch", "head_index", "dest_pos", "source_pos"],
    hook: HookPoint,
):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

In [None]:
checkpoint_indices = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, -1]
model_to_scores_per_layer_head = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    tokens_trained_on = []
    induction_scores_per_layer_head = collections.defaultdict(list)
    for index in checkpoint_indices:
        # Load the model from the relevant checkpoint by index
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        # induction_score_hook will store results here
        induction_score_store = torch.zeros((model_for_this_checkpoint.cfg.n_layers, model_for_this_checkpoint.cfg.n_heads), device=model_for_this_checkpoint.cfg.device)

        model_for_this_checkpoint.run_with_hooks(
            repeated_tokens, 
            return_type=None, # For efficiency, we don't need to calculate the logits
            fwd_hooks=[(
                pattern_hook_names_filter,
                induction_score_hook
            )]
        )

        for layer in range(model_for_this_checkpoint.cfg.n_layers):
            for head in range(model_for_this_checkpoint.cfg.n_heads):
                induction_scores_per_layer_head[str(layer) + ',' + str(head)].append(induction_score_store[layer][head].item())
    model_to_scores_per_layer_head[model_name] = induction_scores_per_layer_head
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_scores_per_layer_head:
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    scores_per_layer_head = model_to_scores_per_layer_head[model_name]
    fig = go.Figure(layout={'title': model_name})
    fig.update_xaxes(title="Elapsed Training Tokens", type='log')
    fig.update_yaxes(title="Prefix Matching Score")
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    for layer_head, scores in scores_per_layer_head.items():
        fig.add_trace(go.Scatter(x=tokens_trained_on, y=scores, name=layer_head))
    fig.update_layout(yaxis_range=[0.0,1.0])
    fig.show()

# Loss Curves Diverge during Phase Change

In [None]:
checkpoint_indices = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, -1]
model_to_loss_curve = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    tokens_trained_on = []
    losses = []
    for index in checkpoint_indices:
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        loss_for_this_checkpoint = 0
        num_batches = 40
        for i, x in enumerate(pile_dataloader):
            tokens = x['tokens'].to(device)
            loss_for_this_checkpoint += model_for_this_checkpoint(tokens, return_type='loss').item()
            if i == num_batches:
                break
        loss_for_this_checkpoint /= num_batches
        losses.append(loss_for_this_checkpoint)
    model_to_loss_curve[model_name] = losses
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_loss_curve:
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    losses = model_to_loss_curve[model_name]
    fig = px.line(x=tokens_trained_on, y=losses, title=model_name, labels={"x":"Elapsed Training Tokens", "y":"Loss (nats / token)"})
    fig.update_layout(yaxis_range=[2.0,8.0])
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    fig.show()

Log x axis to see the phase change more clearly:

In [None]:
for model_name in model_to_loss_curve:
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    losses = model_to_loss_curve[model_name]
    fig = px.line(x=tokens_trained_on, y=losses, title=model_name, labels={"x":"Elapsed Training Tokens", "y":"Loss (nats / token)"}, log_x=True)
    fig.update_layout(yaxis_range=[2.0,8.0])
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    fig.show()

# Per Token Loss Principal component Analysis

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
pca = PCA(n_components=2)

In [None]:
# collect some examples
examples = []
num_examples = 200 // pile_batch_size
for i, x in enumerate(pile_dataloader):
    tokens = x['tokens'].to(device)
    examples.append(tokens)
    if i == num_examples - 1:
        break
examples[0].shape

torch.Size([4, 1024])

In [None]:
indices = torch.randint(0, examples[0].shape[-1]-1, (len(examples) * pile_batch_size,))
indices.shape

torch.Size([200])

In [None]:
checkpoint_indices = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, -1]
model_to_pca_features = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    loss_data_matrix = torch.zeros((len(checkpoint_indices), len(examples) * pile_batch_size))
    tokens_trained_on = []
    for pos, index in enumerate(checkpoint_indices):
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        loss_vec_store = torch.zeros((len(examples) * pile_batch_size, examples[0].shape[-1]-1))
        for i, ex in enumerate(examples):
            loss_vec = model_for_this_checkpoint(ex, return_type="loss", loss_per_token=True)
            loss_vec_store[i*pile_batch_size:i*pile_batch_size + pile_batch_size] = loss_vec.cpu()
        loss_sampled = loss_vec_store[torch.arange(loss_vec_store.shape[0]), indices]
        # I needed to put this on cpu to avoid cuda memory errors...
        loss_data_matrix[pos] = loss_sampled.cpu()
    loss_data_scaled = StandardScaler().fit_transform(loss_data_matrix)
    pca_features = pca.fit_transform(loss_data_scaled)
    model_to_pca_features[model_name] = pca_features
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_pca_features:
    pca_features = model_to_pca_features[model_name]
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    fig1 = go.Figure()
    for i in range(1, len(pca_features)):
        # color phase change window red
        line_color = "red" if 3e8 <= tokens_trained_on[i] <= 1.5e9 else 'blue'
        fig1.add_trace(go.Scatter(x=pca_features[i-1: i+1, 0],
                                    y=pca_features[i-1: i+1, 1],
                                    line={"width": 1, "dash": "dash", "color": line_color}, showlegend=False))
    fig1.update(layout_showlegend=False)
    
    fig2 = px.scatter(x=pca_features[:, 0], y=pca_features[:, 1], color=list(map(str, tokens_trained_on)))
    fig3 = go.Figure(data=fig1.data + fig2.data)
    fig3.update_layout(legend_title="Elapsed Training Tokens", title=model_name)
    fig3.show()

# B - A per token losses on Harry Potter

In [None]:
context = """Mr. and Mrs. Dursley, of number four, Privet Drive, were
proud to say that they were perfectly normal, thank
you very much. They were the last people you’d expect to be involved in anything strange or mysterious, because they just didn’t
hold with such nonsense.
Mr. Dursley was the director of a firm called Grunnings, which
made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin
and blonde and had nearly twice the usual amount of neck, which
came in very useful as she spent so much of her time craning over
garden fences, spying on the neighbors. The Dursleys had a small
son called Dudley and in their opinion there was no finer boy
anywhere.
The Dursleys had everything they wanted, but they also had a
secret, and their greatest fear was that somebody would discover it.
They didn’t think they could bear it if anyone found out about the
Potters. Mrs. Potter was Mrs. Dursley’s sister, but they hadn’t met
for several years; in fact, Mrs. Dursley pretended she didn’t have a
sister, because her sister and her good-for-nothing husband were
as unDursleyish as it was possible to be. The Dursleys shuddered
to think what the neighbors would say if the Potters arrived in the
street. The Dursleys knew that the Potters had a small son, too, but
they had never even seen him. This boy was another good reason
for keeping the Potters away; they didn’t want Dudley mixing with
a child like that.
"""

In [None]:
# take indices right before and after phase change window (based on pca plot above)
a_index = 25
b_index = 50

model_before_phase_change = HookedTransformer.from_pretrained('attn-only-2l', device=device, checkpoint_index=a_index)
model_after_phase_change = HookedTransformer.from_pretrained('attn-only-2l', device=device, checkpoint_index=b_index)

In [None]:
loss_vec_before = model_before_phase_change(context, return_type='loss', loss_per_token=True)
loss_vec_after = model_after_phase_change(context, return_type='loss', loss_per_token=True)

loss_vec_difference = loss_vec_after - loss_vec_before

In [None]:
str_tokens = model_before_phase_change.to_str_tokens(context)
z = utils.to_numpy(loss_vec_difference.reshape(20, -1))
z_text = np.array(str_tokens[1:]).reshape(z.shape)

fig = px.imshow(z, color_continuous_midpoint=0.0, color_continuous_scale="RdBu", aspect="auto")
fig.update_traces(text=z_text, texttemplate="%{text}")
fig.show()

# Per-Token losses over training

In [None]:
leys_idx = max(idx for idx, token in enumerate(str_tokens) if token == 'leys')
useful_idx = min(idx for idx, token in enumerate(str_tokens) if token == ' useful')

In [None]:
checkpoint_indices = [10, 25, 35, 60, -1]
tokens_trained_on = []
leys_losses = []
useful_losses = []
mean_losses = []
for index in checkpoint_indices:
    model_for_this_checkpoint = HookedTransformer.from_pretrained('attn-only-2l', device=device, checkpoint_index=index)

    tokens_trained_on_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
    tokens_trained_on.append(tokens_trained_on_for_this_checkpoint)

    loss_vec = model_for_this_checkpoint(context, return_type="loss", loss_per_token=True)
    loss_for_leys = loss_vec[:, leys_idx-1].item()
    leys_losses.append(loss_for_leys)

    loss_for_useful = loss_vec[:, useful_idx-1].item()
    useful_losses.append(loss_for_useful)

    mean_losses.append(loss_vec.mean().item())

In [None]:
fig = go.Figure()
fig.update_layout(yaxis_range=[0.0, 14.0])
fig.add_vrect(x0=3e8, x1=1.5e9, fillcolor='gold', line_width=1, opacity=0.2)
fig.update_xaxes(title="Elapsed Training Tokens")
fig.update_yaxes(title="Loss (nats / token)")
fig.add_trace(go.Scatter(x=tokens_trained_on, y=useful_losses, name=" useful"))
fig.add_trace(go.Scatter(x=tokens_trained_on, y=leys_losses, name="leys"))
fig.add_trace(go.Scatter(x=tokens_trained_on, y=mean_losses, name="mean loss", line=dict(color='gray', dash='dash')))
fig.show()