<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/cursivetransformer_mech_interp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Setup

!pip install git+https://github.com/callummcdougall/CircuitsVis.git#subdirectory=python
!pip install transformer_lens
!pip install gradio
!pip install wandb
!pip install einops
!pip install matplotlib
!pip install datasets

# Clone the cursivetransformer repository and install its requirements
!rm -rf cursivetransformer && git clone https://github.com/zwimpee/cursivetransformer.git
!pip install -r cursivetransformer/requirements.txt

import os
import sys; sys.path.append('/content/cursivetransformer');
import copy
import types
from typing import List, Callable, Dict, Optional, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import einops
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.express as px
import plotly.io as pio
import circuitsvis as cv

from IPython.display import display
from jaxtyping import Float, Int

from cursivetransformer.model import get_all_args, get_checkpoint, get_latest_checkpoint_artifact
from cursivetransformer.data import create_datasets, offsets_to_strokes, strokes_to_offsets
from cursivetransformer.sample import generate, generate_n_words, plot_strokes
from cursivetransformer.mech_interp import (
    HookedCursiveTransformer,
    HookedCursiveTransformerConfig,
    convert_cursivetransformer_model_config,
    visualize_attention
)

import transformer_lens.utils as utils
from transformer_lens.hook_points import HookPoint

torch.set_grad_enabled(False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

import wandb
wandb.login()

args = get_all_args(False)
args.sample_only = True
args.load_from_run_id = '6le6tujz'
args.wandb_entity = 'sam-greydanus'
args.dataset_name = 'bigbank'
args.wandb_run_name = 'cursivetransformer_dictionary_learning'

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

train_dataset, test_dataset = create_datasets(args)

args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.vocab_size = train_dataset.get_vocab_size()
args.context_vocab_size = train_dataset.get_char_vocab_size()

cfg = convert_cursivetransformer_model_config(args)
model = HookedCursiveTransformer.from_pretrained("cursivetransformer", cfg)

# Induction Circuit Investigation and Analysis

## 1. Get an example from the dataset and create repeated pattern

def create_repeated_sequence(stroke_sequence, pad_token, repeat_count=2):
    pad_tensor = torch.tensor([pad_token])
    sequences = [stroke_sequence if i % 2 == 0 else pad_tensor for i in range(repeat_count * 2 - 1)]
    return torch.cat(sequences).unsqueeze(0)

batch_size = 10
index = 0
stroke_tensor, ascii_tensor, y = test_dataset[index]
stroke_tensor = stroke_tensor.unsqueeze(0) # Shape: [1, 1000]
ascii_tensor = ascii_tensor.unsqueeze(0) # Shape: [1, 50]
y = y.unsqueeze(0) # Shape: [1, 1000]

### Create repeated sequence of stroke and ascii tokens (improved approach for reusability)

stroke_sequence = stroke_tensor[0][:50]
ascii_sequence = ascii_tensor[0][:2]

repeated_stroke_sequence = create_repeated_sequence(stroke_sequence, test_dataset.PAD_TOKEN)
repeated_ascii_sequence = einops.repeat(ascii_sequence, "seq_len -> (2 seq_len)")

# - [x] TODO: FIX THIS!
_ = plot_strokes(offsets_to_strokes(test_dataset.decode_stroke(stroke_sequence)), test_dataset.decode_text(ascii_sequence))
_ = plot_strokes(offsets_to_strokes(test_dataset.decode_stroke(repeated_stroke_sequence)), test_dataset.decode_text(repeated_ascii_sequence))

## 2. Run the model over the input token sequences
with torch.no_grad():
    outputs, cache = model.run_with_cache(
        repeated_stroke_sequence, repeated_ascii_sequence,
        return_type="both", per_token_loss=True
    )

## 3. Compute and plot per-token loss by position

loss_by_position = outputs[1].detach().cpu().numpy()  # Shape: [batch_size, sequence_length]
px.line(y=loss_by_position[0], labels={"x": "Position", "y": "Loss"}, title="Loss by position on repeated token sequence").show()

# # Store induction scores
# induction_score_store = torch.zeros((model.cfg.n_layers, model.cfg.n_heads), device=model.cfg.device)

# def induction_score_hook(pattern: torch.Tensor, hook: HookPoint):
#     induction_stripe = pattern.diagonal(dim1=-2, dim2=-1, offset=1 - stroke_sequence.shape[0])
#     induction_score = einops.reduce(induction_stripe, "batch head_index position -> head_index", "mean")
#     induction_score_store[hook.layer(), :] = induction_score

# # Updated hook filter for both self-attention and cross-attention patterns
# pattern_hook_names_filter = lambda name: name.endswith("attn.hook_pattern") or name.endswith("cross_attn.hook_pattern")

# # Run with hooks to collect induction scores
# _ = model.run_with_hooks(
#     repeated_tokens,
#     repeated_context,
#     fwd_hooks=[(
#         pattern_hook_names_filter,
#         induction_score_hook
#     )]
# )

# # Visualize induction score by head
# px.imshow(induction_score_store.detach().cpu(), labels={"x": "Head", "y": "Layer"}, title="Induction Score by Head").show()

# # Visualization of Attention Pattern
# def visualize_pattern_hook(pattern: torch.Tensor, hook: HookPoint):
#     display(
#         cv.attention.attention_patterns(
#             tokens=repeated_tokens,
#             attention=pattern[0, induction_head_index, :, :][None, :, :] # Add a dummy axis, as CircuitsVis expects 3D patterns.
#         )
#     )

# induction_head_layer = 3
# induction_head_index = 1
# model.run_with_hooks(
#     repeated_tokens,
#     repeated_context,
#     fwd_hooks=[(
#         'blocks.{}.cross_attn.hook_pattern'.format(induction_head_layer),
#         visualize_pattern_hook
#     )]
# )

# # Activation Patching

# # Activation Patching - Updated for Cross-Attention
# def activation_patching(
#     model: HookedCursiveTransformer,
#     x_clean: torch.Tensor,
#     c_clean: torch.Tensor,
#     x_corrupted: torch.Tensor,
#     c_corrupted: torch.Tensor,
#     patching_nodes: List[str],
#     patch_positions: Optional[torch.Tensor] = None,
# ):
#     _, cache_corrupted = model.run_with_cache(x_corrupted, c_corrupted, return_type="both")

#     def patching_hook(act, hook):
#         act_corrupted = cache_corrupted[hook.name]
#         if patch_positions is not None:
#             act[:, patch_positions, :] = act_corrupted[:, patch_positions, :]
#         else:
#             act[:] = act_corrupted
#         return act

#     hooks = [(node, patching_hook) for node in patching_nodes]
#     logits_patched = model.run_with_hooks(x_clean, c_clean, fwd_hooks=hooks, return_type="logits")
#     return logits_patched

# # Example Activation Patching
# x_clean = stroke_sequence.unsqueeze(0)
# c_clean = ascii_sequence.unsqueeze(0)

# x_corrupted = x_clean.clone()
# x_corrupted = (x_corrupted + 1) % model.cfg.d_vocab
# c_corrupted = c_clean.clone()

# patching_nodes = ['blocks.0.cross_attn.hook_result']
# logits_patched = activation_patching(
#     model,
#     x_clean,
#     c_clean,
#     x_corrupted,
#     c_corrupted,
#     patching_nodes,
# )

# # Get predictions from patched logits
# predictions_patched = logits_patched.argmax(dim=-1)

# # Run the clean input without patches
# logits_clean = model(x_clean, c_clean)
# predictions_clean = logits_clean.argmax(dim=-1)

# # Run the corrupted input without patches
# logits_corrupted = model(x_corrupted, c_corrupted)
# predictions_corrupted = logits_corrupted.argmax(dim=-1)

# # Compare predictions
# print("Clean Predictions:", predictions_clean)
# print("Corrupted Predictions:", predictions_corrupted)
# print("Patched Predictions:", predictions_patched)

# # Visualize the strokes if applicable
# _ = plot_strokes(offsets_to_strokes(test_dataset.decode_stroke(stroke_sequence)), test_dataset.decode_text(ascii_sequence))
# _ = plot_strokes(offsets_to_strokes(test_dataset.decode_stroke(x_corrupted[0])), test_dataset.decode_text(c_corrupted[0]))