<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/cursivetransformer_mech_interp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Induction Circuit Investigation and Analysis

# Setup

In [1]:
!pip install transformer_lens
!pip install gradio
!pip install wandb
!pip install einops
!pip install matplotlib
!pip install datasets

# Clone the cursivetransformer repository and install its requirements
!rm -rf cursivetransformer && git clone https://github.com/zwimpee/cursivetransformer.git
!pip install -r cursivetransformer/requirements.txt

Collecting transformer_lens
  Downloading transformer_lens-2.7.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Collecting wandb>=0.13.5 (from transformer_lens)
  Downloading wandb-0.18.3-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets>=2.7.1->transformer_lens)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxha

In [2]:
import sys
sys.path.append('/content/cursivetransformer')
from cursivetransformer.model import get_all_args, get_checkpoint, get_latest_checkpoint_artifact
from cursivetransformer.data import create_datasets, offsets_to_strokes, strokes_to_offsets
from cursivetransformer.sample import generate, generate_n_words, plot_strokes
from cursivetransformer.mech_interp import (
    HookedCursiveTransformer,
    HookedCursiveTransformerConfig,
    convert_cursivetransformer_model_config,
    visualize_attention,
    generate_repeated_stroke_tokens,
    generate_random_ascii_context,
    run_and_cache_model_repeated_tokens,
    compute_induction_scores,
    plot_induction_scores,
    plot_head_attention_pattern,
    create_induction_summary,
    ablate_heads,
    get_induction_positions,
    compute_loss_on_induction_positions
)

import pandas as pd
import os

import copy
import types
from typing import List, Callable, Dict, Optional, Union, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio
import circuitsvis as cv
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import display
from jaxtyping import Float, Int


import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import ActivationCache

torch.set_grad_enabled(False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


True

In [None]:
args = get_all_args(False)
args.sample_only = True
args.load_from_run_id = '6le6tujz'
args.wandb_entity = 'sam-greydanus'
args.dataset_name = 'bigbank'
args.wandb_run_name = 'cursivetransformer_dictionary_learning'

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

train_dataset, test_dataset = create_datasets(args)

args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.vocab_size = train_dataset.get_vocab_size()
args.context_vocab_size = train_dataset.get_char_vocab_size()

## Load model into HookedCursiveTransformer

In [None]:
cfg = convert_cursivetransformer_model_config(args)
model = HookedCursiveTransformer.from_pretrained("cursivetransformer", cfg)

# Experiment

In [None]:
seq_len = 50  # Number of (θ, r) pairs in the initial sequence <- [ ] TODO: DEBUG THIS!
n_repeats = 2  # Number of repetitions
batch_size = 1

# Generate repeated stroke tokens and random ASCII context
rep_tokens = generate_repeated_stroke_tokens(model, test_dataset, seq_len, n_repeats, batch_size)
context_tokens = generate_random_ascii_context(model, batch_size)

# Run the model and cache activations
model = model.to(device)
logits, targets, cache = run_and_cache_model_repeated_tokens(model, rep_tokens.to(device), context_tokens.to(device))

# sanity_check_token_pairs(rep_tokens)
# verify_attention_summation(cache, layer=2, head=3, attn_type='self')

induction_scores = compute_induction_scores(rep_tokens, cache, model)
plot_induction_scores(induction_scores)

In [None]:
# Identify top N heads with highest induction scores
N = 5
induction_scores_flat = induction_scores.view(-1)
top_scores, top_indices = torch.topk(induction_scores_flat, N)
num_heads = model.cfg.n_heads

print(f"Top {N} Induction Heads:")
for rank, (score, idx) in enumerate(zip(top_scores, top_indices), start=1):
    layer = idx // num_heads
    head = idx % num_heads
    print(f"{rank}. Layer {layer}, Head {head}, Induction Score: {score:.4f}")
    # Plot attention pattern
    plot_head_attention_pattern(cache, layer, head, seq_len * n_repeats * 2, attn_type='self')

In [None]:
# Create and display summary table
df_induction = create_induction_summary(induction_scores)
display(df_induction.sort_values(by='Score', ascending=False))

In [None]:
threshold = 0.0001
# Example head_list with attention types
heads_to_ablate = [
    (layer_idx, head_idx, 'self')  # or 'cross'
    for layer_idx in range(model.cfg.n_layers)
    for head_idx in range(model.cfg.n_heads)
    if induction_scores[layer_idx, head_idx] > threshold
]

# Run ablation with corrected hook names
ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)

# Print shapes for debugging
print(f"rep_tokens shape: {rep_tokens.shape}")
print(f"logits shape: {logits[0].shape}")
print(f"targets shape: {targets.shape}")

# Get induction positions
induction_positions = get_induction_positions(rep_tokens, seq_len, n_repeats)
print(f"Induction positions: {induction_positions}")

