<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/cursivetransformer_mech_interp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Induction Circuit Investigation and Analysis

# Setup

In [1]:
!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
!pip install transformer_lens
!pip install gradio
!pip install wandb
!pip install einops
!pip install matplotlib
!pip install datasets

# Clone the cursivetransformer repository and install its requirements
!rm -rf cursivetransformer && git clone https://github.com/zwimpee/cursivetransformer.git
!pip install -r cursivetransformer/requirements.txt

Collecting git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
  Cloning https://github.com/callummcdougall/CircuitsVis.git to /tmp/pip-req-build-jc9_v5e9
  Running command git clone --filter=blob:none --quiet https://github.com/callummcdougall/CircuitsVis.git /tmp/pip-req-build-jc9_v5e9
  Resolved https://github.com/callummcdougall/CircuitsVis.git to commit 1e6129d08cae7af9242d9ab5d3ed322dd44b4dd3
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Cloning into 'cursivetransformer'...
remote: Enumerating objects: 2720, done.[K
remote: Counting objects: 100% (861/861), done.[K
remote: Compressing objects: 100% (287/287), done.[K
remote: Total 2720 (delta 639), reused 764 (delta 574), pack-reused 1859 (from 1)[K
Receiving objects: 100% (2720/2720), 44.12 MiB | 11.65 MiB/s, done.
Resolving deltas: 100% (1545/1545), done.
Collecting git+https://

In [2]:
import sys
sys.path.append('/content/cursivetransformer')
from cursivetransformer.model import get_all_args, get_checkpoint, get_latest_checkpoint_artifact
from cursivetransformer.data import create_datasets, offsets_to_strokes, strokes_to_offsets
from cursivetransformer.sample import generate, generate_n_words, plot_strokes
from cursivetransformer.mech_interp import (
    HookedCursiveTransformer,
    HookedCursiveTransformerConfig,
    convert_cursivetransformer_model_config,
    visualize_attention,
    generate_repeated_stroke_tokens,
    generate_random_ascii_context,
    run_and_cache_model_repeated_tokens,
    compute_induction_scores,
    plot_induction_scores,
    plot_head_attention_pattern,
    create_induction_summary,
    ablate_heads,
    get_induction_positions,
    compute_loss_on_induction_positions
)

import pandas as pd
import os

import copy
import types
from typing import List, Callable, Dict, Optional, Union, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio
import circuitsvis as cv
import matplotlib.pyplot as plt
import seaborn as sns

from IPython.display import display
from jaxtyping import Float, Int


import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import ActivationCache

torch.set_grad_enabled(False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import wandb
wandb.login()

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:


Abort: 

In [None]:
args = get_all_args(False)
args.sample_only = True
args.load_from_run_id = '6le6tujz'
args.wandb_entity = 'sam-greydanus'
args.dataset_name = 'bigbank'
args.wandb_run_name = 'cursivetransformer_dictionary_learning'

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

train_dataset, test_dataset = create_datasets(args)

args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.vocab_size = train_dataset.get_vocab_size()
args.context_vocab_size = train_dataset.get_char_vocab_size()

## Load model into HookedCursiveTransformer

In [None]:
cfg = convert_cursivetransformer_model_config(args)
model = HookedCursiveTransformer.from_pretrained("cursivetransformer", cfg)

# Experiment

In [None]:
seq_len = 50  # Number of (θ, r) pairs in the initial sequence <- [ ] TODO: DEBUG THIS!
n_repeats = 2  # Number of repetitions
batch_size = 1

# Generate repeated stroke tokens and random ASCII context
rep_tokens = generate_repeated_stroke_tokens(model, test_dataset, seq_len, n_repeats, batch_size)
context_tokens = generate_random_ascii_context(model, batch_size)

# Run the model and cache activations
model = model.to(device)
logits, targets, cache = run_and_cache_model_repeated_tokens(model, rep_tokens.to(device), context_tokens.to(device))

# sanity_check_token_pairs(rep_tokens)
# verify_attention_summation(cache, layer=2, head=3, attn_type='self')

induction_scores = compute_induction_scores(rep_tokens, cache, model)
plot_induction_scores(induction_scores)

In [None]:
# Identify top N heads with highest induction scores
N = 5
induction_scores_flat = induction_scores.view(-1)
top_scores, top_indices = torch.topk(induction_scores_flat, N)
num_heads = model.cfg.n_heads

print(f"Top {N} Induction Heads:")
for rank, (score, idx) in enumerate(zip(top_scores, top_indices), start=1):
    layer = idx // num_heads
    head = idx % num_heads
    print(f"{rank}. Layer {layer}, Head {head}, Induction Score: {score:.4f}")
    # Plot attention pattern
    plot_head_attention_pattern(cache, layer, head, seq_len * n_repeats * 2, attn_type='self')

In [None]:
# Create and display summary table
df_induction = create_induction_summary(induction_scores)
display(df_induction.sort_values(by='Score', ascending=False))

In [None]:
threshold = 0.0001
# Example head_list with attention types
heads_to_ablate = [
    (layer_idx, head_idx, 'self')  # or 'cross'
    for layer_idx in range(model.cfg.n_layers)
    for head_idx in range(model.cfg.n_heads)
    if induction_scores[layer_idx, head_idx] > threshold
]

# Run ablation with corrected hook names
ablated_logits = ablate_heads(model, heads_to_ablate, rep_tokens, context_tokens)

# Get induction positions
induction_positions = get_induction_positions(rep_tokens, seq_len, n_repeats)

# Compute loss on induction positions
original_loss = compute_loss_on_induction_positions(logits, targets, induction_positions)
ablated_loss = compute_loss_on_induction_positions(ablated_logits, targets, induction_positions)

print(f"Original Loss on Induction Positions: {original_loss.item():.4f}")
print(f"Ablated Loss on Induction Positions: {ablated_loss.item():.4f}")