<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/HookedCursiveTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HookedCursiveTransformer

# Setup

In [None]:
!pip install transformer_lens
!pip install gradio
!pip install wandb
!pip install einops
!pip install matplotlib
!pip install datasets

# Clone the cursivetransformer repository and install its requirements
!rm -rf cursivetransformer && git clone https://github.com/zwimpee/cursivetransformer.git
!pip install -r cursivetransformer/requirements.txt

# Login to Weights & Biases (replace 'your_api_key' with your actual API key)
import wandb
wandb.login()

import sys
sys.path.append('/content/cursivetransformer')  # Adjust the path if necessary

# Import cursivetransformer modules
from cursivetransformer.model import get_all_args, get_checkpoint
from cursivetransformer.data import create_datasets, offsets_to_strokes
from cursivetransformer.sample import generate, generate_n_words, plot_strokes

# Import TransformerLens modules

import dataclasses
import logging
import os
import re
from pathlib import Path
from typing import Dict, Optional, Union

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens.components import Attention, MLP
import transformer_lens.utils as utils
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


# Import other necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
import pprint
import json
from datasets import load_dataset
from IPython.display import HTML, display
from functools import partial
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import math
import einops
from einops import rearrange, einsum

# Move input to the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


args = get_all_args(False)
args.sample_only = True
args.load_from_run_id = '6le6tujz'  # Replace with your actual run ID
args.wandb_entity = 'sam-greydanus'
args.dataset_name = 'bigbank'  # Replace with your dataset name
args.wandb_run_name = 'cursivetransformer_dictionary_learning'

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

train_dataset, test_dataset = create_datasets(args)

args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.vocab_size = train_dataset.get_vocab_size()
args.context_vocab_size = train_dataset.get_char_vocab_size()

model, optimizer, scheduler, step, best_loss = get_checkpoint(args)

# Classes

In [None]:
class HookedCursiveTransformerConfig(HookedTransformerConfig):
    def __init__(self, **kwargs):
        # Extract custom arguments
        self.d_model_c = kwargs.pop('d_model_c', None)
        self.n_ctx_c = kwargs.pop('n_ctx_c', None)
        self.d_vocab_c = kwargs.pop('d_vocab_c', None)
        self.use_cross_attention = kwargs.pop('use_cross_attention', True)

        # Now, call the superclass constructor with the remaining kwargs
        super().__init__(**kwargs)

        # Set default values if necessary
        if self.d_model_c is None:
            self.d_model_c = self.d_model
        if self.n_ctx_c is None:
            self.n_ctx_c = self.n_ctx
        if self.d_vocab_c is None:
            self.d_vocab_c = self.d_vocab

# - [ ] TODO: Change this to inherit from AbstractAttention: https://github.com/TransformerLensOrg/TransformerLens/blob/main/transformer_lens/components/abstract_attention.py
class HookedAttention(nn.Module):
    def __init__(self, cfg, is_cross_attention=False, layer_idx=None):
        super().__init__()
        self.cfg = cfg
        self.is_cross_attention = is_cross_attention
        self.layer_idx = layer_idx
        self.W_Q = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        if is_cross_attention:
            self.W_K = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model_c, cfg.d_head))
            self.W_V = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model_c, cfg.d_head))
        else:
            self.W_K = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
            self.W_V = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_model, cfg.d_head))
        self.W_O = nn.Parameter(torch.empty(cfg.n_heads, cfg.d_head, cfg.d_model))
        # Biases
        self.b_Q = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.b_K = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.b_V = nn.Parameter(torch.zeros(cfg.n_heads, cfg.d_head))
        self.b_O = nn.Parameter(torch.zeros(cfg.d_model))
        # Register hooks if needed

    def forward(self, x, context=None, **kwargs):
        # x: [batch, seq_len, d_model]
        if self.is_cross_attention:
            if context is None:
                raise ValueError("Context must be provided for cross-attention")
            k_input = context
            v_input = context
        else:
            k_input = x
            v_input = x

        # Compute queries, keys, values
        q = einsum(x, self.W_Q, 'b t d, h d e -> b h t e') + self.b_Q  # [batch, n_heads, seq_len, d_head]
        k = einsum(k_input, self.W_K, 'b s d, h d e -> b h s e') + self.b_K  # [batch, n_heads, seq_len_k, d_head]
        v = einsum(v_input, self.W_V, 'b s d, h d e -> b h s e') + self.b_V  # [batch, n_heads, seq_len_v, d_head]

        # Compute attention scores
        attn_scores = einsum(q, k, 'b h t e, b h s e -> b h t s') / math.sqrt(self.cfg.d_head)

        # Apply causal mask if needed (for self-attention)
        if not self.is_cross_attention and self.cfg.attn_only:
            attn_scores = self.apply_causal_mask(attn_scores)

        # Apply softmax to get attention probabilities
        attn_probs = nn.functional.softmax(attn_scores, dim=-1)

        # Save attention patterns for analysis
        if kwargs.get('cache', None) is not None:
            cache = kwargs['cache']
            hook_name = f'blocks.{self.layer_idx}.{"cross_attn" if self.is_cross_attention else "attn"}.hook_pattern'
            cache[hook_name] = attn_probs.detach()

        # Compute attention output
        attn_output = einsum(attn_probs, v, 'b h t s, b h s e -> b h t e')
        attn_output = einsum(attn_output, self.W_O, 'b h t e, h e d -> b t d') + self.b_O  # [batch, seq_len, d_model]

        return attn_output