# Ensure induction_positions are within bounds
max_position = logits[0].shape[1] - 1  # Assuming logits shape is (batch_size, seq_len, vocab_size)
induction_positions = induction_positions[0]  # Unpack the tensor from the list
induction_positions = [pos for pos in induction_positions if pos <= max_position]
print(f"Filtered induction positions: {induction_positions}")

# Compute loss on induction positions
if len(induction_positions) > 0:
    original_loss = compute_loss_on_induction_positions(logits, targets, induction_positions)
    ablated_loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)

    print(f"Original Loss on Induction Positions: {original_loss.item():.4f}")
    print(f"Ablated Loss on Induction Positions: {ablated_loss.item():.4f}")
else:
    print("No valid induction positions found within the sequence length.")

In [None]:
import itertools
import matplotlib.pyplot as plt
import seaborn as sns
import random

def perform_ablation_study(model, rep_tokens, context_tokens, targets, induction_positions):
    results = {}
    n_layers = model.cfg.n_layers
    n_heads = model.cfg.n_heads

    # 1. Single-Head Ablation
    for layer in range(n_layers):
        for head in range(n_heads):
            heads_to_ablate = [(layer, head, 'self')]
            ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)
            loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)
            results[f'L{layer}H{head}'] = loss.item()

    # 2. Layer-wise Ablation
    for layer in range(n_layers):
        heads_to_ablate = [(layer, head, 'self') for head in range(n_heads)]
        ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)
        loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)
        results[f'Layer{layer}'] = loss.item()

    # 3. Cumulative Ablation
    for i in range(1, n_layers * n_heads + 1):
        heads_to_ablate = list(itertools.product(range(n_layers), range(n_heads)))[:i]
        heads_to_ablate = [(l, h, 'self') for l, h in heads_to_ablate]
        ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)
        loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)
        results[f'Cumulative{i}'] = loss.item()

    # 4. Random Ablation Baseline
    n_random_trials = 10
    random_ablation_losses = []
    for _ in range(n_random_trials):
        n_heads_to_ablate = np.random.randint(1, n_layers * n_heads + 1)
        heads_to_ablate = random.sample(list(itertools.product(range(n_layers), range(n_heads))), n_heads_to_ablate)
        heads_to_ablate = [(l, h, 'self') for l, h in heads_to_ablate]
        ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)
        loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)
        random_ablation_losses.append(loss.item())
    results['RandomAblationMean'] = np.mean(random_ablation_losses)
    results['RandomAblationStd'] = np.std(random_ablation_losses)

    return results

# After performing the ablation study
ablation_results = perform_ablation_study(model, rep_tokens, context_tokens, targets, induction_positions)

# Prepare data for visualization
df = pd.DataFrame(list(ablation_results.items()), columns=['Ablation', 'Loss'])
df = df.set_index('Ablation')

# Print debugging information
print("All ablation results:")
print(df)
print("\nModel configuration:")
print(f"Number of layers: {model.cfg.n_layers}")
print(f"Number of heads: {model.cfg.n_heads}")

# Visualize overall results
plt.figure(figsize=(15, 10))
sns.heatmap(df.T, annot=True, cmap='coolwarm', center=df['Loss'].mean())
plt.title('Ablation Study Results')
plt.show()

# Bar plot of results
plt.figure(figsize=(15, 10))
sns.barplot(x=df.index, y='Loss', data=df)
plt.xticks(rotation=90)
plt.title('Ablation Study Results - Bar Plot')
plt.tight_layout()
plt.show()

# Separate single-head ablations
single_head_df = df[df.index.str.match(r'L\d+H\d+')]
print("\nSingle-head ablation results:")
print(single_head_df)

# Check if we have the correct number of single-head ablations
expected_ablations = model.cfg.n_layers * model.cfg.n_heads
if len(single_head_df) == expected_ablations:
    # Reshape and plot single-head ablations
    plt.figure(figsize=(15, 10))
    reshaped_values = single_head_df['Loss'].values.reshape(model.cfg.n_layers, model.cfg.n_heads)
    sns.heatmap(reshaped_values,
                annot=True, cmap='coolwarm', center=single_head_df['Loss'].mean(),
                xticklabels=range(model.cfg.n_heads), yticklabels=range(model.cfg.n_layers))
    plt.title('Single-Head Ablation Results')
    plt.xlabel('Head')
    plt.ylabel('Layer')
    plt.show()
else:
    print(f"\nWarning: Number of single-head ablations ({len(single_head_df)}) "
          f"doesn't match expected number ({expected_ablations})")
    print("Skipping single-head ablation heatmap.")

# 5. Attention Pattern Analysis
fig, axes = plt.subplots(cfg.n_layers, cfg.n_heads, figsize=(20, 20))
for layer in range(cfg.n_layers):
    for head in range(cfg.n_heads):
        attn_pattern = cache['pattern', layer][0, head].detach().cpu().numpy()
        im = axes[layer, head].imshow(attn_pattern, cmap='viridis')
        axes[layer, head].set_title(f'Layer {layer}, Head {head}')
        axes[layer, head].axis('off')
