<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/cursivetransformer_mech_interp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Setup

!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
!pip install transformer_lens
!pip install gradio
!pip install wandb
!pip install einops
!pip install matplotlib
!pip install datasets

# Clone the cursivetransformer repository and install its requirements
!rm -rf cursivetransformer && git clone https://github.com/zwimpee/cursivetransformer.git
!pip install -r cursivetransformer/requirements.txt

In [None]:
import os
import sys; sys.path.append('/content/cursivetransformer');
import copy
import types
from typing import List, Callable, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio
import circuitsvis as cv

from IPython.display import display
from jaxtyping import Float, Int

from cursivetransformer.model import get_all_args, get_checkpoint, get_latest_checkpoint_artifact
from cursivetransformer.data import create_datasets, offsets_to_strokes, strokes_to_offsets
from cursivetransformer.sample import generate, generate_n_words, plot_strokes
from cursivetransformer.mech_interp import (
    HookedCursiveTransformer,
    HookedCursiveTransformerConfig,
    convert_cursivetransformer_model_config,
    visualize_attention
)

import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint
from transformer_lens import ActivationCache

torch.set_grad_enabled(False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import wandb
wandb.login()

In [None]:
args = get_all_args(False)
args.sample_only = True
args.load_from_run_id = '6le6tujz'
args.wandb_entity = 'sam-greydanus'
args.dataset_name = 'bigbank'
args.wandb_run_name = 'cursivetransformer_dictionary_learning'

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

train_dataset, test_dataset = create_datasets(args)

args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.vocab_size = train_dataset.get_vocab_size()
args.context_vocab_size = train_dataset.get_char_vocab_size()

In [None]:
cfg = convert_cursivetransformer_model_config(args)
model = HookedCursiveTransformer.from_pretrained("cursivetransformer", cfg)

# Induction Circuit Investigation and Analysis





In [None]:
test_dataset.feature_sizes

In [None]:
test_dataset.cumulative_sizes

In [None]:
def generate_repeated_stroke_tokens(
    model,
    seq_len: int,
    n_repeats: int,
    batch_size: int = 1
) -> Int[torch.Tensor, "batch_size full_seq_len"]:
    """
    Generates a sequence of repeated stroke tokens, alternating between θ and r tokens.

    Args:
        model: The model instance.
        seq_len: Number of (θ, r) pairs in the initial sequence.
        n_repeats: Number of times to repeat the sequence.
        batch_size: Batch size.

    Returns:
        rep_tokens: Tensor of shape [batch_size, n_repeats * 2 * seq_len]
    """
    device = model.cfg.device
    feature_sizes = model.feature_sizes  # [size_r_bins, size_theta_bins]
    cumulative_sizes = model.cumulative_sizes  # cumulative indices for token types

    # Get valid indices for θ and r tokens
    theta_token_indices = torch.arange(
        cumulative_sizes[1],
        cumulative_sizes[2],
        device=device
    )
    r_token_indices = torch.arange(
        cumulative_sizes[0],
        cumulative_sizes[1],
        device=device
    )

    # Generate random θ and r tokens
    random_theta_tokens = theta_token_indices[
        torch.randint(
            low=0,
            high=feature_sizes[1],
            size=(batch_size, seq_len),
            device=device
        )
    ]
    random_r_tokens = r_token_indices[
        torch.randint(
            low=0,
            high=feature_sizes[0],
            size=(batch_size, seq_len),
            device=device
        )
    ]

    # Alternate between θ and r tokens
    stroke_tokens_half = torch.zeros(batch_size, seq_len * 2, dtype=torch.long, device=device)
    stroke_tokens_half[:, 0::2] = random_theta_tokens
    stroke_tokens_half[:, 1::2] = random_r_tokens

    # Repeat the sequence
    rep_tokens = stroke_tokens_half.repeat(1, n_repeats)

    return rep_tokens

def generate_random_ascii_context(
    model,
    batch_size: int = 1
) -> Int[torch.Tensor, "batch_size context_seq_len"]:
    """
    Generates a random ASCII context sequence.

    Args:
        model: The model instance.
        batch_size: Batch size.

    Returns:
        context_tokens: Tensor of shape [batch_size, context_seq_len]
    """
    device = model.cfg.device
    context_seq_len = model.cfg.context_block_size
    context_vocab_size = model.cfg.context_vocab_size

    context_tokens = torch.randint(
        low=0,
        high=context_vocab_size - 1,  # Exclude PAD token
        size=(batch_size, context_seq_len),
        dtype=torch.long,
        device=device
    )

    return context_tokens

In [None]:
def run_and_cache_model_repeated_tokens(
    model,
    rep_tokens: torch.Tensor,
    context_tokens: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor, ActivationCache]:
    """
    Runs the model on repeated tokens and caches activations.

    Args:
        model: The model instance.
        rep_tokens: Input stroke tokens of shape [batch_size, seq_len]
        context_tokens: Input context tokens of shape [batch_size, context_seq_len]

    Returns:
        logits: Model output logits.
        cache: Activation cache.
    """
    # Shift inputs to create targets
    inputs = rep_tokens[:, :-1]
    targets = rep_tokens[:, 1:]

    # Run model with cache
    logits, cache = model.run_with_cache(
        tokens=inputs,
        context=context_tokens,
        return_type="both"
    )

    return logits, targets, cache

In [None]:
def compute_induction_scores(
    model,
    rep_tokens: torch.Tensor,
    cache: ActivationCache
) -> torch.Tensor:
    """
    Computes induction scores for all attention heads, accounting for the alternating tokens.

    Args:
        model: The model instance.
        rep_tokens: Input stroke tokens of shape [batch_size, seq_len]
        cache: Activation cache.

    Returns:
        induction_scores: Tensor of shape [num_layers, num_heads]
    """
    num_layers = model.cfg.n_layers
    num_heads = model.cfg.n_heads
    induction_scores = torch.zeros(num_layers, num_heads, device=model.cfg.device)

    batch_size, seq_len = rep_tokens.shape

    # Group tokens into pairs of (θ, r)
    token_pairs = rep_tokens.view(batch_size, seq_len // 2, 2)  # Shape: [batch_size, seq_len_pairs, 2]

    # Represent token pairs as tuples for comparison
    token_pair_tuples = [tuple(pair.tolist()) for pair in token_pairs[0]]

    # For each position in the second half, find matching token in the first half
    half_point = len(token_pair_tuples) // 2
    for layer in range(num_layers):
        attn_patterns = cache["pattern", layer]  # Shape: [batch_size, num_heads, seq_len_q, seq_len_k]
        for head in range(num_heads):
            attn = attn_patterns[0, head]  # Shape: [seq_len_q, seq_len_k]
            scores = []
            for i in range(half_point, len(token_pair_tuples) - 1):
                current_pair = token_pair_tuples[i]
                # Find the last occurrence of the same token pair before position i
                try:
                    k_pos = max(j for j in range(i) if token_pair_tuples[j] == current_pair)
                    # Since each token pair corresponds to 2 tokens, adjust positions
                    q_pos = i * 2  # Query position in token indices
                    k_pos = k_pos * 2  # Key position in token indices
                    # Get attention weight from q_pos to k_pos (for both θ and r tokens)
                    attn_weight_theta = attn[q_pos - 1, k_pos - 1]
                    attn_weight_r = attn[q_pos, k_pos]
                    # Average the attention weights for θ and r tokens
                    attn_weight = (attn_weight_theta + attn_weight_r) / 2
                    scores.append(attn_weight.item())
                except ValueError:
                    # No previous matching token pair found
                    continue
            if scores:
                induction_scores[layer, head] = torch.tensor(scores).mean()
    return induction_scores

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

def plot_induction_scores(induction_scores: torch.Tensor):
    """
    Plots a heatmap of induction scores.

    Args:
        induction_scores: Tensor of shape [num_layers, num_heads]
    """
    plt.figure(figsize=(12, 6))
    sns.heatmap(
        induction_scores.cpu().numpy(),
        annot=True,
        fmt=".2f",
        cmap="YlGnBu",
        xticklabels=[f"H{h}" for h in range(induction_scores.shape[1])],
        yticklabels=[f"L{l}" for l in range(induction_scores.shape[0])]
    )
    plt.title("Induction Scores per Head")
    plt.xlabel("Heads")
    plt.ylabel("Layers")
    plt.show()

In [None]:
def plot_head_attention_pattern(
    cache: ActivationCache,
    layer: int,
    head: int,
    seq_len: int
):
    """
    Plots the attention pattern of a specific head.

    Args:
        cache: Activation cache.
        layer: Layer index.
        head: Head index.
        seq_len: Total sequence length.
    """
    attn = cache["pattern", layer][0, head].detach().cpu().numpy()
    plt.figure(figsize=(8, 6))
    plt.imshow(attn, cmap='viridis', aspect='auto')
    plt.colorbar()
    plt.title(f"Attention Pattern - Layer {layer}, Head {head}")
    plt.xlabel("Key Positions")
    plt.ylabel("Query Positions")
    plt.show()

In [None]:
def compute_cross_attention_induction_scores(
    model,
    context_tokens: torch.Tensor,
    cache: ActivationCache
) -> torch.Tensor:
    """
    Computes induction-like scores for cross-attention heads.

    Args:
        model: The model instance.
        context_tokens: Context tokens of shape [batch_size, context_seq_len]
        cache: Activation cache.

    Returns:
        cross_induction_scores: Tensor of shape [num_layers, num_heads]
    """
    num_layers = model.cfg.n_layers
    num_heads = model.cfg.n_heads
    cross_induction_scores = torch.zeros(num_layers, num_heads, device=model.cfg.device)

    batch_size, context_seq_len = context_tokens.shape

    for layer in range(num_layers):
        attn_patterns = cache["pattern", layer, "cross_attn"]  # Need to access cross-attention patterns
        for head in range(num_heads):
            attn = attn_patterns[0, head]  # Shape: [stroke_seq_len, context_seq_len]
            # For this example, we might need more specific analysis based on the use case
            # Placeholder for cross-attention induction score computation
            cross_induction_scores[layer, head] = attn.mean().item()
    return cross_induction_scores

In [None]:
# Store induction scores
# induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

# def induction_score_hook(pattern: torch.Tensor, hook: HookPoint):
#     induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1 - stroke_sequence.shape[0])
#     induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
#     induction_score_store[hook.layer(), :] = induction_score

# # Updated hook filter for both self-attention and cross-attention patterns
# pattern_hook_names_filter = lambda name: name.endswith("attn.hook_pattern") or name.endswith("cross_attn.hook_pattern")

# # Run with hooks to collect induction scores
# _ = model.run_with_hooks(
#     repeated_tokens,
#     repeated_context,
#     fwd_hooks=[(
#         pattern_hook_names_filter,
#         induction_score_hook
#     )]
# )

# # Visualize induction score by head
# px.imshow(induction_score_store.detach().cpu(), labels={"x": "Head", "y": "Layer"}, title="Induction Score by Head").show()

# # Visualization of Attention Pattern
# def visualize_pattern_hook(pattern: torch.Tensor, hook: HookPoint):
#     display(
#         cv.attention.attention_patterns(
#             tokens=repeated_tokens,
#             attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
#         )
#     )

# induction_head_layer = 3
# induction_head_index = 1
# model.run_with_hooks(
#     repeated_tokens,
#     repeated_context,
#     fwd_hooks=[(
#         'blocks.{}.cross_attn.hook_pattern'.format(induction_head_layer),
#         visualize_pattern_hook
#     )]
# )

# # Activation Patching

# # Activation Patching - Updated for Cross-Attention
# def activation_patching(
#     model: HookedCursiveTransformer,
#     x_clean: torch.Tensor,
#     c_clean: torch.Tensor,
#     x_corrupted: torch.Tensor,
#     c_corrupted: torch.Tensor,
#     patching_nodes: List[str],
#     patch_positions: Optional[torch.Tensor] = None,
# ):
#     _, cache_corrupted = model.run_with_cache(x_corrupted, c_corrupted, return_type="both")

#     def patching_hook(act, hook):
#         act_corrupted = cache_corrupted[hook.name]
#         if patch_positions is not None:
#             act[:, patch_positions, :] = act_corrupted[:, patch_positions, :]
#         else:
#             act[:] = act_corrupted
#         return act

#     hooks = [(node, patching_hook) for node in patching_nodes]
#     logits_patched = model.run_with_hooks(x_clean, c_clean, fwd_hooks=hooks, return_type="logits")
#     return logits_patched

# # Example Activation Patching
# x_clean = stroke_sequence.unsqueeze(0)
# c_clean = ascii_sequence.unsqueeze(0)

# x_corrupted = x_clean.clone()
# x_corrupted = (x_corrupted + 1) % model.cfg.d_vocab
# c_corrupted = c_clean.clone()

# patching_nodes = ['blocks.0.cross_attn.hook_result']
# logits_patched = activation_patching(
#     model,
#     x_clean,
#     c_clean,
#     x_corrupted,
#     c_corrupted,
#     patching_nodes,
# )

# # Get predictions from patched logits
# predictions_patched = logits_patched.argmax(dim=-1)

# # Run the clean input without patches
# logits_clean = model(x_clean, c_clean)
# predictions_clean = logits_clean.argmax(dim=-1)

# # Run the corrupted input without patches
# logits_corrupted = model(x_corrupted, c_corrupted)
# predictions_corrupted = logits_corrupted.argmax(dim=-1)

# # Compare predictions
# print("Clean Predictions:", predictions_clean)
# print("Corrupted Predictions:", predictions_corrupted)
# print("Patched Predictions:", predictions_patched)

# # Visualize the strokes if applicable
# _ = plot_strokes(offsets_to_strokes(test_dataset.decode_stroke(stroke_sequence)), test_dataset.decode_text(ascii_sequence))
# _ = plot_strokes(offsets_to_strokes(test_dataset.decode_stroke(x_corrupted[0])), test_dataset.decode_text(c_corrupted[0]))