class CursiveTransformerBlock(nn.Module):
    def __init__(self, cfg, layer_idx):
        super().__init__()
        self.cfg = cfg
        self.layer_idx = layer_idx

        # Layer norms
        self.ln1 = nn.LayerNorm(cfg.d_model)
        self.ln2 = nn.LayerNorm(cfg.d_model)
        self.ln3 = nn.LayerNorm(cfg.d_model)

        # Self-attention using Attention
        self.attn = HookedAttention(cfg, layer_idx=layer_idx)

        # Cross-attention using modified HookedAttention
        self.cross_attn = HookedAttention(cfg, is_cross_attention=True, layer_idx=layer_idx)

        # MLP
        self.mlp = MLP(cfg)

    def forward(self, x, c, **kwargs):
        # Self-attention
        x = x + self.attn(self.ln1(x), **kwargs)

        # Cross-attention
        x = x + self.cross_attn(self.ln2(x), context=c, **kwargs)

        # MLP
        x = x + self.mlp(self.ln3(x), **kwargs)

        return x

class HookedCursiveTransformer(HookedTransformer):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.cfg = cfg

        # Embedding layers for the context input
        self.embed_c = nn.Embedding(cfg.d_vocab_c, cfg.d_model_c)
        self.pos_embed_c = nn.Embedding(cfg.n_ctx_c, cfg.d_model_c)

        # Override the transformer blocks with custom blocks that include cross-attention
        self.blocks = nn.ModuleList([
            CursiveTransformerBlock(self.cfg, layer_idx)
            for layer_idx in range(self.cfg.n_layers)
        ])

        # Update the final layer norm to match your model
        self.ln_final = nn.LayerNorm(self.cfg.d_model)

    def forward(self, x, c, return_type="logits", **kwargs):
        # Embedding and positional encoding for x
        x_tokens = self.embed(x)  # [batch, seq_len, d_model]
        x_positions = self.pos_embed(torch.arange(x_tokens.size(1), device=x_tokens.device))  # [seq_len, d_model]
        x = x_tokens + x_positions

        # Embedding and positional encoding for c
        c_tokens = self.embed_c(c)  # [batch, context_len, d_model_c]
        c_positions = self.pos_embed_c(torch.arange(c_tokens.size(1), device=c_tokens.device))  # [context_len, d_model_c]
        c = c_tokens + c_positions

        # Pass through transformer blocks
        for block in self.blocks:
            x = block(x, c, **kwargs)

        x = self.ln_final(x)
        logits = self.unembed(x)

        if return_type == "logits":
            return logits
        elif return_type == "loss":
            targets = kwargs.get('targets')
            if targets is None:
                raise ValueError("Targets must be provided when return_type is 'loss'")
            loss = nn.functional.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)
            return loss
        elif return_type == "both":
            return logits, kwargs.get('cache', None)
        else:
            raise ValueError(f"Invalid return_type: {return_type}")

    def apply_causal_mask(self, attn_scores):
        # Implement causal masking logic here if needed
        return attn_scores

# Functions

In [None]:
def convert_cursivetransformer_model_config(args):
    cfg_dict = {
        # Standard parameters
        "d_model": args.n_embd,
        "n_layers": args.n_layer,
        "d_mlp": args.n_embd * 4,
        "d_head": args.n_embd // args.n_ctx_head,
        "n_heads": args.n_ctx_head,
        "n_ctx": args.max_seq_length,
        "d_vocab": args.vocab_size,
        "tokenizer_name": None,
        "act_fn": "gelu_new",
        "attn_only": False,
        "final_rms": False,
        "original_architecture": "cursivetransformer",
        "normalization_type": "LN",
        "init_weights": False,
        "device": args.device,
        # Additional parameters for cross-attention
        "d_model_c": args.n_embd2,
        "n_ctx_c": args.context_block_size,
        "d_vocab_c": args.context_vocab_size,
        "use_cross_attention": True,
    }
    cfg = HookedCursiveTransformerConfig.from_dict(cfg_dict)
    return cfg


