<a href="https://colab.research.google.com/github/ckkissane/induction-heads-transformer-lens/blob/main/Induction_Heads_Phase_Change.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Setup

(Mostly taken from TransformerLens main demo. No need to read)

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEVELOPMENT_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/neelnanda-io/TransformerLens.git
    %pip install circuitsvis
    
    # PySvelte is an unmaintained visualization library, use it as a backup if circuitsvis isn't working
    # # Install another version of node that makes PySvelte work way faster
    # !curl -fsSL https://deb.nodesource.com/setup_16.x | sudo -E bash -; sudo apt-get install -y nodejs
    # %pip install git+https://github.com/neelnanda-io/PySvelte.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
if IN_COLAB or not DEVELOPMENT_MODE:
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "notebook_connected"
print(f"Using renderer: {pio.renderers.default}")

Using renderer: colab


In [None]:
import circuitsvis as cv
# Testing that the library works
cv.examples.hello("Connor")

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.auto as tqdm
import random
from pathlib import Path
import plotly.express as px
from torch.utils.data import DataLoader

from torchtyping import TensorType as TT
from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

In [None]:
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x7f7772fcffa0>

In [None]:
from transformer_lens import evals
import matplotlib.pyplot as plt
import collections
import plotly.graph_objects as go

In [None]:
def imshow(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.imshow(utils.to_numpy(tensor), color_continuous_midpoint=0.0, color_continuous_scale="RdBu", labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def line(tensor, renderer=None, xaxis="", yaxis="", **kwargs):
    px.line(utils.to_numpy(tensor), labels={"x":xaxis, "y":yaxis}, **kwargs).show(renderer)

def scatter(x, y, xaxis="", yaxis="", caxis="", renderer=None, **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(y=y, x=x, labels={"x":xaxis, "y":yaxis, "color":caxis}, **kwargs).show(renderer)

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# useful for sanity checks
model = HookedTransformer.from_pretrained("attn-only-2l", device=device)

# Models with more than one layer have an abrubt improvement in in-context learning

In [None]:
def in_context_learning_score(model, tokens):
    loss_vec = model(tokens, return_type='loss', loss_per_token=True)
    return (loss_vec[..., 500] - loss_vec[..., 50]).mean()

In [None]:
# Small batch size to avoid cuda memory issues on colab
pile_batch_size = 4
pile_dataloader = evals.make_pile_data_loader(tokenizer=model.tokenizer, batch_size=pile_batch_size)

In [None]:
checkpoint_indices = [10, 25, 35, 60, -1]
model_to_in_context_learning_scores = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    tokens_trained_on = []
    in_context_learning_scores = []
    for index in checkpoint_indices:
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        in_context_learning_score_for_this_checkpoint = 0
        # Use subset of dataset for the sake of time
        num_batches = 2000 // pile_batch_size
        for i, x in enumerate(pile_dataloader):
            tokens = x['tokens'].to(device)
            in_context_learning_score_for_this_checkpoint += in_context_learning_score(model_for_this_checkpoint, tokens).item()
            if i == num_batches:
                break
        in_context_learning_score_for_this_checkpoint /= num_batches
        in_context_learning_scores.append(in_context_learning_score_for_this_checkpoint)
    model_to_in_context_learning_scores[model_name] = in_context_learning_scores
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_in_context_learning_scores:
    in_context_learning_scores = model_to_in_context_learning_scores[model_name]
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    fig = px.line(x=tokens_trained_on, y=in_context_learning_scores, title=model_name, labels={"x":"Elapsed Training Tokens", "y":"In-Context Learning Scores"}, log_x=True)
    fig.update_layout(yaxis_range=[-0.6,0.2])
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    fig.show()

# Induction Heads form in phase change (Prefix Matching Score)

In [None]:
batch_size = 10
seq_len = 50
random_tokens = torch.randint(1000, 10000, (batch_size, seq_len)).to(model.cfg.device)
repeated_tokens = einops.repeat(random_tokens, "batch seq_len -> batch (2 seq_len)")
repeated_tokens[:, 0] = model.tokenizer.bos_token_id

In [None]:
# hook copied from transformer lens main demo 
def induction_score_hook(
    pattern: TT["batch", "head_index", "dest_pos", "source_pos"],
    hook: HookPoint,
):
    # We take the diagonal of attention paid from each destination position to source positions seq_len-1 tokens back
    # (This only has entries for tokens with index>=seq_len)
    induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1-seq_len)
    # Get an average score per head
    induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
    # Store the result.
    induction_score_store[hook.layer(), :] = induction_score

# We make a boolean filter on activation names, that's true only on attention pattern names.
pattern_hook_names_filter = lambda name: name.endswith("pattern")

In [None]:
checkpoint_indices = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, -1]
model_to_scores_per_layer_head = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    tokens_trained_on = []
    induction_scores_per_layer_head = collections.defaultdict(list)
    for index in checkpoint_indices:
        # Load the model from the relevant checkpoint by index
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        # induction_score_hook will store results here
        induction_score_store = torch.zeros((model_for_this_checkpoint.cfg.n_layers, model_for_this_checkpoint.cfg.n_heads), device=model_for_this_checkpoint.cfg.device)

        model_for_this_checkpoint.run_with_hooks(
            repeated_tokens, 
            return_type=None, # For efficiency, we don't need to calculate the logits
            fwd_hooks=[(
                pattern_hook_names_filter,
                induction_score_hook
            )]
        )

        for layer in range(model_for_this_checkpoint.cfg.n_layers):
            for head in range(model_for_this_checkpoint.cfg.n_heads):
                induction_scores_per_layer_head[str(layer) + ',' + str(head)].append(induction_score_store[layer][head].item())
    model_to_scores_per_layer_head[model_name] = induction_scores_per_layer_head
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_scores_per_layer_head:
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    scores_per_layer_head = model_to_scores_per_layer_head[model_name]
    fig = go.Figure(layout={'title': model_name})
    fig.update_xaxes(title="Elapsed Training Tokens", type='log')
    fig.update_yaxes(title="Prefix Matching Score")
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    for layer_head, scores in scores_per_layer_head.items():
        fig.add_trace(go.Scatter(x=tokens_trained_on, y=scores, name=layer_head))
    fig.update_layout(yaxis_range=[0.0,1.0])
    fig.show()

# Loss Curves Diverge during Phase Change

In [None]:
checkpoint_indices = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, -1]
model_to_loss_curve = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    tokens_trained_on = []
    losses = []
    for index in checkpoint_indices:
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        loss_for_this_checkpoint = 0
        num_batches = 40
        for i, x in enumerate(pile_dataloader):
            tokens = x['tokens'].to(device)
            loss_for_this_checkpoint += model_for_this_checkpoint(tokens, return_type='loss').item()
            if i == num_batches:
                break
        loss_for_this_checkpoint /= num_batches
        losses.append(loss_for_this_checkpoint)
    model_to_loss_curve[model_name] = losses
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_loss_curve:
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    losses = model_to_loss_curve[model_name]
    fig = px.line(x=tokens_trained_on, y=losses, title=model_name, labels={"x":"Elapsed Training Tokens", "y":"Loss (nats / token)"})
    fig.update_layout(yaxis_range=[2.0,8.0])
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    fig.show()

