<a href="https://colab.research.google.com/github/certainforest/role-representation/blob/main/Ayush/Activation_Patching_in_TL_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

<a target="_blank" href="https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/Activation_Patching_in_TL_Demo.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

 # Activation Patching in TransformerLens Demo
 This is an accompaniment to [Exploratory Analysis Demo](https://neelnanda.io/exploratory-analysis-demo). That notebook explains some basic techniques for mech interp of networks, including an overview of activation patching ([summary here](https://dynalist.io/d/n2ZWtnoYHrU1s4vnFSAQ519J#z=qeWBvs-R-taFfcCq-S_hgMqx)). This demonstrates how to use the Activation Patching utils in TransformerLens.


 <b style="color: red">To use this notebook, go to Runtime > Change Runtime Type and select GPU as the hardware accelerator.</b>

 **Tips for reading this Colab:**
 * You can run all this code for yourself!
 * The graphs are interactive!
 * Use the table of contents pane in the sidebar to navigate
 * Collapse irrelevant sections with the dropdown arrows
 * Search the page using the search in the sidebar, not CTRL+F

 ## Setup (Ignore)

In [None]:
!pip install -U kaleido



In [None]:
import os
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np

# Create output directory
output_dir = Path("activation_patches")
output_dir.mkdir(exist_ok=True)

def save_imshow(data, filename, **kwargs):
    """Save imshow plot as PNG using matplotlib"""

    # Convert tensor to numpy on CPU
    if hasattr(data, 'cpu'):
        data = data.cpu().numpy()

    # Extract parameters
    title = kwargs.get('title', 'Heatmap')
    xaxis_title = kwargs.get('xaxis', 'X')
    yaxis_title = kwargs.get('yaxis', 'Y')
    x_labels = kwargs.get('x', None)
    y_labels = kwargs.get('y', None)
    facet_col = kwargs.get('facet_col', None)
    facet_labels = kwargs.get('facet_labels', None)
    zmin = kwargs.get('zmin', None)
    zmax = kwargs.get('zmax', None)

    if facet_col is not None:
        # Multiple subplots
        n_facets = data.shape[facet_col]
        fig, axes = plt.subplots(1, n_facets, figsize=(6*n_facets, 5))
        if n_facets == 1:
            axes = [axes]

        for i, ax in enumerate(axes):
            if facet_col == 0:
                plot_data = data[i]
            else:
                plot_data = data[:, i]

            im = ax.imshow(plot_data, aspect='auto', cmap='RdBu_r',
                          vmin=zmin, vmax=zmax,
                          interpolation='nearest')

            if facet_labels and i < len(facet_labels):
                ax.set_title(facet_labels[i])

            ax.set_xlabel(xaxis_title)
            if i == 0:
                ax.set_ylabel(yaxis_title)

            # Set tick labels if provided
            if x_labels is not None and len(x_labels) == plot_data.shape[1]:
                ax.set_xticks(range(len(x_labels)))
                ax.set_xticklabels(x_labels, rotation=90, ha='right', fontsize=8)
            if y_labels is not None and len(y_labels) == plot_data.shape[0]:
                ax.set_yticks(range(len(y_labels)))
                ax.set_yticklabels(y_labels, fontsize=8)

            # Add colorbar
            plt.colorbar(im, ax=ax)

        fig.suptitle(title, fontsize=14, y=1.02)
    else:
        # Single plot
        fig, ax = plt.subplots(figsize=(10, 6))

        im = ax.imshow(data, aspect='auto', cmap='RdBu_r',
                      vmin=zmin, vmax=zmax,
                      interpolation='nearest')

        ax.set_xlabel(xaxis_title)
        ax.set_ylabel(yaxis_title)
        ax.set_title(title)

        # Set tick labels if provided
        if x_labels is not None and len(x_labels) == data.shape[1]:
            ax.set_xticks(range(len(x_labels)))
            ax.set_xticklabels(x_labels, rotation=90, ha='right', fontsize=8)
        if y_labels is not None and len(y_labels) == data.shape[0]:
            ax.set_yticks(range(len(y_labels)))
            ax.set_yticklabels(y_labels, fontsize=8)

        # Add colorbar
        plt.colorbar(im, ax=ax, label='Value')

    # Save as PNG
    filepath = output_dir / f"{filename}.png"
    plt.tight_layout()
    plt.savefig(str(filepath), dpi=150, bbox_inches='tight')
    plt.close(fig)  # Close to free memory
    print(f"Saved: {filepath}")

In [None]:
# Janky code to do different setup when run in a Colab notebook vs VSCode
DEBUG_MODE = False
try:
    import google.colab
    IN_COLAB = True
    print("Running as a Colab notebook")
    %pip install git+https://github.com/TransformerLensOrg/TransformerLens.git
    # Install my janky personal plotting utils
    %pip install git+https://github.com/neelnanda-io/neel-plotly.git
except:
    IN_COLAB = False
    print("Running as a Jupyter notebook - intended for development only!")
    from IPython import get_ipython

    ipython = get_ipython()
    # Code to automatically update the HookedTransformer code as its edited without restarting the kernel
    ipython.magic("load_ext autoreload")
    ipython.magic("autoreload 2")

Running as a Colab notebook
Collecting git+https://github.com/TransformerLensOrg/TransformerLens.git
  Cloning https://github.com/TransformerLensOrg/TransformerLens.git to /tmp/pip-req-build-mz3c9roz
  Running command git clone --filter=blob:none --quiet https://github.com/TransformerLensOrg/TransformerLens.git /tmp/pip-req-build-mz3c9roz
  Resolved https://github.com/TransformerLensOrg/TransformerLens.git to commit 7df72ff71b3b0b25845f9d12836ba45a58a0d629
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Collecting git+https://github.com/neelnanda-io/neel-plotly.git
  Cloning https://github.com/neelnanda-io/neel-plotly.git to /tmp/pip-req-build-agtko7o7
  Running command git clone --filter=blob:none --quiet https://github.com/neelnanda-io/neel-plotly.git /tmp/pip-req-build-agtko7o7
  Resolved https://github.com/neelnanda-io/neel-plotly.git to commit 6dc24b26f8dec991

In [None]:
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio

if IN_COLAB or not DEBUG_MODE:
    # Thanks to annoying rendering issues, Plotly graphics will either show up in colab OR Vscode depending on the renderer - this is bad for developing demos! Thus creating a debug mode.
    pio.renderers.default = "colab"
else:
    pio.renderers.default = "png"





This means that static image generation (e.g. `fig.write_image()`) will not work.

Please upgrade Plotly to version 6.1.1 or greater, or downgrade Kaleido to version 0.2.1.




In [None]:
pio.renderers.default = "colab"  # If in Colab


In [None]:
!pip install --force-reinstall numpy==1.26.4


Collecting numpy==1.26.4
  Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)
