<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/induction_heads.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Detecting Induction Heads in `cursivetransformer`

In this notebook, we aim to detect **induction heads** in the `cursivetransformer` model. We'll adapt methods from Neel Nanda's work on reverse-engineering induction circuits to our model, which is trained to generate cursive handwriting. Our goal is to identify components within the model that perform induction-like mechanisms, specifically focusing on how the model might use previous tokens to inform the generation of subsequent ones.

---

## **Table of Contents**

1. [Introduction to Induction Heads](#introduction)
2. [Setting Up the Environment](#setup)
3. [Loading the Model and Data](#loading)
4. [Analyzing Attention Patterns](#attention)
5. [Identifying Candidate Induction Heads](#candidate_heads)
6. [Reverse-Engineering the Induction Circuit](#reverse_engineering)
    - [QK and OV Circuits](#qk_ov)
    - [Analyzing the QK Circuit](#qk_analysis)
    - [Analyzing the OV Circuit](#ov_analysis)
7. [Computing Composition Scores](#composition_scores)
8. [Visualizing and Interpreting Results](#visualization)
9. [Conclusion](#conclusion)

---

<a name="introduction"></a>
## 1. Introduction to Induction Heads

**Induction heads** are specialized attention heads in transformer models that enable the model to perform **in-context learning**. They allow the model to recognize and replicate patterns by attending to previous occurrences of tokens in the sequence.

In the context of the `cursivetransformer`, which generates cursive handwriting, induction heads may help the model maintain consistent handwriting styles or replicate patterns in stroke sequences.

**Our Objectives:**

- Identify attention heads in the `cursivetransformer` that function as induction heads.
- Reverse-engineer the circuits involved to understand how they operate.
- Analyze the QK (query-key) and OV (output-value) circuits to see how they contribute to induction behavior.

---

<a name="setup"></a>
## 2. Setting Up the Environment

First, let's import necessary libraries and set up the environment.

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
import tqdm

# If using Jupyter, uncomment the next line
# %matplotlib inline

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Check for GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
```

---

<a name="loading"></a>
## 3. Loading the Model and Data

We'll load the `cursivetransformer` model and the associated dataset.

**Note:** Ensure that you have the model and dataset available in your environment. Adjust paths as necessary.

```python
# Assume the model and dataset classes are defined in cursivetransformer.py
# You may need to adjust the import statements based on your project structure

from cursivetransformer import CursiveTransformer, CursiveDataset

# Load the model
model = CursiveTransformer.load_from_checkpoint('path_to_model_checkpoint.ckpt')
model.to(device)
model.eval()  # Set model to evaluation mode

# Load the dataset
train_dataset = CursiveDataset(split='train')
test_dataset = CursiveDataset(split='test')

# Define data loaders if necessary
from torch.utils.data import DataLoader

batch_size = 64  # Adjust as needed
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
```

---

<a name="attention"></a>
## 4. Analyzing Attention Patterns

To detect induction heads, we'll start by analyzing the attention patterns of each head in the model. Induction heads typically display a **distinctive attention pattern**: they attend from a token to previous occurrences of the same token or related tokens.

### Extracting Attention Patterns

We'll define a function to extract and visualize attention patterns for a given layer and head.

```python
def get_attention_patterns(model, data_loader, layer_idx, head_idx, num_batches=1):
    attention_patterns = []

    def hook_fn(module, input, output):
        # Output shape: [batch_size, num_heads, seq_len, seq_len]
        attention_patterns.append(output.detach().cpu())

    hook_handle = model.transformer.layers[layer_idx].self_attn.attn_drop.register_forward_hook(hook_fn)

    # Run a few batches to collect attention patterns
    for i, batch in enumerate(data_loader):
        if i >= num_batches:
            break
        inputs, _ = batch  # Adjust unpacking based on your dataset
        inputs = inputs.to(device)
        with torch.no_grad():
            _ = model(inputs)
    hook_handle.remove()

    # Concatenate attention patterns from all batches
    attention_patterns = torch.cat(attention_patterns, dim=0)  # Shape: [total_samples, num_heads, seq_len, seq_len]

    # Extract patterns for the specified head
    patterns = attention_patterns[:, head_idx, :, :]  # Shape: [total_samples, seq_len, seq_len]

    return patterns
```

### Visualizing Attention Patterns

We'll plot the average attention pattern for the specified head.

```python
def plot_attention_pattern(patterns, layer_idx, head_idx):
    # Compute the average attention pattern over samples
    avg_pattern = patterns.mean(dim=0)  # Shape: [seq_len, seq_len]

    fig = px.imshow(
        avg_pattern,
        labels={'x': 'Key Position', 'y': 'Query Position', 'color': 'Attention Weight'},
        title=f'Average Attention Pattern for Layer {layer_idx}, Head {head_idx}',
        color_continuous_scale='Blues'
    )
    fig.update_layout(width=600, height=600)
    fig.show()
```

### Example Usage

Let's analyze and plot attention patterns for all heads in a specific layer.

```python
layer_idx = 0  # Adjust as needed
num_heads = model.transformer.layers[layer_idx].self_attn.num_heads

for head_idx in range(num_heads):
    patterns = get_attention_patterns(model, test_loader, layer_idx, head_idx, num_batches=1)
    plot_attention_pattern(patterns, layer_idx, head_idx)
```

**Interpretation:**

- Look for heads where the attention pattern shows a diagonal offset, indicating that the head is attending to previous tokens in a specific way.
- Heads that show such patterns are candidates for being induction heads.

---

<a name="candidate_heads"></a>
## 5. Identifying Candidate Induction Heads

Based on the attention patterns, we can identify candidate heads that might be functioning as induction heads.

**Example:**

Suppose we observe that head 3 in layer 0 shows a diagonal attention pattern with an offset, suggesting it might be an induction head.

---

<a name="reverse_engineering"></a>
## 6. Reverse-Engineering the Induction Circuit

To confirm that a head is an induction head, we'll analyze its **QK** and **OV** circuits. This involves examining the weight matrices and understanding how they contribute to the attention mechanism.

<a name="qk_ov"></a>
### Understanding QK and OV Circuits

- **QK Circuit:** Determines the attention scores by projecting queries and keys.
- **OV Circuit:** Determines how information is aggregated and written back to the residual stream.

---

<a name="qk_analysis"></a>
### Analyzing the QK Circuit

#### Extracting Q and K Matrices

```python
def get_QK_matrices(model, layer_idx, head_idx):
    W_Q = model.transformer.layers[layer_idx].self_attn.q_proj.weight  # Shape: [d_model, d_head * num_heads]
    W_K = model.transformer.layers[layer_idx].self_attn.k_proj.weight  # Shape: [d_model, d_head * num_heads]

    d_head = model.transformer.layers[layer_idx].self_attn.head_dim
    W_Q = W_Q[:, head_idx * d_head: (head_idx + 1) * d_head]  # Shape: [d_model, d_head]
    W_K = W_K[:, head_idx * d_head: (head_idx + 1) * d_head]

    return W_Q, W_K
```

#### Computing the QK Circuit

The QK circuit is given by \( W_Q W_K^T \).

```python
def compute_QK_circuit(W_Q, W_K):
    return W_Q @ W_K.T  # Shape: [d_model, d_model]
```

#### Visualizing the QK Circuit

We can visualize the QK circuit to understand what features the head is focusing on.

```python
def plot_QK_circuit(QK_circuit):
    plt.figure(figsize=(8, 6))
    plt.imshow(QK_circuit.detach().cpu().numpy(), cmap='viridis')
    plt.colorbar()
    plt.title('QK Circuit')
    plt.xlabel('Key Features')
    plt.ylabel('Query Features')
    plt.show()
```

#### Example Usage

```python
W_Q, W_K = get_QK_matrices(model, layer_idx=0, head_idx=3)
QK_circuit = compute_QK_circuit(W_Q, W_K)
plot_QK_circuit(QK_circuit)
```

**Interpretation:**

- Look for patterns that indicate the head is attending to previous tokens or specific features that align with induction behavior.

---

<a name="ov_analysis"></a>
### Analyzing the OV Circuit

#### Extracting V and O Matrices

```python
def get_VO_matrices(model, layer_idx, head_idx):
    W_V = model.transformer.layers[layer_idx].self_attn.v_proj.weight  # Shape: [d_model, d_head * num_heads]
    W_O = model.transformer.layers[layer_idx].self_attn.out_proj.weight  # Shape: [d_model, d_head * num_heads]

    d_head = model.transformer.layers[layer_idx].self_attn.head_dim
    W_V = W_V[:, head_idx * d_head: (head_idx + 1) * d_head]  # Shape: [d_model, d_head]
    W_O = W_O[head_idx * d_head: (head_idx + 1) * d_head, :]  # Shape: [d_head, d_model]

    return W_V, W_O
```

#### Computing the OV Circuit

The OV circuit is given by \( W_V W_O \).

```python
def compute_OV_circuit(W_V, W_O):
    return W_V @ W_O  # Shape: [d_model, d_model]
```

#### Visualizing the OV Circuit

```python
def plot_OV_circuit(OV_circuit):
    plt.figure(figsize=(8, 6))
    plt.imshow(OV_circuit.detach().cpu().numpy(), cmap='coolwarm')
    plt.colorbar()
    plt.title('OV Circuit')
    plt.xlabel('Output Features')
    plt.ylabel('Value Features')
    plt.show()
```

#### Example Usage

```python
W_V, W_O = get_VO_matrices(model, layer_idx=0, head_idx=3)
OV_circuit = compute_OV_circuit(W_V, W_O)
plot_OV_circuit(OV_circuit)
```

**Interpretation:**

- The OV circuit shows how the head writes information back to the residual stream.
- An induction head often copies information from the key positions to the query positions.

---

<a name="composition_scores"></a>
## 7. Computing Composition Scores

Composition scores help us quantify how much two components (e.g., heads) are interacting or composing with each other.

### Defining the Composition Score

We can define the composition score between two matrices \( W_A \) and \( W_B \) as:

\[ \text{Comp\_Score} = \frac{\| W_A W_B \|_F}{\| W_A \|_F \| W_B \|_F} \]

where \( \| \cdot \|_F \) denotes the Frobenius norm.

### Computing Composition Scores Between Heads

Suppose we suspect that head 3 in layer 0 is composing with head 5 in layer 1.

```python
def get_composition_score(W_A, W_B):
    numerator = torch.norm(W_A @ W_B, p='fro')
    denominator = torch.norm(W_A, p='fro') * torch.norm(W_B, p='fro')
    return (numerator / denominator).item()
```

#### Example Usage

```python
# Get OV circuit of head 3 in layer 0
W_V0, W_O0 = get_VO_matrices(model, layer_idx=0, head_idx=3)
W_OV0 = W_V0 @ W_O0  # Shape: [d_model, d_model]

# Get QK circuit of head 5 in layer 1
W_Q1, W_K1 = get_QK_matrices(model, layer_idx=1, head_idx=5)
W_QK1 = W_Q1 @ W_K1.T  # Shape: [d_model, d_model]

# Compute composition score
comp_score = get_composition_score(W_OV0, W_QK1)
print(f"Composition Score between Layer 0 Head 3 and Layer 1 Head 5: {comp_score:.4f}")
```

**Interpreting the Score:**

- A higher composition score suggests that the two heads are interacting significantly.
- We can compute composition scores for all pairs of heads to identify significant interactions.

### Computing Composition Scores for All Head Pairs

```python
num_layers = len(model.transformer.layers)
num_heads = model.transformer.layers[0].self_attn.num_heads

# Initialize a matrix to hold composition scores
composition_scores = np.zeros((num_heads, num_heads))

for head_idx0 in range(num_heads):
    W_V0, W_O0 = get_VO_matrices(model, layer_idx=0, head_idx=head_idx0)
    W_OV0 = W_V0 @ W_O0  # Shape: [d_model, d_model]

    for head_idx1 in range(num_heads):
        W_Q1, W_K1 = get_QK_matrices(model, layer_idx=1, head_idx=head_idx1)
        W_QK1 = W_Q1 @ W_K1.T  # Shape: [d_model, d_model]

        # Compute composition score
        comp_score = get_composition_score(W_OV0, W_QK1)
        composition_scores[head_idx0, head_idx1] = comp_score

# Plot the composition scores
plt.figure(figsize=(8, 6))
plt.imshow(composition_scores, cmap='plasma')
plt.colorbar()
plt.title('Composition Scores Between Layer 0 and Layer 1 Heads')
plt.xlabel('Layer 1 Heads')
plt.ylabel('Layer 0 Heads')
plt.show()
```

**Interpretation:**

- Look for pairs with high composition scores.
- These pairs are candidates for being involved in induction circuits.

---

<a name="visualization"></a>
## 8. Visualizing and Interpreting Results

### Attention Pattern Visualization

Revisiting the attention patterns, we can overlay the positions where high attention occurs with the tokens in the sequence.

Suppose we have stroke sequences and corresponding context text; we can map attention weights back to the strokes.

**Note:** The actual implementation would depend on how strokes and tokens are represented in your model.

### Analyzing Specific Sequences

We can select specific sequences where induction behavior is expected, such as sequences where certain patterns repeat.

```python
# Example: Select a sequence with repeating patterns
sequence_idx = 0  # Adjust as needed
inputs, targets = test_dataset[sequence_idx]
inputs = inputs.unsqueeze(0).to(device)

# Run the model and extract attention patterns
def extract_attention_patterns(module, input, output):
    attention_patterns.append(output.detach().cpu())

attention_patterns = []
hook_handle = model.transformer.layers[0].self_attn.attn_drop.register_forward_hook(extract_attention_patterns)

with torch.no_grad():
    outputs = model(inputs)
hook_handle.remove()

# Process and visualize the attention patterns as before
```

---

<a name="conclusion"></a>
## 9. Conclusion

By adapting methods from Neel Nanda's notebook, we've attempted to detect induction heads in the `cursivetransformer` model. Through analyzing attention patterns, computing QK and OV circuits, and calculating composition scores, we've identified candidate heads that may function as induction heads.

**Next Steps:**

- **Validate Findings:** Perform further experiments to confirm the induction behavior, such as ablation studies.
- **Interpret Features:** Investigate what specific features the induction heads are focusing on, possibly relating to specific strokes or handwriting styles.
- **Extend Analysis:** Apply similar methods to other components or layers of the model to gain deeper insights.

---

**References:**

- Neel Nanda's work on reverse-engineering induction circuits.
- [Transformer Circuits Thread](https://transformer-circuits.pub/2021/framework/index.html)

---

**Note:** The actual implementation details may vary based on the specific architecture and data representations used in the `cursivetransformer` model. Adjustments may be necessary to accommodate differences.