Log x axis to see the phase change more clearly:

In [None]:
for model_name in model_to_loss_curve:
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    losses = model_to_loss_curve[model_name]
    fig = px.line(x=tokens_trained_on, y=losses, title=model_name, labels={"x":"Elapsed Training Tokens", "y":"Loss (nats / token)"}, log_x=True)
    fig.update_layout(yaxis_range=[2.0,8.0])
    fig.add_vrect(x0=3e8, x1=1.5e9, line_width=1, fillcolor="gold", opacity=0.2)
    fig.show()

# Per Token Loss Principal component Analysis

In [None]:
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
pca = PCA(n_components=2)

In [None]:
# collect some examples
examples = []
num_examples = 200 // pile_batch_size
for i, x in enumerate(pile_dataloader):
    tokens = x['tokens'].to(device)
    examples.append(tokens)
    if i == num_examples - 1:
        break
examples[0].shape

torch.Size([4, 1024])

In [None]:
indices = torch.randint(0, examples[0].shape[-1]-1, (len(examples) * pile_batch_size,))
indices.shape

torch.Size([200])

In [None]:
checkpoint_indices = [10, 15, 20, 25, 30, 35, 40, 45, 50, 60, -1]
model_to_pca_features = {}
model_to_tokens_trained_on = {}
for model_name in ["attn-only-1l", "attn-only-2l", "attn-only-3l"]:
    loss_data_matrix = torch.zeros((len(checkpoint_indices), len(examples) * pile_batch_size))
    tokens_trained_on = []
    for pos, index in enumerate(checkpoint_indices):
        model_for_this_checkpoint = HookedTransformer.from_pretrained(model_name, checkpoint_index=index, device=device)

        tokens_seen_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
        tokens_trained_on.append(tokens_seen_for_this_checkpoint)

        loss_vec_store = torch.zeros((len(examples) * pile_batch_size, examples[0].shape[-1]-1))
        for i, ex in enumerate(examples):
            loss_vec = model_for_this_checkpoint(ex, return_type="loss", loss_per_token=True)
            loss_vec_store[i*pile_batch_size:i*pile_batch_size + pile_batch_size] = loss_vec.cpu()
        loss_sampled = loss_vec_store[torch.arange(loss_vec_store.shape[0]), indices]
        # I needed to put this on cpu to avoid cuda memory errors...
        loss_data_matrix[pos] = loss_sampled.cpu()
    loss_data_scaled = StandardScaler().fit_transform(loss_data_matrix)
    pca_features = pca.fit_transform(loss_data_scaled)
    model_to_pca_features[model_name] = pca_features
    model_to_tokens_trained_on[model_name] = tokens_trained_on

