<a href="https://colab.research.google.com/github/zwimpee/cursivetransformer/blob/main/HookedCursiveTransformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# HookedCursiveTransformer

https://github.com/zwimpee/cursivetransformer/issues/26

We will find this post particularly useful: https://arena3-chapter1-transformer-interp.streamlit.app/[1.1]_Transformer_from_Scratch
As well as this notebook: https://colab.research.google.com/drive/1bZkkJd8pAVnSN23svyszZ3f4WrnYKN_3?usp=sharing

# Setup

In [1]:
!pip install transformer_lens
!pip install gradio
!pip install wandb
!pip install einops
!pip install matplotlib
!pip install datasets

# Clone the cursivetransformer repository and install its requirements
!rm -rf cursivetransformer && git clone https://github.com/zwimpee/cursivetransformer.git
!pip install -r cursivetransformer/requirements.txt

# Login to Weights & Biases (replace 'your_api_key' with your actual API key)
import wandb
wandb.login()

import sys
sys.path.append('/content/cursivetransformer')  # Adjust the path if necessary

# Import cursivetransformer modules
from cursivetransformer.model import get_all_args, get_checkpoint
from cursivetransformer.data import create_datasets, offsets_to_strokes
from cursivetransformer.sample import generate, generate_n_words, plot_strokes

# Import TransformerLens modules

import dataclasses
import logging
import os
import re
from pathlib import Path
from typing import Dict, Optional, Union

from transformer_lens import HookedTransformer
import transformer_lens.utils as utils
from transformer_lens.HookedTransformerConfig import HookedTransformerConfig


# Import other necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
import pprint
import json
from datasets import load_dataset
from IPython.display import HTML, display
from functools import partial
import tqdm.notebook as tqdm
import matplotlib.pyplot as plt
import einops
from einops import rearrange


args = get_all_args(False)
args.sample_only = True
args.load_from_run_id = '6le6tujz'  # Replace with your actual run ID
args.wandb_entity = 'sam-greydanus'
args.dataset_name = 'bigbank'  # Replace with your dataset name
args.wandb_run_name = 'cursivetransformer_dictionary_learning'

torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)

train_dataset, test_dataset = create_datasets(args)

args.block_size = train_dataset.get_stroke_seq_length()
args.context_block_size = train_dataset.get_text_seq_length()
args.vocab_size = train_dataset.get_vocab_size()
args.context_vocab_size = train_dataset.get_char_vocab_size()

model, optimizer, scheduler, step, best_loss = get_checkpoint(args)

Collecting transformer_lens
  Downloading transformer_lens-2.7.0-py3-none-any.whl.metadata (12 kB)
Collecting beartype<0.15.0,>=0.14.1 (from transformer_lens)
  Downloading beartype-0.14.1-py3-none-any.whl.metadata (28 kB)
Collecting better-abc<0.0.4,>=0.0.3 (from transformer_lens)
  Downloading better_abc-0.0.3-py3-none-any.whl.metadata (1.4 kB)
Collecting datasets>=2.7.1 (from transformer_lens)
  Downloading datasets-3.0.1-py3-none-any.whl.metadata (20 kB)
Collecting fancy-einsum>=0.0.3 (from transformer_lens)
  Downloading fancy_einsum-0.0.3-py3-none-any.whl.metadata (1.2 kB)
Collecting jaxtyping>=0.2.11 (from transformer_lens)
  Downloading jaxtyping-0.2.34-py3-none-any.whl.metadata (6.4 kB)
Collecting wandb>=0.13.5 (from transformer_lens)
  Downloading wandb-0.18.2-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (9.7 kB)