Using cached numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (18.0 MB)
Installing collected packages: numpy
  Attempting uninstall: numpy
    Found existing installation: numpy 1.26.4
    Uninstalling numpy-1.26.4:
      Successfully uninstalled numpy-1.26.4
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
tobler 0.13.0 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.
opencv-python-headless 4.12.0.88 requires numpy<2.3.0,>=2; python_version >= "3.9", but you have numpy 1.26.4 which is incompatible.
rasterio 1.5.0 requires numpy>=2, but you have numpy 1.26.4 which is incompatible.
jaxlib 0.7.2 requires numpy>=2.0, but you have numpy 1.26.4 which is incompatible.
pytensor 2.36.3 requires 

In [None]:
# Import stuff
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import einops
from fancy_einsum import einsum
import tqdm.notebook as tqdm
import random
from pathlib import Path
# import plotly.express as px
from torch.utils.data import DataLoader

from typing import List, Union, Optional
from functools import partial
import copy

import itertools
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
import dataclasses
import datasets
from IPython.display import HTML

In [None]:
import transformer_lens
import transformer_lens.utils as utils
from transformer_lens.hook_points import (
    HookedRootModule,
    HookPoint,
)  # Hooking utilities
from transformer_lens import HookedTransformer, HookedTransformerConfig, FactoredMatrix, ActivationCache

 We turn automatic differentiation off, to save GPU memory, as this notebook focuses on model inference not model training.

In [None]:
torch.set_grad_enabled(False)

torch.autograd.grad_mode.set_grad_enabled(mode=False)

 Plotting helper functions from a janky personal library of plotting utils. The library is not documented and I recommend against trying to read it, just use your preferred plotting library if you want to do anything non-obvious:

In [None]:
from neel_plotly import line, imshow, scatter

In [None]:
import transformer_lens.patching as patching

 ## Activation Patching Setup
 This just copies the relevant set up from Exploratory Analysis Demo, and isn't very important.