fig.colorbar(im, ax=axes.ravel().tolist(), shrink=0.5)
plt.subplots_adjust(hspace=0.3, wspace=0.3)
plt.show()

# 6. Loss Landscape Visualization
def compute_loss_landscape(model, rep_tokens, context_tokens, targets, induction_positions, n_points=20):
    landscape = np.zeros((cfg.n_layers, cfg.n_heads, n_points))
    for layer in range(cfg.n_layers):
        for head in range(cfg.n_heads):
            for i, alpha in enumerate(np.linspace(0, 1, n_points)):
                heads_to_ablate = [(layer, head, 'self')]
                ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)
                loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)
                landscape[layer, head, i] = loss.item()
    return landscape

n_points = 20
loss_landscape = compute_loss_landscape(model, rep_tokens, context_tokens, targets, induction_positions, n_points)

# Compute baseline loss
baseline_loss = compute_loss_on_induction_positions(logits, targets, induction_positions).item()

# Visualize loss landscape
fig, axes = plt.subplots(cfg.n_layers, cfg.n_heads, figsize=(20, 20))
for layer in range(cfg.n_layers):
    for head in range(cfg.n_heads):
        axes[layer, head].plot(np.linspace(0, 1, n_points), loss_landscape[layer, head])
        axes[layer, head].axhline(y=baseline_loss, color='r', linestyle='--')
        axes[layer, head].set_title(f'L{layer}H{head}')
        axes[layer, head].set_xlabel('Ablation Strength')
        axes[layer, head].set_ylabel('Loss')
plt.subplots_adjust(hspace=0.5, wspace=0.3)
plt.show()

# Bar plot with baseline
df = pd.DataFrame(ablation_results.items(), columns=['Ablation', 'Loss'])
df = df.sort_values('Loss', ascending=False)

plt.figure(figsize=(15, 10))
sns.barplot(x='Ablation', y='Loss', data=df)
plt.axhline(y=baseline_loss, color='r', linestyle='--', label='Baseline Loss')
plt.xticks(rotation=90)
plt.title('Ablation Study Results - Bar Plot')
plt.legend()
plt.subplots_adjust(bottom=0.2)
plt.show()

# Review of Results Thus Far

We are likely having trouble due to the fact that our model has non-linearities, namely that is has MLPs.

This is not to say this has all been for nothing, however, as we can use the ablation techniques to try and identify other circuits that may exist.

Enough time has been spent on this particular toy problem of random repeated tokens, it is time to move on to the task of generating handwritten cursive, and repeat the ablation study, with the added augmentation of plotting both the unablated and ablated decoded strokes along with the attention patterns.

To do this, we will need to adapt the code from `sample.py` to use `HookedCursiveTransformer`:
```python
def generate_n_words(model, dataset, text, model_device='cpu', do_sample=False,
                         top_k=None, temperature=1.0, num_steps=950, n_words=3):
    '''Warmup sequence assumes we're using tokenization scheme from git commit 4eef841a55496f9ad444336530caca63b0a3cc23'''
    SEED_TOKENS = torch.tensor([377,   0, 371,  21, 361,  41, 355,  38, 350,  34, 353,  36, 359,  15,
        414,  30, 408,  21, 414,  30, 429,  31, 447,  30, 310,  28, 376,  28,
        381,  28, 372,  30, 366,  23, 357,  34, 353,  36, 355,  39, 402,  23,
        418,  30, 418,  30, 428,  12, 353,  24, 350,  34, 359,  30, 376,  28,
        415,  30, 418,  30, 414,  30, 372,  25, 356,  27, 354,  31, 353,  36,
        364,  31, 418,  30, 418,  30, 418,  30, 353,  36, 348,  22, 357,  34,
        366,  34, 407,  31, 418,  30, 422,  32, 376,  28, 361,  34, 377, 151,
        376, 232], dtype=torch.int64)
    SEED_CHARS = 'snn'
  
    model_device = next(model.parameters()).device
    warmup_steps = len(SEED_TOKENS)
    ascii_context = f'{SEED_CHARS} {text}'

    def count_words(text):
      return len(text.split(' '))
    assert count_words(ascii_context) == n_words+1, f"Expected {n_words+1} words, got {count_words(ascii_context)}"

    context = dataset.encode_text(ascii_context).unsqueeze(0).to(model_device)
    X_init = SEED_TOKENS.unsqueeze(0).to(model_device)
    
    steps = num_steps - X_init.size(1)
    X_samp = generate(model, X_init, context, steps, temperature=temperature,
                      top_k=top_k, do_sample=do_sample).to('cpu')
    
    stroke_seq = X_samp[0].detach().cpu().numpy()[len(SEED_TOKENS):]
    offset_samp = dataset.decode_stroke(stroke_seq)
    point_samp = offsets_to_strokes(offset_samp)

    return point_samp
```


In [None]:
# - [ ] TODO: Create a new function based off of `generate_n_words`(provided above) that uses HookedCursiveTransformer instead of the original pytorch model