def fill_missing_keys(model, state_dict):
    """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.

    This function is assumed to be run before weights are initialized.

    Args:
        state_dict (dict): State dict from a pretrained model

    Returns:
        dict: State dict with missing keys filled in
    """
    # Get the default state dict
    default_state_dict = model.state_dict()
    # Get the keys that are missing from the pretrained model
    missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
    # Fill in the missing keys with the default initialization
    for key in missing_keys:
        if "hf_model" in key:
            # Skip keys that are from the HuggingFace model, if loading from HF.
            continue
        if "W_" in key:
            logging.warning(
                "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
                    key
                )
            )
        state_dict[key] = default_state_dict[key]
    return state_dict

def convert_cursivetransformer_weights(cursivetransformer, cfg):
    state_dict = {}
    # Embeddings
    state_dict["embed.W_E"] = cursivetransformer.transformer.wte.weight
    state_dict["pos_embed.W_pos"] = cursivetransformer.transformer.wpe.weight
    state_dict["embed_c.weight"] = cursivetransformer.transformer.wce.weight
    state_dict["pos_embed_c.weight"] = cursivetransformer.transformer.wcpe.weight

    for l in range(cfg.n_layers):
        block = cursivetransformer.transformer.h[l]

        # Layer Norms
        state_dict[f'blocks.{l}.ln1.weight'] = block.ln_1.weight
        state_dict[f'blocks.{l}.ln1.bias'] = block.ln_1.bias
        state_dict[f'blocks.{l}.ln2.weight'] = block.ln_2.weight
        state_dict[f'blocks.{l}.ln2.bias'] = block.ln_2.bias
        state_dict[f'blocks.{l}.ln3.weight'] = block.ln_3.weight
        state_dict[f'blocks.{l}.ln3.bias'] = block.ln_3.bias

        # Self-Attention Weights
        self_attn = block.attn
        W_self = self_attn.c_attn.weight  # Shape: [3 * d_model, d_model]
        b_self = self_attn.c_attn.bias    # Shape: [3 * d_model]

        # Split weights and biases
        W_Q_self, W_K_self, W_V_self = torch.chunk(W_self, 3, dim=0)
        b_Q_self, b_K_self, b_V_self = torch.chunk(b_self, 3, dim=0)

        # Reshape and assign
        W_Q_self = W_Q_self.T.reshape(cfg.n_heads, cfg.d_model, cfg.d_head)
        W_K_self = W_K_self.T.reshape(cfg.n_heads, cfg.d_model, cfg.d_head)
        W_V_self = W_V_self.T.reshape(cfg.n_heads, cfg.d_model, cfg.d_head)
        state_dict[f'blocks.{l}.attn.W_Q'] = W_Q_self
        state_dict[f'blocks.{l}.attn.W_K'] = W_K_self
        state_dict[f'blocks.{l}.attn.W_V'] = W_V_self
        state_dict[f'blocks.{l}.attn.b_Q'] = b_Q_self.reshape(cfg.n_heads, cfg.d_head)
        state_dict[f'blocks.{l}.attn.b_K'] = b_K_self.reshape(cfg.n_heads, cfg.d_head)
        state_dict[f'blocks.{l}.attn.b_V'] = b_V_self.reshape(cfg.n_heads, cfg.d_head)

        # Self-Attention Output Projection
        W_O_self = self_attn.c_proj.weight  # Shape: [d_model, d_model]
        b_O_self = self_attn.c_proj.bias    # Shape: [d_model]
        W_O_self = W_O_self.T.reshape(cfg.n_heads, cfg.d_head, cfg.d_model)
        state_dict[f'blocks.{l}.attn.W_O'] = W_O_self
        state_dict[f'blocks.{l}.attn.b_O'] = b_O_self

        # Cross-Attention Weights
        cross_attn = block.cross_attn
        W_Q_cross = cross_attn.c_attn_q.weight  # Shape: [d_model, d_model]
        b_Q_cross = cross_attn.c_attn_q.bias    # Shape: [d_model]
        W_KV_cross = cross_attn.c_attn_kv.weight  # Shape: [2 * d_model_c, d_model_c]
        b_KV_cross = cross_attn.c_attn_kv.bias    # Shape: [2 * d_model_c]

        # Split KV weights and biases
        W_K_cross, W_V_cross = torch.chunk(W_KV_cross, 2, dim=0)
        b_K_cross, b_V_cross = torch.chunk(b_KV_cross, 2, dim=0)

        # Reshape and assign
        W_Q_cross = W_Q_cross.T.reshape(cfg.n_heads, cfg.d_model, cfg.d_head)
        W_K_cross = W_K_cross.T.reshape(cfg.n_heads, cfg.d_model_c, cfg.d_head)
        W_V_cross = W_V_cross.T.reshape(cfg.n_heads, cfg.d_model_c, cfg.d_head)
        state_dict[f'blocks.{l}.cross_attn.W_Q'] = W_Q_cross
        state_dict[f'blocks.{l}.cross_attn.W_K'] = W_K_cross
        state_dict[f'blocks.{l}.cross_attn.W_V'] = W_V_cross
        state_dict[f'blocks.{l}.cross_attn.b_Q'] = b_Q_cross.reshape(cfg.n_heads, cfg.d_head)
        state_dict[f'blocks.{l}.cross_attn.b_K'] = b_K_cross.reshape(cfg.n_heads, cfg.d_head)
        state_dict[f'blocks.{l}.cross_attn.b_V'] = b_V_cross.reshape(cfg.n_heads, cfg.d_head)

        # Cross-Attention Output Projection
        W_O_cross = cross_attn.c_proj.weight  # Shape: [d_model, d_model]
        b_O_cross = cross_attn.c_proj.bias    # Shape: [d_model]
        W_O_cross = W_O_cross.T.reshape(cfg.n_heads, cfg.d_head, cfg.d_model)
        state_dict[f'blocks.{l}.cross_attn.W_O'] = W_O_cross
        state_dict[f'blocks.{l}.cross_attn.b_O'] = b_O_cross

        # MLP Weights
        mlp = block.mlp
        W_in = mlp.c_fc.weight  # Shape: [4 * d_model, d_model]
        b_in = mlp.c_fc.bias    # Shape: [4 * d_model]
        W_out = mlp.c_proj.weight  # Shape: [d_model, 4 * d_model]
        b_out = mlp.c_proj.bias    # Shape: [d_model]

        # Transpose and assign
        state_dict[f'blocks.{l}.mlp.W_in'] = W_in.T
        state_dict[f'blocks.{l}.mlp.b_in'] = b_in
        state_dict[f'blocks.{l}.mlp.W_out'] = W_out.T
        state_dict[f'blocks.{l}.mlp.b_out'] = b_out

    # Unembedding Weights
    state_dict["unembed.W_U"] = cursivetransformer.lm_head.weight.T
    if cursivetransformer.lm_head.bias is not None:
        state_dict["unembed.b_U"] = cursivetransformer.lm_head.bias
    else:
        state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab_out)

    # Final Layer Norm
    state_dict["ln_final.weight"] = cursivetransformer.transformer.ln_f.weight
    state_dict["ln_final.bias"] = cursivetransformer.transformer.ln_f.bias

    return state_dict