In [None]:
model = HookedTransformer.from_pretrained("gpt2-small")



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.



config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


In [None]:
# STEP 1: Define the prompts (use the nested example above)
prompts = [
    'Alice: What models have you tried? Bob: I tried Gemma and Quen. Quen was better. Alice: Did you enter a contest? Bob: Yes, that\'s where I started. Who asked about models?',

    'What models have you tried? I tried Gemma and Quen. Quen was better. Did you enter a contest? Yes, that\'s where I started. Who asked about models?'
]

answers = [
    (' Alice', ' Bob'),  # Alice asked about models
    (' Bob', ' Alice')   # Wrong attribution
]
clean_tokens = model.to_tokens(prompts)
# Swap each adjacent pair, with a hacky list comprehension
corrupted_tokens = clean_tokens[
    [(i+1 if i%2==0 else i-1) for i in range(len(clean_tokens)) ]
    ]
print("Clean string 0", model.to_string(clean_tokens[0]))
print("Corrupted string 0", model.to_string(corrupted_tokens[0]))

answer_token_indices = torch.tensor([[model.to_single_token(answers[i][j]) for j in range(2)] for i in range(len(answers))], device=model.cfg.device)
print("Answer token indices", answer_token_indices)

Clean string 0 <|endoftext|>Alice: What models have you tried? Bob: I tried Gemma and Quen. Quen was better. Alice: Did you enter a contest? Bob: Yes, that's where I started. Who asked about models?
Corrupted string 0 <|endoftext|>What models have you tried? I tried Gemma and Quen. Quen was better. Did you enter a contest? Yes, that's where I started. Who asked about models?<|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|><|endoftext|>
Answer token indices tensor([[14862,  5811],
        [14862,  5811]], device='cuda:0')


In [None]:
# VERIFICATION: Check your tokens are correct
print("\n" + "="*60)
print("SPEAKER TRACKING EXPERIMENT - TOKEN VERIFICATION")
print("="*60)

print(f"\nClean prompt: {prompts[0]}")
print(f"Tokens: {model.to_str_tokens(clean_tokens[0])}")

print(f"\nCorrupted prompt: {prompts[1]}")
print(f"Tokens: {model.to_str_tokens(corrupted_tokens[0])}")

print(f"\nAnswer tokens: {answers}")
print(f"Empty token ID: {answer_token_indices[0, 0]}")
print(f"Full token ID: {answer_token_indices[0, 1]}")

print("\n" + "="*60)


SPEAKER TRACKING EXPERIMENT - TOKEN VERIFICATION

Clean prompt: Alice: What models have you tried? Bob: I tried Gemma and Quen. Quen was better. Alice: Did you enter a contest? Bob: Yes, that's where I started. Who asked about models?
Tokens: ['<|endoftext|>', 'Alice', ':', ' What', ' models', ' have', ' you', ' tried', '?', ' Bob', ':', ' I', ' tried', ' Gem', 'ma', ' and', ' Qu', 'en', '.', ' Qu', 'en', ' was', ' better', '.', ' Alice', ':', ' Did', ' you', ' enter', ' a', ' contest', '?', ' Bob', ':', ' Yes', ',', ' that', "'s", ' where', ' I', ' started', '.', ' Who', ' asked', ' about', ' models', '?']