Collecting pyarrow>=15.0.0 (from datasets>=2.7.1->transformer_lens)
  Downloading pyarrow-17.0.0-cp310-cp310-manylinux_2_28_x86_64.whl.metadata (

[34m[1mwandb[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


Trying to load dataset file from /content/cursivetransformer/data/bigbank.json.zip
Succeeded in loading the bigbank dataset; contains 1900 items.
For a dataset of 1805 examples we can generate 440811596555 combinations of 4 examples.
Generating 497000 4-word examples.
For a dataset of 95 examples we can generate 3183545 combinations of 4 examples.
Generating 3000 4-word examples.
Number of examples in the train dataset: 497000
Number of examples in the test dataset: 3000
Max token sequence length: 1000
Number of unique characters in the ascii vocabulary: 71
Ascii vocabulary:
	" enaitoshrdx.vpukbgfcymzw1lqj804I92637OTAS5N)EHR"'(BCQLMWYU,ZF!DXV?KPGJ"
Split up the dataset into 497000 training examples and 3000 test examples
Number of Transformer parameters: 368064
Model #params: 397184
Finding latest checkpoint for W&B run id 6le6tujz
  model:best_checkpoint:v70
  model:best_checkpoint:v71
  model:best_checkpoint:v72
  model:best_checkpoint:v73
  model:best_checkpoint:v74
  model:best_che

[34m[1mwandb[0m:   1 of 1 files downloaded.  


# Functions

In [48]:
def convert_cursivetransformer_model_config(args):
    cfg_dict = {
        "d_model": args.n_embd,
        "n_layers": args.n_layer,
        "d_mlp": args.n_embd * 4,  # Assuming MLP size is 4*d_model
        "d_head": args.n_embd // args.n_ctx_heads,
        "n_heads": args.n_ctx_heads,
        "n_ctx": args.max_seq_length,
        "d_vocab": args.vocab_size,
        "tokenizer_name": None,
        "act_fn": "gelu_new",
        "attn_only": False,
        "final_rms": False,
        "original_architecture": "cursivetransformer",
        "normalization_type": "LN",
        "init_weights": False,
        "device": args.device,
    }
    cfg = HookedTransformerConfig.from_dict(cfg_dict)
    return cfg
def fill_missing_keys(model, state_dict):
    """Takes in a state dict from a pretrained model, and fills in any missing keys with the default initialization.

    This function is assumed to be run before weights are initialized.

    Args:
        state_dict (dict): State dict from a pretrained model

    Returns:
        dict: State dict with missing keys filled in
    """
    # Get the default state dict
    default_state_dict = model.state_dict()
    # Get the keys that are missing from the pretrained model
    missing_keys = set(default_state_dict.keys()) - set(state_dict.keys())
    # Fill in the missing keys with the default initialization
    for key in missing_keys:
        if "hf_model" in key:
            # Skip keys that are from the HuggingFace model, if loading from HF.
            continue
        if "W_" in key:
            logging.warning(
                "Missing key for a weight matrix in pretrained, filled in with an empty tensor: {}".format(
                    key
                )
            )
        state_dict[key] = default_state_dict[key]
    return state_dict

def convert_cursivetransformer_weights(model, cfg: HookedTransformerConfig):
    state_dict = {}
    state_dict["embed.W_E"] = model.transformer.wte.weight
    state_dict["pos_embed.W_pos"] = model.transformer.wpe.weight

    for l in range(cfg.n_layers):
        print(f"Layer {l}:")
        state_dict[f"blocks.{l}.ln1.w"] = model.transformer.h[l].ln_1.weight
        state_dict[f"blocks.{l}.ln1.b"] = model.transformer.h[l].ln_1.bias

        # Fetch the weights
        W = model.transformer.h[l].attn.c_attn.weight
        # print(f"W shape: {W.shape}")

        # Correctly split W along dim=0 (out_features)
        W_Q, W_K, W_V = torch.tensor_split(W, 3, dim=0)
        # print(f"W_Q shape: {W_Q.shape}")

        total_head_dim = W_Q.shape[0]  # Number of rows in W_Q
        if total_head_dim % cfg.n_heads != 0:
            raise ValueError(f"Total head dimension {total_head_dim} is not divisible by n_heads {cfg.n_heads}")

        d_head = total_head_dim // cfg.n_heads
        cfg.d_head = d_head  # Update cfg if necessary
        # print(f"Using n_heads: {cfg.n_heads}, d_head: {cfg.d_head}")

        # Rearrangement
        W_Q = einops.rearrange(W_Q, "(i h) m -> i m h", i=cfg.n_heads)
        W_K = einops.rearrange(W_K, "(i h) m -> i m h", i=cfg.n_heads)
        W_V = einops.rearrange(W_V, "(i h) m -> i m h", i=cfg.n_heads)

        state_dict[f"blocks.{l}.attn.W_Q"] = W_Q
        state_dict[f"blocks.{l}.attn.W_K"] = W_K
        state_dict[f"blocks.{l}.attn.W_V"] = W_V

        # Biases
        qkv_bias = model.transformer.h[l].attn.c_attn.bias
        qkv_bias = torch.tensor_split(qkv_bias, 3)
        b_Q, b_K, b_V = qkv_bias
        b_Q = einops.rearrange(b_Q, "(i h) -> i h", i=cfg.n_heads)
        b_K = einops.rearrange(b_K, "(i h) -> i h", i=cfg.n_heads)
        b_V = einops.rearrange(b_V, "(i h) -> i h", i=cfg.n_heads)
        state_dict[f"blocks.{l}.attn.b_Q"] = b_Q
        state_dict[f"blocks.{l}.attn.b_K"] = b_K
        state_dict[f"blocks.{l}.attn.b_V"] = b_V

        # Output projection
        W_O = model.transformer.h[l].attn.c_proj.weight
        W_O = einops.rearrange(W_O, "m (i h) -> i h m", i=cfg.n_heads)
        state_dict[f"blocks.{l}.attn.W_O"] = W_O
        state_dict[f"blocks.{l}.attn.b_O"] = model.transformer.h[l].attn.c_proj.bias

        state_dict[f"blocks.{l}.ln2.w"] = model.transformer.h[l].ln_2.weight
        state_dict[f"blocks.{l}.ln2.b"] = model.transformer.h[l].ln_2.bias

        # MLP weights - Transpose W_in and W_out
        W_in = model.transformer.h[l].mlp.c_fc.weight
        state_dict[f"blocks.{l}.mlp.W_in"] = W_in.T  # Transpose
        state_dict[f"blocks.{l}.mlp.b_in"] = model.transformer.h[l].mlp.c_fc.bias

        W_out = model.transformer.h[l].mlp.c_proj.weight
        state_dict[f"blocks.{l}.mlp.W_out"] = W_out.T  # Transpose
        state_dict[f"blocks.{l}.mlp.b_out"] = model.transformer.h[l].mlp.c_proj.bias

    # Unembedding weights
    state_dict["unembed.W_U"] = model.lm_head.weight.T

    # Handle unembed.b_U
    if model.lm_head.bias is not None:
        state_dict["unembed.b_U"] = model.lm_head.bias
    else:
        # Initialize unembed.b_U to zeros
        state_dict["unembed.b_U"] = torch.zeros(cfg.d_vocab_out)

    state_dict["ln_final.w"] = model.transformer.ln_f.weight
    state_dict["ln_final.b"] = model.transformer.ln_f.bias

    state_dict = fill_missing_keys(model, state_dict)

    return state_dict

# Sandbox

In [55]:
cfg = convert_cursivetransformer_model_config(args)
state_dict = convert_cursivetransformer_weights(model, cfg)
hooked_model = HookedTransformer(cfg)
hooked_model.load_state_dict(state_dict, strict=False)
hooked_model.eval()

# Move input to the correct device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hooked_model.to(device)
sample_input = torch.randint(0, cfg.d_vocab, (1, 10)).to(device)

# Collect attention patterns
with torch.no_grad():
    logits, cache = hooked_model.run_with_cache(sample_input)

# Choose the layer and head to visualize
layer = 0  # Change to the desired layer index
head = 0   # Change to the desired head index

# Access attention patterns for the specified layer
attn_patterns = cache[f'blocks.{layer}.attn.hook_pattern']  # Shape: [batch_size, n_heads, seq_len, seq_len]

# Extract the attention pattern for the specified head
attn = attn_patterns[0, head].cpu().numpy()  # Shape: [seq_len, seq_len]

# Visualize the attention pattern
plt.imshow(attn, cmap='viridis')
plt.colorbar()
plt.title(f'Attention Pattern for Layer {layer}, Head {head}')
plt.xlabel('Key Positions')
plt.ylabel('Query Positions')
plt.show()

Layer 0:
W shape: torch.Size([192, 64])
W_Q shape: torch.Size([64, 64])
Using n_heads: 4, d_head: 16
Layer 1:
W shape: torch.Size([192, 64])
W_Q shape: torch.Size([64, 64])
Using n_heads: 4, d_head: 16
Layer 2:
W shape: torch.Size([192, 64])
W_Q shape: torch.Size([64, 64])
Using n_heads: 4, d_head: 16
Layer 3:
W shape: torch.Size([192, 64])
W_Q shape: torch.Size([64, 64])
Using n_heads: 4, d_head: 16