In [None]:
for model_name in model_to_pca_features:
    pca_features = model_to_pca_features[model_name]
    tokens_trained_on = model_to_tokens_trained_on[model_name]
    fig1 = go.Figure()
    for i in range(1, len(pca_features)):
        # color phase change window red
        line_color = "red" if 3e8 <= tokens_trained_on[i] <= 1.5e9 else 'blue'
        fig1.add_trace(go.Scatter(x=pca_features[i-1: i+1, 0],
                                    y=pca_features[i-1: i+1, 1],
                                    line={"width": 1, "dash": "dash", "color": line_color}, showlegend=False))
    fig1.update(layout_showlegend=False)
    
    fig2 = px.scatter(x=pca_features[:, 0], y=pca_features[:, 1], color=list(map(str, tokens_trained_on)))
    fig3 = go.Figure(data=fig1.data + fig2.data)
    fig3.update_layout(legend_title="Elapsed Training Tokens", title=model_name)
    fig3.show()

# B - A per token losses on Harry Potter

In [None]:
context = """Mr. and Mrs. Dursley, of number four, Privet Drive, were
proud to say that they were perfectly normal, thank
you very much. They were the last people you’d expect to be involved in anything strange or mysterious, because they just didn’t
hold with such nonsense.
Mr. Dursley was the director of a firm called Grunnings, which
made drills. He was a big, beefy man with hardly any neck, although he did have a very large mustache. Mrs. Dursley was thin
and blonde and had nearly twice the usual amount of neck, which
came in very useful as she spent so much of her time craning over
garden fences, spying on the neighbors. The Dursleys had a small
son called Dudley and in their opinion there was no finer boy
anywhere.
The Dursleys had everything they wanted, but they also had a
secret, and their greatest fear was that somebody would discover it.
They didn’t think they could bear it if anyone found out about the
Potters. Mrs. Potter was Mrs. Dursley’s sister, but they hadn’t met
for several years; in fact, Mrs. Dursley pretended she didn’t have a
sister, because her sister and her good-for-nothing husband were
as unDursleyish as it was possible to be. The Dursleys shuddered
to think what the neighbors would say if the Potters arrived in the
street. The Dursleys knew that the Potters had a small son, too, but
they had never even seen him. This boy was another good reason
for keeping the Potters away; they didn’t want Dudley mixing with
a child like that.
"""

In [None]:
# take indices right before and after phase change window (based on pca plot above)
a_index = 25
b_index = 50

model_before_phase_change = HookedTransformer.from_pretrained('attn-only-2l', device=device, checkpoint_index=a_index)
model_after_phase_change = HookedTransformer.from_pretrained('attn-only-2l', device=device, checkpoint_index=b_index)

In [None]:
loss_vec_before = model_before_phase_change(context, return_type='loss', loss_per_token=True)
loss_vec_after = model_after_phase_change(context, return_type='loss', loss_per_token=True)

loss_vec_difference = loss_vec_after - loss_vec_before

In [None]:
str_tokens = model_before_phase_change.to_str_tokens(context)
z = utils.to_numpy(loss_vec_difference.reshape(20, -1))
z_text = np.array(str_tokens[1:]).reshape(z.shape)

fig = px.imshow(z, color_continuous_midpoint=0.0, color_continuous_scale="RdBu", aspect="auto")
fig.update_traces(text=z_text, texttemplate="%{text}")
fig.show()

# Per-Token losses over training

In [None]:
leys_idx = max(idx for idx, token in enumerate(str_tokens) if token == 'leys')
useful_idx = min(idx for idx, token in enumerate(str_tokens) if token == ' useful')

In [None]:
checkpoint_indices = [10, 25, 35, 60, -1]
tokens_trained_on = []
leys_losses = []
useful_losses = []
mean_losses = []
for index in checkpoint_indices:
    model_for_this_checkpoint = HookedTransformer.from_pretrained('attn-only-2l', device=device, checkpoint_index=index)

    tokens_trained_on_for_this_checkpoint = model_for_this_checkpoint.cfg.checkpoint_value
    tokens_trained_on.append(tokens_trained_on_for_this_checkpoint)

    loss_vec = model_for_this_checkpoint(context, return_type="loss", loss_per_token=True)
    loss_for_leys = loss_vec[:, leys_idx-1].item()
    leys_losses.append(loss_for_leys)

    loss_for_useful = loss_vec[:, useful_idx-1].item()
    useful_losses.append(loss_for_useful)

    mean_losses.append(loss_vec.mean().item())

In [None]:
fig = go.Figure()
fig.update_layout(yaxis_range=[0.0, 14.0])
fig.add_vrect(x0=3e8, x1=1.5e9, fillcolor='gold', line_width=1, opacity=0.2)
fig.update_xaxes(title="Elapsed Training Tokens")
fig.update_yaxes(title="Loss (nats / token)")
fig.add_trace(go.Scatter(x=tokens_trained_on, y=useful_losses, name=" useful"))
fig.add_trace(go.Scatter(x=tokens_trained_on, y=leys_losses, name="leys"))
fig.add_trace(go.Scatter(x=tokens_trained_on, y=mean_losses, name="mean loss", line=dict(color='gray', dash='dash')))
fig.show()