Corrupted prompt: What models have you tried? I tried Gemma and Quen. Quen was better. Did you enter a contest? Yes, that's where I started. Who asked about models?
Tokens: ['<|endoftext|>', 'What', ' models', ' have', ' you', ' tried', '?', ' I', ' tried', ' Gem', 'ma', ' and', ' Qu', 'en', '.', ' Qu', 'en', ' was', ' better', '.', ' Did', ' you', ' enter', ' a', ' contest', '?',

In [None]:
def get_logit_diff(logits, answer_token_indices=answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)

clean_logit_diff = get_logit_diff(clean_logits, answer_token_indices).item()
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = get_logit_diff(corrupted_logits, answer_token_indices).item()
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: -0.0286
Corrupted logit diff: -0.0286


In [None]:
CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

print(f"Clean Baseline is 1: {ioi_metric(clean_logits).item():.4f}")
print(f"Corrupted Baseline is 0: {ioi_metric(corrupted_logits).item():.4f}")

Clean Baseline is 1: nan
Corrupted Baseline is 0: nan


In [None]:
# TEST: Verify metric works
print(f"Clean logit diff (should be positive): {clean_logit_diff:.4f}")
print(f"Corrupted logit diff (should be negative): {corrupted_logit_diff:.4f}")
print(f"\nMetric interpretation:")
print(f"  Clean baseline = 1.0: {ioi_metric(clean_logits).item():.4f}")
print(f"  Corrupted baseline = 0.0: {ioi_metric(corrupted_logits).item():.4f}")

if clean_logit_diff > 0 and corrupted_logit_diff < 0:
    print("\n✅ Metric is working correctly!")
    print(f"   Clean prefers 'empty' by {clean_logit_diff:.2f} logits")
    print(f"   Corrupted prefers 'full' by {-corrupted_logit_diff:.2f} logits")
else:
    print("\n❌ Something is wrong with the metric!")

Clean logit diff (should be positive): -0.0286
Corrupted logit diff (should be negative): -0.0286

Metric interpretation:
  Clean baseline = 1.0: nan
  Corrupted baseline = 0.0: nan

❌ Something is wrong with the metric!


 ## Patching
 In the following cells, we use the patching module to call activation patching utilities

In [None]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

 ### Patching Single Activation Types
 We start by patching single types of activation
 The general syntax is that the functions are called get_act_patch_... and take in (model, corrupted_tokens, clean_cache, patching_metric)

 We can patch the residual stream at the start of each block over each layer and position
 resid_pre -> attn_out, mlp_out, resid_mid all also work

In [None]:
# Run this before the save_imshow function
import sys
!pip uninstall -y kaleido
!pip install kaleido==0.2.1

Found existing installation: kaleido 1.2.0
Uninstalling kaleido-1.2.0:
  Successfully uninstalled kaleido-1.2.0
Collecting kaleido==0.2.1
  Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl.metadata (15 kB)
Downloading kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl (79.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.9/79.9 MB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: kaleido
Successfully installed kaleido-0.2.1


In [None]:
resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, clean_cache, ioi_metric)
save_imshow(resid_pre_act_patch_results,
           filename="01_resid_pre_activation_patching",
           yaxis="Layer",
           xaxis="Position",
           x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
           title="resid_pre Activation Patching")

  0%|          | 0/564 [00:00<?, ?it/s]

Saved: activation_patches/01_resid_pre_activation_patching.png


 We can patch head outputs over each head in each layer, patching across all positions at once
 out -> q, k, v, pattern all also work

In [None]:
attn_head_out_all_pos_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, clean_cache, ioi_metric)
save_imshow(attn_head_out_all_pos_act_patch_results,
           filename="02_attn_head_out_all_pos",
           yaxis="Layer",
           xaxis="Head",
           title="attn_head_out Activation Patching (All Pos)")

  0%|          | 0/144 [00:00<?, ?it/s]

Saved: activation_patches/02_attn_head_out_all_pos.png


 We can patch head outputs over each head in each layer, patching on each position in turn
 out -> q, k, v, pattern all also work, though note that pattern has output shape [layer, pos, head]
 We reshape it to plot nicely

In [None]:
ALL_HEAD_LABELS = [f"L{i}H{j}" for i in range(model.cfg.n_layers) for j in range(model.cfg.n_heads)]
if DO_SLOW_RUNS:
    attn_head_out_act_patch_results = patching.get_act_patch_attn_head_out_by_pos(model, corrupted_tokens, clean_cache, ioi_metric)
    attn_head_out_act_patch_results = einops.rearrange(attn_head_out_act_patch_results, "layer pos head -> (layer head) pos")
    save_imshow(attn_head_out_act_patch_results,
               filename="03_attn_head_out_by_pos",
               yaxis="Head Label",
               xaxis="Pos",
               x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
               y=ALL_HEAD_LABELS,
               title="attn_head_out Activation Patching By Pos")

  0%|          | 0/6768 [00:00<?, ?it/s]

Saved: activation_patches/03_attn_head_out_by_pos.png


 ### Patching multiple activation types
 Some utilities are provided to patch multiple activations types *in turn*. Note that this is *not* a utility to patch multiple activations at once, it's just a useful scan to get a sense for what's going on in a model
 By block: We patch the residual stream at the start of each block, attention output and MLP output over each layer and position