# Sandbox

In [None]:
# Ensure the configuration has the correct maximum sequence lengths
cfg = convert_cursivetransformer_model_config(args)
state_dict = convert_cursivetransformer_weights(model, cfg)
hooked_model = HookedCursiveTransformer(cfg)
hooked_model.load_state_dict(state_dict)
hooked_model.to(device)

# Add batch dimension to your inputs
x, c, y = test_dataset[0]
x = x.unsqueeze(0).to(device)  # Shape: [1, 1000]
c = c.unsqueeze(0).to(device)  # Shape: [1, 50]
y = y.unsqueeze(0).to(device) # Shape: [1, 1000]

print(x.shape)  # torch.Size([1, 1000])
print(c.shape)  # torch.Size([1, 50])
print(y.shape)  # torch.Size([1, 1000])

# Run the model with cache
hooked_model.eval()
with torch.no_grad():
    logits, cache = hooked_model(x, c, return_type="both")

# Choose the layer and head to visualize
layer = 0  # Change to the desired layer index
head = 0   # Change to the desired head index

# Access attention patterns for the specified layer
attn_patterns = cache[f'blocks.{layer}.attn.hook_pattern']  # Shape: [batch_size, n_heads, seq_len, seq_len]

# Extract the attention pattern for the specified head and sample
attn = attn_patterns[0, head].cpu().numpy()  # Shape: [seq_len, seq_len]

# Visualize the attention pattern
import matplotlib.pyplot as plt

plt.figure(figsize=(8, 6))
plt.imshow(attn, cmap='viridis', aspect='auto')
plt.colorbar()
plt.title(f'Self-Attention Pattern for Layer {layer}, Head {head}')
plt.xlabel('Key Positions')
plt.ylabel('Query Positions')
plt.show()