In [None]:
every_block_result = patching.get_act_patch_block_every(model, corrupted_tokens, clean_cache, ioi_metric)
save_imshow(every_block_result,
           filename="04_activation_patching_per_block",
           facet_col=0,
           facet_labels=["Residual Stream", "Attn Output", "MLP Output"],
           title="Activation Patching Per Block",
           xaxis="Position",
           yaxis="Layer",
           zmax=1,
           zmin=-1,
           x=[f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))])

  0%|          | 0/564 [00:00<?, ?it/s]

  0%|          | 0/564 [00:00<?, ?it/s]

  0%|          | 0/564 [00:00<?, ?it/s]

Saved: activation_patches/04_activation_patching_per_block.png


 By head: For each head we patch the output, query, key, value or pattern. We do all positions at once so it's not super slow.

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
save_imshow(every_head_all_pos_act_patch_result,
           filename="05_activation_patching_per_head_all_pos",
           facet_col=0,
           facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
           title="Activation Patching Per Head (All Pos)",
           xaxis="Head",
           yaxis="Layer",
           zmax=1,
           zmin=-1)

# Only if DO_SLOW_RUNS is True
if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(model, corrupted_tokens, clean_cache, ioi_metric)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    save_imshow(every_head_act_patch_result,
               filename="06_activation_patching_per_head_by_pos",
               facet_col=0,
               facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
               title="Activation Patching Per Head (By Pos)",
               xaxis="Position",
               yaxis="Layer & Head",
               zmax=1,
               zmin=-1,
               x=[f"{tok}_{i}" for i, tok in enumerate(model.to_str_tokens(clean_tokens[0]))],
               y=ALL_HEAD_LABELS)

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

  0%|          | 0/144 [00:00<?, ?it/s]

Saved: activation_patches/05_activation_patching_per_head_all_pos.png


  0%|          | 0/6768 [00:00<?, ?it/s]

  0%|          | 0/6768 [00:00<?, ?it/s]

  0%|          | 0/6768 [00:00<?, ?it/s]

  0%|          | 0/6768 [00:00<?, ?it/s]

  0%|          | 0/6768 [00:00<?, ?it/s]

Saved: activation_patches/06_activation_patching_per_head_by_pos.png


 ## Induction Patching
 To show how easy it is, lets do that again with induction heads in a 2L Attention Only model
 The input will be repeated random tokens eg BOS 1 5 8 9 2 1 5 8 9 2, and we judge the model's ability to predict the second repetition with its induction heads
 Lets call A, B and C different (non-repeated) random sequences. We'll start with clean tokens AA and corrupted tokens AB, and see how well the model can predict the second A given the first A

 ### Setup

In [None]:
attn_only = HookedTransformer.from_pretrained("attn-only-2l")
batch = 4
seq_len = 20
rand_tokens_A = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
rand_tokens_B = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
rand_tokens_C = torch.randint(100, 10000, (batch, seq_len)).to(attn_only.cfg.device)
bos = torch.tensor([attn_only.tokenizer.bos_token_id]*batch)[:, None].to(attn_only.cfg.device)
clean_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_A], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction = torch.cat([bos, rand_tokens_A, rand_tokens_B], dim=1).to(attn_only.cfg.device)

config.json: 0.00B [00:00, ?B/s]

./model_final.pth:   0%|          | 0.00/210M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/51.0 [00:00<?, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

Loaded pretrained model attn-only-2l into HookedTransformer


In [None]:
clean_logits_induction, clean_cache_induction = attn_only.run_with_cache(clean_tokens_induction)
corrupted_logits_induction, corrupted_cache_induction = attn_only.run_with_cache(corrupted_tokens_induction)

 We define our metric as negative loss on the second half (negative loss so that higher is better)
 This time we won't normalise our metric

In [None]:
def induction_loss(logits, answer_token_indices=rand_tokens_A):
    seq_len = answer_token_indices.shape[1]

    # logits: batch x seq_len x vocab_size
    # Take the logits for the answers, cut off the final element to get the predictions for all but the first element of the answers (which can't be predicted)
    final_logits = logits[:, -seq_len:-1]
    final_log_probs = final_logits.log_softmax(-1)
    return final_log_probs.gather(-1, answer_token_indices[:, 1:].unsqueeze(-1)).mean()
CLEAN_BASELINE_INDUCTION = induction_loss(clean_logits_induction).item()
print("Clean baseline:", CLEAN_BASELINE_INDUCTION)
CORRUPTED_BASELINE_INDUCTION = induction_loss(corrupted_logits_induction).item()
print("Corrupted baseline:", CORRUPTED_BASELINE_INDUCTION)

Clean baseline: -1.5530186891555786
Corrupted baseline: -12.87680721282959


 ### Patching

In [None]:
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)
save_imshow(every_head_all_pos_act_patch_result,
           filename="07_induction_per_head_all_pos",
           facet_col=0,
           facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
           title="Activation Patching Per Head (All Pos)",
           xaxis="Head",
           yaxis="Layer",
           zmax=CLEAN_BASELINE_INDUCTION)

if DO_SLOW_RUNS:
    every_head_act_patch_result = patching.get_act_patch_attn_head_by_pos_every(attn_only, corrupted_tokens_induction, clean_cache_induction, induction_loss)
    every_head_act_patch_result = einops.rearrange(every_head_act_patch_result, "act_type layer pos head -> act_type (layer head) pos")
    save_imshow(every_head_act_patch_result,
               filename="08_induction_per_head_by_pos",
               facet_col=0,
               facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
               title="Activation Patching Per Head (By Pos)",
               xaxis="Position",
               yaxis="Layer & Head",
               zmax=CLEAN_BASELINE_INDUCTION,
               x=[f"{tok}_{i}" for i, tok in enumerate(attn_only.to_str_tokens(clean_tokens_induction[0]))],
               y=[f"L{l}H{h}" for l in range(attn_only.cfg.n_layers) for h in range(attn_only.cfg.n_heads)])

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Saved: activation_patches/07_induction_per_head_all_pos.png


  0%|          | 0/656 [00:00<?, ?it/s]

  0%|          | 0/656 [00:00<?, ?it/s]

  0%|          | 0/656 [00:00<?, ?it/s]

  0%|          | 0/656 [00:00<?, ?it/s]

  0%|          | 0/656 [00:00<?, ?it/s]

Saved: activation_patches/08_induction_per_head_by_pos.png


 ### Changing the Corrupted Baseline
 We can also change the corrupted baseline easily to check what things look like! We'll keep clean as AA, but rather than corrupted as AB, we'll try out:
 * BA - This has a corrupted first half, so we expect both keys *and* values to matter. Head output patching should work, but value and key and pattern won't.
 * BB - This is still inductiony but with different tokens. So keys, queries and patterns don't matter, head output patching will work, and value will.
 * BC - This is just random tokens, so everything is corrupted! The induction head needs queries, keys *and* values, so only output will work.

In [None]:
corrupted_tokens_induction_BA = torch.cat([bos, rand_tokens_B, rand_tokens_A], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction_BB = torch.cat([bos, rand_tokens_B, rand_tokens_B], dim=1).to(attn_only.cfg.device)
corrupted_tokens_induction_BC = torch.cat([bos, rand_tokens_B, rand_tokens_C], dim=1).to(attn_only.cfg.device)

In [None]:
# BA baseline
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BA, clean_cache_induction, induction_loss)
save_imshow(every_head_all_pos_act_patch_result,
           filename="09_induction_BA_baseline",
           facet_col=0,
           facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
           title="Activation Patching Per Head on BA (All Pos)",
           xaxis="Head",
           yaxis="Layer",
           zmax=CLEAN_BASELINE_INDUCTION)

# BB baseline
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BB, clean_cache_induction, induction_loss)
save_imshow(every_head_all_pos_act_patch_result,
           filename="10_induction_BB_baseline",
           facet_col=0,
           facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
           title="Activation Patching Per Head on BB (All Pos)",
           xaxis="Head",
           yaxis="Layer",
           zmax=CLEAN_BASELINE_INDUCTION)

# BC baseline
every_head_all_pos_act_patch_result = patching.get_act_patch_attn_head_all_pos_every(attn_only, corrupted_tokens_induction_BC, clean_cache_induction, induction_loss)
save_imshow(every_head_all_pos_act_patch_result,
           filename="11_induction_BC_baseline",
           facet_col=0,
           facet_labels=["Output", "Query", "Key", "Value", "Pattern"],
           title="Activation Patching Per Head on BC (All Pos)",
           xaxis="Head",
           yaxis="Layer",
           zmax=CLEAN_BASELINE_INDUCTION)

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Saved: activation_patches/09_induction_BA_baseline.png


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Saved: activation_patches/10_induction_BB_baseline.png


  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

  0%|          | 0/16 [00:00<?, ?it/s]

Saved: activation_patches/11_induction_BC_baseline.png
