In [3]:
%load_ext autoreload
%autoreload 2


In [4]:
from typing import Iterable

In [54]:
import rich
import rich.table

# HookedTransformer

* [TransformerLens - Tutorial - Trains HookedTransformer from Scratch](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb)

```python
import transformers

# note: it's probably easier to just operate on tokens outside of the model,
#       that'll also make it clearer where tokenizer is used
#
# okay wrapping a pretrained tokenizer *can* be done:
# - https://huggingface.co/learn/nlp-course/chapter6/8#building-a-bpe-tokenizer-from-scratch
# - but none of the models support just naive encoding
#   - https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.BPE
class HookedTransformer:
    cfg: HookedTransformerConfig

    # note: actually does an `isinstance` check in the constructor
    tokenizer: transformers.PreTrainedTokenizerBase | None
```

In [5]:
import transformer_lens

from jaxtyping import Int64, Float32

import numpy as np
import plotly.express as px
import plotly.io as pio

import torch
import torch.utils.data

In [6]:
# plotting code copied over from transformer_lens tutorial notebook


def line(tensor: torch.Tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor: torch.Tensor, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

In [7]:
device = transformer_lens.utils.get_device()

print(f"Using device: {device}")

Using device: mps


### Setup Sample Generator

In [144]:
import string
import itertools
import more_itertools


class SpecialToken:
    # note: as assume a BOS token because transformerlens expects it
    BOS = "<"
    # we use a EOS token for convenience
    EOS = ">"


# note: without length, the model doesn't need to learn induction heads, just directly copies


# TODO(bschoen): Allow this to generalize in the future
#
# Good for purely attention, since seeing patterns
def generate_sample_palindrome_then_repeated() -> Iterable[str]:
    """Generate palindrom samples like `<abc|cba|abc>`."""

    # Generate all combinations of lowercase letters
    characters = string.ascii_lowercase

    # note: chosen arbitrarily
    lengths = [2, 3, 4, 5, 6, 7]

    # pad to max length
    max_length = (
        1 + max(lengths) + 1 + max(lengths) + 1 + max(lengths) + 1 + max(lengths) + 1
    )

    # set max number to take of each length
    max_combinations_per_length = 10000

    for length in lengths:

        for combination_index, combination in enumerate(
            itertools.product(characters, repeat=length)
        ):

            if combination_index > max_combinations_per_length:
                break

            combination_str = "".join(combination)
            reversed_str = "".join(reversed(combination_str))

            sample = (
                SpecialToken.BOS
                + combination_str
                + "|"
                + reversed_str
                + "|"
                + combination_str
                + SpecialToken.EOS
            )

            # Pad the sample to max_length with EOS tokens
            padded_sample = sample.ljust(max_length, SpecialToken.EOS)

            yield padded_sample  # Return the padded sample


# TODO(bschoen): For this do we get like a "next biggest" head?
# TODO(bschoen): Can we do circuit analysis on this?
def generate_sample_sorted() -> Iterable[str]:
    """Generate sequence sorted `<cab|abc>`."""

    # Generate all combinations of lowercase letters
    characters = string.ascii_lowercase

    # note: chosen arbitrarily
    # lengths = [3, 4, 5, 6, 7]
    lengths = [2, 3, 4, 5]  # , 6, 7]

    # pad to max length
    max_length = 1 + max(lengths) + 1 + max(lengths) + 1

    # set max number to take of each length
    # max_combinations_per_length = 10000

    for length in lengths:

        for combination_index, combination in enumerate(
            itertools.product(characters, repeat=length)
        ):

            # if combination_index > max_combinations_per_length:
            #    break

            combination_str = "".join(combination)
            sorted_str = "".join(sorted(combination_str))

            sample = (
                SpecialToken.BOS + combination_str + "|" + sorted_str + SpecialToken.EOS
            )

            # Pad the sample to max_length with EOS tokens
            padded_sample = sample.ljust(max_length, SpecialToken.EOS)

            yield padded_sample  # Return the padded sample


generate_sample = generate_sample_sorted

# show a few examples
[x for x in more_itertools.take(10, generate_sample())]

['<aa|aa>>>>>>>',
 '<ab|ab>>>>>>>',
 '<ac|ac>>>>>>>',
 '<ad|ad>>>>>>>',
 '<ae|ae>>>>>>>',
 '<af|af>>>>>>>',
 '<ag|ag>>>>>>>',
 '<ah|ah>>>>>>>',
 '<ai|ai>>>>>>>',
 '<aj|aj>>>>>>>']

### Setup Tokenizer

In [145]:
from gpt_from_scratch.naive_tokenizer import NaiveTokenizer

vocab = string.ascii_lowercase + "|" + SpecialToken.BOS + SpecialToken.EOS

tokenizer = NaiveTokenizer.from_text(vocab)

In [146]:
from gpt_from_scratch import tokenizer_utils

# test tokenizer
input_text = "<abc|cba|abc><bd|db|bd>>>>"
tokenizer_utils.show_token_mapping(tokenizer, input_text)

Input:		<abc|cba|abc><bd|db|bd>>>>
Tokenized:	[44m[97m<[0m[42m[97ma[0m[43m[97mb[0m[46m[97mc[0m[45m[97m|[0m[41m[97mc[0m[44m[97mb[0m[42m[97ma[0m[43m[97m|[0m[46m[97ma[0m[45m[97mb[0m[41m[97mc[0m[44m[97m>[0m[42m[97m<[0m[43m[97mb[0m[46m[97md[0m[45m[97m|[0m[41m[97md[0m[44m[97mb[0m[42m[97m|[0m[43m[97mb[0m[46m[97md[0m[45m[97m>[0m[41m[97m>[0m[44m[97m>[0m[42m[97m>[0m
Token ID | Token Bytes | Token String
---------+-------------+--------------
       0 | [38;5;2m3C[0m | '<'
          [48;5;1m[38;5;15m<[0mabc|cba|abc><bd|db|bd>>>>
          U+003C LESS-THAN SIGN (1 bytes: [38;5;2m3C[0m)
       2 | [38;5;2m61[0m | 'a'
          <[48;5;1m[38;5;15ma[0mbc|cba|abc><bd|db|bd>>>>
          U+0061 LATIN SMALL LETTER A (1 bytes: [38;5;2m61[0m)
       3 | [38;5;2m62[0m | 'b'
          <a[48;5;1m[38;5;15mb[0mc|cba|abc><bd|db|bd>>>>
          U+0062 LATIN SMALL LETTER B (1 bytes: [38;5;2m62[0m)
       4 | [38;5;

### Setup Loss Function

In [147]:
def loss_fn(logits, target):
    # standard cross entropy loss
    return torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)),
        target.view(-1),
    )

### Evaluate On Test

In [148]:
def evaluate_loss_on_test_batches(
    model: transformer_lens.HookedTransformer,
    data_loader: torch.utils.data.DataLoader,
    max_batches: int,
) -> float:

    # Set the model to evaluation mode
    model.eval()

    losses = []

    with torch.no_grad():  # Disable gradient computation

        for batch_index, batch in enumerate(data_loader):

            if batch_index > max_batches:
                break

            x, y = batch

            x, y = x.to(device), y.to(device)

            logits = model(x)

            loss = loss_fn(logits, y)

            losses.append(loss.item())

    # Set the model back to training mode
    model.train()

    return sum(losses) / len(losses)

### Setup Data Loaders

In [149]:
class AutoregressiveDataset(torch.utils.data.Dataset):
    def __init__(self, samples: list[str], tokenizer: NaiveTokenizer) -> None:
        self.samples = samples
        self.tokenizer = tokenizer  # Assuming tokenizer is defined in the global scope

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens = self.tokenizer.encode(sample)

        # Convert to tensor and add batch dimension
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)

        return x, y


def make_batch_dataloader(
    samples: list[str],
    tokenizer: NaiveTokenizer,
    batch_size: int,
) -> tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:

    dataset = AutoregressiveDataset(samples=samples, tokenizer=tokenizer)

    # Create DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        # drop the last batch if it's incomplete
        drop_last=True,
    )

    return dataset, dataloader


# Example usage:
# batch_generator = make_batch_generator(tokenizer, batch_size=4)
# for x, y in batch_generator:
#     # x is input, y is target (x shifted by 1)
#     pass

In [150]:
import random

# split into test and train
all_samples = list(generate_sample())

# note: 4394 batches = (26 * 26 * 26) / 4
print(f"{len(all_samples)} samples")

# Randomly shuffle all_samples
random.shuffle(all_samples)  # In-place shuffling of the list

# Inline comment explaining the motivation
# We shuffle the samples to ensure a random distribution of data points
# between the training and test sets, reducing potential bias


# max_samples = 10
# print(f'Capping at {max_samples} batches first to make sure we can overfit')
# all_samples = all_samples[:max_samples]

test_train_ratio = 0.2

test_size = int(test_train_ratio * len(all_samples))

# put remaining ones into train
train_size = len(all_samples) - test_size

train_samples = all_samples[:train_size]
test_samples = all_samples[train_size:]

print(f"{len(train_samples)=}")
print(f"{len(test_samples)=}")

# now we can finally construct dataloaders
batch_size = 128

train_dataset, train_loader = make_batch_dataloader(
    samples=train_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)
# Split test_samples based on the number of '>' characters
test_samples_by_difficulty = {}
for sample in test_samples:
    difficulty = len(sample) - sample.count(">")
    if difficulty not in test_samples_by_difficulty:
        test_samples_by_difficulty[difficulty] = []
    test_samples_by_difficulty[difficulty].append(sample)

# Sort the dictionary by difficulty (number of '>' characters)
test_samples_by_difficulty = dict(
    sorted(test_samples_by_difficulty.items(), reverse=True)
)

# Inline comment explaining the motivation
# We sort the dictionary by difficulty to ensure a consistent order
# when iterating through the difficulty levels, making it easier to
# analyze and compare model performance across increasing complexities

for difficulty, samples in test_samples_by_difficulty.items():
    print(f"{difficulty}: {len(samples)}")

# Create dataloaders for each difficulty level
test_datasets = {}
test_loaders = {}
for difficulty, samples in test_samples_by_difficulty.items():
    test_datasets[difficulty], test_loaders[difficulty] = make_batch_dataloader(
        samples=samples,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )

# Inline comment explaining the motivation
# We split the test samples based on the number of '>' characters to create
# separate datasets for different difficulty levels. This allows us to evaluate
# the model's performance across varying complexities of input sequences.

12356604 samples
len(train_samples)=9885284
len(test_samples)=2471320
12: 2376204
10: 91398
8: 3578
6: 140


### Setup Model

In [151]:
# now we know our vocab size from our sample generation
def make_hooked_transformer_config(
    n_layers: int,
    d_model: int,
    n_heads: int,
) -> transformer_lens.HookedTransformerConfig:

    for sample in generate_sample():
        n_ctx = len(sample)
        break

    cfg = transformer_lens.HookedTransformerConfig(
        n_layers=n_layers,
        d_model=d_model,
        d_head=d_model // n_heads,
        # The number of attention heads.
        # If not specified, will be set to d_in // d_head.
        # (This is represented by a default value of -1)
        n_heads=n_heads,
        # The dimensionality of the feedforward mlp network.
        # Defaults to 4 * d_in, and in an attn-only model is None.
        # TODO(bschoen): Need to try out also setting `attn_only`
        # d_mlp=None,
        # note: transformerlens does the same thing if this is not set
        d_vocab=len(tokenizer.byte_to_token_dict),
        # length of the longest sample is our context length
        n_ctx=n_ctx,
        act_fn="relu",
        normalization_type="LN",
        # note: must be set, otherwise tries to default to cuda / cpu (not mps)
        device=device.type,
    )

    print(f"Num params: {cfg.n_params}")

    return cfg

## Setup Image Logging

In [45]:
# Convert matplotlib figure to PNG for wandb upload
import io
from PIL import Image

import matplotlib.pyplot as plt
import numpy as np
from typing import Dict, Any
from jaxtyping import Float


def fig_to_wandb_image(fig) -> Image:
    """
    Convert a matplotlib figure to a PNG image that can be uploaded to wandb.

    Args:
        fig (matplotlib.figure.Figure): The matplotlib figure to convert

    Returns:
        PIL.Image: The figure as a PIL Image object
    """
    # Save the figure to a byte buffer
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=300, bbox_inches="tight")
    buf.seek(0)

    # Convert the buffer to a PIL Image
    image = Image.open(buf)
    return image


# note: `title` is passed in for telling them apart in gifs etc
def generate_image_for_attention_patterns(
    input_token_str_to_cache_dict: dict[str, transformer_lens.ActivationCache],
    title: str,
) -> Image:
    """
    Visualize attention patterns for all layers and heads in the model for multiple caches.

    Args:
        caches (List[Dict[str, Any]]): List of caches containing attention patterns from model forward passes.

    Returns:
        plt.Figure: A matplotlib figure containing the visualized attention patterns.
    """
    input_token_strings = list(input_token_str_to_cache_dict.keys())
    caches = list(input_token_str_to_cache_dict.values())

    # Find all attention pattern tensors in the first cache (assuming all caches have the same structure)
    pattern_keys = [
        key for key in caches[0].keys() if key.endswith(".attn.hook_pattern")
    ]

    n_layers = len(pattern_keys)
    n_heads = caches[0][pattern_keys[0]].shape[1]
    n_caches = len(caches)

    # Calculate total number of subplots
    total_subplots = n_layers * n_heads

    # Create a figure with subplots stacked vertically for each cache
    fig, axes = plt.subplots(
        n_caches, total_subplots, figsize=(4 * total_subplots, 4 * n_caches)
    )

    # Set overall figure title
    fig.suptitle(title, fontsize=16)

    # Color maps for alternating heads
    cmaps = ["Blues", "Reds"]

    for cache_idx, cache in enumerate(caches):
        input_token_string = input_token_strings[cache_idx]
        for layer, key in enumerate(pattern_keys):
            attention_pattern = cache[key]

            # Remove batch dimension and move to CPU
            reshaped_pattern = attention_pattern.squeeze(0).detach().cpu().numpy()

            for head in range(n_heads):
                subplot_index = layer * n_heads + head
                ax = (
                    axes[cache_idx, subplot_index]
                    if n_caches > 1
                    else axes[subplot_index]
                )

                # Plot the attention pattern
                im = ax.imshow(reshaped_pattern[head], cmap=cmaps[head % len(cmaps)])

                # Set title for each subplot
                ax.set_title(f"L{layer}-H{head}", fontsize=8)

                # Set column labels as individual characters from input_token_string at the top
                ax.xaxis.tick_top()
                ax.set_xticks(range(len(input_token_string)))
                ax.set_xticklabels(list(input_token_string), fontsize=6, ha="right")

                ax.set_yticks([])  # Remove y-axis ticks

    plt.tight_layout()

    image = fig_to_wandb_image(fig)

    # close figure so doesn't keep taking up memory
    plt.close(fig)

    return image

In [126]:
import glob
from PIL import Image
import os
import pathlib


def convert_pngs_in_directory_to_gif(output_dir: pathlib.Path) -> pathlib.Path:

    # Get a list of all PNG files in the output directory
    # Use rglob for recursive search of PNG files
    png_files = list(output_dir.rglob("*.png"))

    # sort by step
    #
    # files have format
    #
    # - `.../<key>_<step>_<hash-identifier-thing>.png`
    # - ex: `.../attention_100_d8bda3455ffb06855d88.png`
    #
    png_files = sorted(png_files, key=lambda x: int(x.name.split("_")[1]))

    # Create a list to store the image frames
    frames = []

    # Load each PNG file and append it to the frames list
    print(f"Generating gif from {len(png_files)} images...")
    for png_file in png_files:
        # Open the image and convert it to RGB mode (required for GIF)
        img = Image.open(str(png_file)).convert("RGB")
        frames.append(img)

    # Define the output GIF filename
    gif_filename = output_dir / "attention_pattern_evolution.gif"

    # Save the frames as an animated GIF
    print(f"Saving gif from {len(frames)} frames to {gif_filename}...")
    frames[0].save(
        gif_filename,
        save_all=True,
        append_images=frames[1:],
        optimize=False,
        duration=200,  # Duration between frames in milliseconds
        loop=0,  # 0 means loop indefinitely
    )

    print(f"GIF created and saved as: {gif_filename}")

    # Optionally, log the GIF to wandb
    # wandb.log({"attention_pattern_evolution": wandb.Image(str(gif_filename))})

    return gif_filename

In [47]:
class WandbConstants:
    ENTITY = "bronsonschoen-personal-use"
    PROJECT = "toy-problem-hooked-transformer-v3"
    NAME = "toy-sequence"
    ATTENTION_PATTERN_IMAGES = "attention"

In [48]:
LossValue = float


@dataclasses.dataclass(frozen=True)
class TrainModelResult:
    model: transformer_lens.HookedTransformer

    # returned because optuna needs it
    # TODO(bschoen): Is this usually val loss?
    train_loss: LossValue

    # useful to retrieve files
    wandb_run_name: str
    wandb_run_id: str

In [55]:
import wandb
import pathlib
import tqdm


def download_images_from_run(result: TrainModelResult) -> pathlib.Path:

    # write things to run specific directory
    output_dir = pathlib.Path(f"wandb_artifacts/{result.wandb_run_id}")

    # create output dir if not exists
    output_dir.mkdir(exist_ok=True, parents=True)

    api = wandb.Api()

    identifier = "/".join(
        [
            WandbConstants.ENTITY,
            WandbConstants.PROJECT,
            result.wandb_run_id,
        ]
    )

    print(f"Downloading {identifier}...")
    run = api.run(identifier)

    # filter down to just attention pattern images
    files = [
        x
        for x in run.files()
        if x.name.startswith(f"media/images/{WandbConstants.ATTENTION_PATTERN_IMAGES}")
    ]

    for file in tqdm.tqdm(desc="Downloading images...", iterable=files):

        print(f"Downloading {file.name}")
        file.download(
            root=str(output_dir),
            replace=False,
            exist_ok=True,
            api=api,
        )

    return output_dir

## Training

In [22]:
# TODO(bschoen): Holdout set of n+1 length

In [152]:
import tqdm

import torch.optim

import wandb

import dataclasses
import json

import time


def print_json(value):
    print(json.dumps(value, indent=2))


# everything customizable via optuna
@dataclasses.dataclass(frozen=True)
class ModelAndTrainingConfig:

    # input
    train_loader: torch.utils.data.DataLoader
    test_loaders: dict[int, torch.utils.data.DataLoader]

    # training
    num_epochs: int = 10000
    eval_test_every_n: int = 500
    wait_between_eval_s: int | None = None

    # model
    n_layers: int = 2
    d_model: int = 16
    n_heads: int = 2

    # optimizers
    betas: tuple[float, float] = (0.9, 0.999)
    learning_rate: float = 1e-3
    max_grad_norm: float = 1.0
    weight_decay: float = 0.1

    def get_hooked_transformer_config(self) -> transformer_lens.HookedTransformerConfig:
        return make_hooked_transformer_config(
            n_layers=self.n_layers,
            d_model=self.d_model,
            n_heads=self.n_heads,
        )

    def to_dict(self) -> dict[str, str | int]:
        dict_repr = dataclasses.asdict(self)
        dict_repr.pop("train_loader")
        dict_repr.pop("test_loaders")
        return dict_repr


def train_model(cfg: ModelAndTrainingConfig) -> TrainModelResult:

    # create new model instance
    ht_cfg = cfg.get_hooked_transformer_config()
    model = transformer_lens.HookedTransformer(ht_cfg)

    # setup optimizers
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.learning_rate,
        betas=cfg.betas,
        weight_decay=cfg.weight_decay,
    )
    # scheduler = torch.optim.lr_scheduler.LambdaLR(
    #    optimizer, lambda i: min(i / 100, 1.0)
    # )

    num_epochs = cfg.num_epochs

    # setup wandb
    wandb.init(
        project=WandbConstants.PROJECT,
        name=WandbConstants.NAME,
        config=cfg.to_dict(),
    )

    print(f"Run name {wandb.run.name} - {wandb.run.id}")

    # create a small (fixed) training set of each difficulty to use for visualization
    test_example_per_difficulty = {}
    for difficulty, test_loader in cfg.test_loaders.items():
        # grab something from the test batch
        x, _ = next(iter(test_loader))
        input_tokens = x[0].to(device)
        test_example_per_difficulty[difficulty] = input_tokens

    losses = []

    for epoch, batch in tqdm.tqdm(
        zip(
            range(num_epochs),
            itertools.cycle(train_loader),
        )
    ):

        tokens, target = batch

        tokens, target = tokens.to(device), target.to(device)

        # ex: torch.Size([4, 9, 29])
        logits: Float32[torch.Tensor, "b t c"] = model(tokens)

        # print(f"Logits:\n{logits.shape}")
        loss = loss_fn(logits, target)

        loss.backward()

        if cfg.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)

        optimizer.step()

        optimizer.zero_grad()

        # scheduler.step()

        losses.append(loss.item())

        # more frequently than eval, print out train loss
        # if (epoch % (cfg.eval_test_every_n // 10)) == 0:
        #
        #    print(f"Epoch {epoch}, " f"Train loss: {loss.item():.6f}")

        # TODO(bschoen): Shouldn't you actually divide loss by batch size?
        # TODO(bschoen): Do we want like an `is trial` (for example logging last one)
        if (epoch % cfg.eval_test_every_n) == 0:

            # skip evaluating test loss if we just started training
            # if epoch == 0:
            #    continue

            print("Evaluating test loss...")

            # compute loss at each difficulty
            test_loss_by_difficulty = {}

            for difficulty, test_loader in cfg.test_loaders.items():

                test_loss = evaluate_loss_on_test_batches(
                    model,
                    test_loader,
                    max_batches=100,
                )

                test_loss_by_difficulty[difficulty] = test_loss

            wandb_log_dict = {"epoch": epoch, "train_loss": loss.item()}

            for difficulty, test_loss in test_loss_by_difficulty.items():

                wandb_log_dict[f"test_loss_difficulty_{difficulty}"] = test_loss

            # print_json(wandb_log_dict)

            # Log metrics
            wandb.log(wandb_log_dict, step=epoch)

            # Compute attention pattern visualization
            print("Computing attention pattern visualization...")
            model.eval()
            test_example_string_to_cache = {}

            for difficulty, input_tokens in test_example_per_difficulty.items():

                logits, cache = model.run_with_cache(input_tokens)

                # store example by using the actual text string as key
                input_tokens_str = "".join(
                    [tokenizer.decode([x.item()]) for x in input_tokens]
                )

                test_example_string_to_cache[input_tokens_str] = cache

            image = generate_image_for_attention_patterns(
                test_example_string_to_cache,
                title=f"Step: {epoch}",
            )

            wandb.log(
                {WandbConstants.ATTENTION_PATTERN_IMAGES: wandb.Image(image)},
                step=epoch,
            )

            if cfg.wait_between_eval_s and cfg.wait_between_eval_s is not None:
                print(
                    f"Sleeping for {cfg.wait_between_eval_s} to avoid wandb rate limiting"
                )
                time.sleep(cfg.wait_between_eval_s)

    # capture run name and id before `finish`
    wandb_run_name = wandb.run.name
    wandb_run_id = wandb.run.id

    wandb.finish()

    # log locally to sanity check
    # px.line(losses, labels={"x": "Epoch", "y": "Train Loss"})

    print(f"Final train loss: {loss.item():.6f}")

    # take model out of train
    model.eval()

    return TrainModelResult(
        model=model,
        train_loss=loss.item(),
        wandb_run_name=wandb_run_name,
        wandb_run_id=wandb_run_id,
    )


# train brief run to test code
training_config = ModelAndTrainingConfig(
    num_epochs=25000,
    eval_test_every_n=1000,
    weight_decay=0.1,
    wait_between_eval_s=None,
    train_loader=train_loader,
    test_loaders=test_loaders,
)

result = train_model(training_config)

# for compatibility with code later
model = result.model
cfg = training_config.get_hooked_transformer_config()

Num params: 6144


Run name toy-sequence - y07bmj07


0it [00:00, ?it/s]

Evaluating test loss...
Computing attention pattern visualization...


997it [00:29, 46.32it/s]

Evaluating test loss...
Computing attention pattern visualization...


1996it [00:54, 46.93it/s]

Evaluating test loss...
Computing attention pattern visualization...


2996it [01:20, 46.91it/s]

Evaluating test loss...
Computing attention pattern visualization...


3996it [01:46, 43.64it/s]

Evaluating test loss...
Computing attention pattern visualization...


5000it [02:13, 45.99it/s]

Evaluating test loss...
Computing attention pattern visualization...


5996it [02:40, 44.31it/s]

Evaluating test loss...
Computing attention pattern visualization...


6997it [03:08, 44.91it/s]

Evaluating test loss...
Computing attention pattern visualization...


8000it [03:35, 45.31it/s]

Evaluating test loss...
Computing attention pattern visualization...


9000it [04:01, 42.98it/s]

Evaluating test loss...
Computing attention pattern visualization...


9996it [04:29, 48.12it/s]

Evaluating test loss...
Computing attention pattern visualization...


10998it [04:55, 46.59it/s]

Evaluating test loss...
Computing attention pattern visualization...


11998it [05:22, 46.59it/s]

Evaluating test loss...
Computing attention pattern visualization...


12998it [05:48, 47.40it/s]

Evaluating test loss...
Computing attention pattern visualization...


13999it [06:14, 42.04it/s]

Evaluating test loss...
Computing attention pattern visualization...


14998it [06:39, 48.87it/s]

Evaluating test loss...
Computing attention pattern visualization...


15998it [07:04, 48.71it/s]

Evaluating test loss...
Computing attention pattern visualization...


16998it [07:29, 47.84it/s]

Evaluating test loss...
Computing attention pattern visualization...


18000it [07:55, 44.63it/s]

Evaluating test loss...
Computing attention pattern visualization...


19000it [08:21, 48.83it/s]

Evaluating test loss...
Computing attention pattern visualization...


19998it [08:47, 47.64it/s]

Evaluating test loss...
Computing attention pattern visualization...


20997it [09:12, 42.91it/s]

Evaluating test loss...
Computing attention pattern visualization...


21997it [09:37, 49.09it/s]

Evaluating test loss...
Computing attention pattern visualization...


22996it [10:01, 49.04it/s]

Evaluating test loss...
Computing attention pattern visualization...


23997it [10:26, 48.73it/s]

Evaluating test loss...
Computing attention pattern visualization...


25000it [10:51, 38.39it/s]


VBox(children=(Label(value='3.484 MB of 3.484 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▂▃▃▃▄▄▄▅▅▅▅▆▆▆▇▇▇▇██
test_loss_difficulty_10,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss_difficulty_12,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
test_loss_difficulty_6,█▃▂▂▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▁▂
test_loss_difficulty_8,█▄▅▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,24000.0
test_loss_difficulty_10,1.3682
test_loss_difficulty_12,1.36596
test_loss_difficulty_6,2.02947
test_loss_difficulty_8,1.50503
train_loss,1.36459


Final train loss: 1.363202
Num params: 6144


## Save Output Image To Gif

In [153]:
output_dir = download_images_from_run(result=result)

gif_filepath = convert_pngs_in_directory_to_gif(output_dir=output_dir)

print(gif_filepath)

Downloading bronsonschoen-personal-use/toy-problem-hooked-transformer-v3/y07bmj07...


Downloading images...:   0%|          | 0/25 [00:00<?, ?it/s]

Downloading media/images/attention_0_40209ede86ce5e85fce3.png


Downloading images...:   4%|▍         | 1/25 [00:00<00:12,  1.97it/s]

Downloading media/images/attention_10000_4714e8d3cddfe0915df2.png


Downloading images...:   8%|▊         | 2/25 [00:01<00:11,  1.98it/s]

Downloading media/images/attention_1000_adafa1a644a75e7057da.png


Downloading images...:  12%|█▏        | 3/25 [00:01<00:12,  1.79it/s]

Downloading media/images/attention_11000_6be3c17695fba2dbd0d9.png


Downloading images...:  16%|█▌        | 4/25 [00:02<00:11,  1.90it/s]

Downloading media/images/attention_12000_95068dd82d1a846023d1.png


Downloading images...:  20%|██        | 5/25 [00:02<00:10,  1.86it/s]

Downloading media/images/attention_13000_3d71f1c8c3b2d5099a00.png


Downloading images...:  24%|██▍       | 6/25 [00:03<00:10,  1.75it/s]

Downloading media/images/attention_14000_0ebe4843c650e6d36b7f.png


Downloading images...:  28%|██▊       | 7/25 [00:03<00:09,  1.82it/s]

Downloading media/images/attention_15000_d46b6bb3a25fd5fbf67e.png


Downloading images...:  32%|███▏      | 8/25 [00:04<00:09,  1.88it/s]

Downloading media/images/attention_16000_ff1253479b413f2af526.png


Downloading images...:  36%|███▌      | 9/25 [00:04<00:08,  1.79it/s]

Downloading media/images/attention_17000_fcfe2560853bab5a161e.png


Downloading images...:  40%|████      | 10/25 [00:05<00:08,  1.85it/s]

Downloading media/images/attention_18000_d44da1d5c7e98eebe8f3.png


Downloading images...:  44%|████▍     | 11/25 [00:05<00:07,  1.89it/s]

Downloading media/images/attention_19000_96de678a132f8b7bd72e.png


Downloading images...:  48%|████▊     | 12/25 [00:06<00:07,  1.81it/s]

Downloading media/images/attention_20000_d54d523b21a35ec1b62c.png


Downloading images...:  52%|█████▏    | 13/25 [00:07<00:06,  1.73it/s]

Downloading media/images/attention_2000_3a929ffeda1bdad2668b.png


Downloading images...:  56%|█████▌    | 14/25 [00:07<00:06,  1.69it/s]

Downloading media/images/attention_21000_7a26afbf7f404c110b96.png


Downloading images...:  60%|██████    | 15/25 [00:08<00:05,  1.80it/s]

Downloading media/images/attention_22000_ee0143887b61ce0cd969.png


Downloading images...:  64%|██████▍   | 16/25 [00:08<00:05,  1.76it/s]

Downloading media/images/attention_23000_61ae44c2d8457ba06d9d.png


Downloading images...:  68%|██████▊   | 17/25 [00:09<00:04,  1.70it/s]

Downloading media/images/attention_24000_86a9e7ff07f981d43eb5.png


Downloading images...:  72%|███████▏  | 18/25 [00:10<00:04,  1.70it/s]

Downloading media/images/attention_3000_db1cbd6a18c4d88399fe.png


Downloading images...:  76%|███████▌  | 19/25 [00:10<00:03,  1.78it/s]

Downloading media/images/attention_4000_27fbb6f23d068e4f5146.png


Downloading images...:  80%|████████  | 20/25 [00:11<00:02,  1.84it/s]

Downloading media/images/attention_5000_9aa3f5e0519617220e21.png


Downloading images...:  84%|████████▍ | 21/25 [00:11<00:02,  1.78it/s]

Downloading media/images/attention_6000_0fa5db96b47afcfe6965.png


Downloading images...:  88%|████████▊ | 22/25 [00:12<00:01,  1.78it/s]

Downloading media/images/attention_7000_36321db52bcaca1035c7.png


Downloading images...:  92%|█████████▏| 23/25 [00:12<00:01,  1.71it/s]

Downloading media/images/attention_8000_dc1cbd265fb069b873a3.png


Downloading images...:  96%|█████████▌| 24/25 [00:13<00:00,  1.67it/s]

Downloading media/images/attention_9000_615974c240cc6ca293e0.png


Downloading images...: 100%|██████████| 25/25 [00:14<00:00,  1.77it/s]


Generating gif from 25 images...
Saving gif from 25 frames to wandb_artifacts/y07bmj07/attention_pattern_evolution.gif...
GIF created and saved as: wandb_artifacts/y07bmj07/attention_pattern_evolution.gif
wandb_artifacts/y07bmj07/attention_pattern_evolution.gif


## Indirect Object Identification

In [86]:
import einops
import circuitsvis as cv


def add_batch_dimension(
    x: Float32[torch.Tensor, "..."]
) -> Float32[torch.Tensor, "batch ..."]:
    return einops.rearrange(x, "... -> 1 ...")


def tokenize_string(input_string: str) -> Float32[torch.Tensor, "seq"]:

    tokens = tokenizer.encode(input_string)

    return torch.tensor(tokens, dtype=torch.long).to(device)


def tokenize_string_as_batch(input_string: str) -> Float32[torch.Tensor, "batch seq"]:

    return add_batch_dimension(tokenize_string(input_string))


def get_first_mismatched_pair(
    tokens_a: Float32[torch.Tensor, "batch=1 seq"],
    tokens_b: Float32[torch.Tensor, "batch=1 seq"],
) -> Float32[torch.Tensor, "batch=1 2"]:

    assert tokens_a.shape == tokens_b.shape

    for index in range(tokens_a.shape[-1]):

        if tokens_a[0, index] != tokens_b[0, index]:

            mismatch: Float32[torch.Tensor, "2"] = torch.tensor(
                [
                    tokens_a[0, index],
                    tokens_b[0, index],
                ]
            ).to(device)

            return add_batch_dimension(mismatch)


# create a custom to_string function since using our own tokenizer
def token_to_string(token: int) -> str:
    return tokenizer.decode([token])


# TODO(bschoen): Vary along things besides reversal

# take an example, modify the first part of the sequence reversal to be wrong
input_string = "<az|"
correct_string = "<az|a"  # a|az>>>>>>>>>>>>>>>>>>>>>>>>"
incorrect_string = "<az|z"  # "a|az>>>>>>>>>>>>>>>>>>>>>>>>"

input_string_tokens = tokenize_string_as_batch(input_string)
correct_string_tokens = tokenize_string_as_batch(correct_string)
incorrect_string_tokens = tokenize_string_as_batch(incorrect_string)

logits, cache = model.run_with_cache(input_string_tokens)
correct_logits, correct_cache = model.run_with_cache(correct_string_tokens)
incorrect_logits, incorrect_cache = model.run_with_cache(incorrect_string_tokens)

In [87]:
cv.logits.token_log_probs(
    token_indices=input_string_tokens,
    log_probs=logits.log_softmax(dim=-1),
    to_string=token_to_string,
)

In [80]:
# position where we changed the sequence
mismatch_position_index = 4

correct_token = correct_string_tokens[0, mismatch_position_index].item()
incorrect_token = incorrect_string_tokens[0, mismatch_position_index].item()

print(f"correct_token: {correct_token} ({tokenizer.decode([correct_token])})")
print(f"incorrect_token: {incorrect_token} ({tokenizer.decode([incorrect_token])})")

correct_token: 2 (a)
incorrect_token: 27 (z)


### Logit Difference In Accumulated Residual Stream

In [81]:
# get diff in format expected by `model.tokens_to_residual_directions`
answer_tokens = get_first_mismatched_pair(
    correct_string_tokens,
    incorrect_string_tokens,
)

print(f"{answer_tokens.shape=}")

answer_residual_directions: Float32[torch.Tensor, "batch 2 d_model"] = (
    model.tokens_to_residual_directions(answer_tokens)
)

print("Answer residual directions shape:", answer_residual_directions.shape)

correct_residual_directions, incorrect_residual_directions = (
    answer_residual_directions.unbind(dim=1)
)
correct_residual_directions: Float32[torch.Tensor, "batch d_model"] = (
    correct_residual_directions
)
incorrect_residual_directions: Float32[torch.Tensor, "batch d_model"] = (
    incorrect_residual_directions
)

logit_diff_directions: Float32[torch.Tensor, "batch d_model"] = (
    correct_residual_directions - incorrect_residual_directions
)

print(f"Logit difference directions shape:", logit_diff_directions.shape)

answer_tokens.shape=torch.Size([1, 2])
Answer residual directions shape: torch.Size([1, 2, 16])
Logit difference directions shape: torch.Size([1, 16])


In [82]:
def logits_to_ave_logit_diff(
    logits: Float[torch.Tensor, "batch seq d_vocab"],
    answer_tokens: Float[torch.Tensor, "batch 2"],
    per_prompt: bool = False,
) -> Float[torch.Tensor, "*batch"]:
    """
    Returns logit difference between the correct and incorrect answer.

    If per_prompt=True, return the array of differences rather than the average.
    """
    # SOLUTION
    # Only the final logits are relevant for the answer
    final_logits: Float[torch.Tensor, "batch d_vocab"] = logits[:, -1, :]
    # Get the logits corresponding to the indirect object / subject tokens respectively
    answer_logits: Float[torch.Tensor, "batch 2"] = final_logits.gather(
        dim=-1, index=answer_tokens
    )
    # Find logit difference
    correct_logits, incorrect_logits = answer_logits.unbind(dim=-1)
    answer_logit_diff = correct_logits - incorrect_logits
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

In [83]:
original_per_prompt_diff = logits_to_ave_logit_diff(
    logits, answer_tokens, per_prompt=True
)
print("Per prompt logit difference:", original_per_prompt_diff)
original_average_logit_diff = logits_to_ave_logit_diff(logits, answer_tokens)
print("Average logit difference:", original_average_logit_diff)

Per prompt logit difference: tensor([-4.3593], device='mps:0', grad_fn=<SubBackward0>)
Average logit difference: tensor(-4.3593, device='mps:0', grad_fn=<MeanBackward0>)


In [85]:
final_residual_stream = cache["resid_post", -1]  # [batch seq d_model]
print(f"Final residual stream shape: {final_residual_stream.shape}")
final_token_residual_stream = final_residual_stream[:, -1, :]  # [batch d_model]

# Apply LayerNorm scaling (to just the final sequence position)
# pos_slice is the subset of the positions we take - here the final token of each prompt
scaled_final_token_residual_stream = cache.apply_ln_to_stack(
    final_token_residual_stream,
    layer=-1,
    pos_slice=-1,
)

batch_size = input_string_tokens.shape[0]

average_logit_diff = (
    einops.einsum(
        scaled_final_token_residual_stream,
        logit_diff_directions,
        "batch d_model, batch d_model ->",
    )
    / batch_size
)

print(f"Calculated average logit diff: {average_logit_diff:.10f}")
print(f"Original logit difference:     {original_average_logit_diff:.10f}")

Final residual stream shape: torch.Size([1, 4, 16])
Calculated average logit diff: -3.0110464096
Original logit difference:     -4.3593158722


In [91]:
def residual_stack_to_logit_diff(
    residual_stack: Float32[torch.Tensor, "... batch d_model"],
    cache: transformer_lens.ActivationCache,
    logit_diff_directions: Float[torch.Tensor, "batch d_model"],
) -> Float32[torch.Tensor, "..."]:
    """
    Gets the avg logit difference between the correct and incorrect answer for a given
    stack of components in the residual stream.
    """
    # SOLUTION
    batch_size = residual_stack.size(-2)
    scaled_residual_stack = cache.apply_ln_to_stack(
        residual_stack,
        layer=-1,
        pos_slice=-1,
    )
    return (
        einops.einsum(
            scaled_residual_stack,
            logit_diff_directions,
            "... batch d_model, batch d_model -> ...",
        )
        / batch_size
    )

In [92]:
from gpt_from_scratch import plotly_utils

accumulated_residual, labels = cache.accumulated_resid(
    layer=-1,
    incl_mid=True,
    pos_slice=-1,
    return_labels=True,
)
# accumulated_residual has shape (component, batch, d_model)

logit_lens_logit_diffs: Float32[torch.Tensor, "..."] = residual_stack_to_logit_diff(
    accumulated_residual,
    cache,
    logit_diff_directions,
)  # [component]

plotly_utils.line(
    logit_lens_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Accumulated Residual Stream",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800,
)

### Logit Difference From Each Layer

In [94]:
per_layer_residual, labels = cache.decompose_resid(
    layer=-1, pos_slice=-1, return_labels=True
)
per_layer_logit_diffs = residual_stack_to_logit_diff(
    per_layer_residual,
    cache,
    logit_diff_directions,
)

plotly_utils.line(
    per_layer_logit_diffs,
    hovermode="x unified",
    title="Logit Difference From Each Layer",
    labels={"x": "Layer", "y": "Logit Diff"},
    xaxis_tickvals=labels,
    width=800,
)

### Logit Difference From Each Head

In [96]:
per_head_residual, labels = cache.stack_head_results(
    layer=-1, pos_slice=-1, return_labels=True
)
per_head_residual = einops.rearrange(
    per_head_residual,
    "(layer head) ... -> layer head ...",
    layer=model.cfg.n_layers,
)
per_head_logit_diffs = residual_stack_to_logit_diff(
    per_head_residual,
    cache,
    logit_diff_directions,
)

plotly_utils.imshow(
    per_head_logit_diffs,
    labels={"x": "Head", "y": "Layer"},
    title="Logit Difference From Each Head",
    width=600,
)

### Highest Value Attention Heads

In [124]:
import IPython.core.display
import IPython.display


def topk_of_Nd_tensor(
    tensor: Float[torch.Tensor, "rows cols"],
    k: int,
) -> list[tuple[int, int]]:
    """
    Helper function: does same as tensor.topk(k).indices, but works over 2D tensors.
    Returns a list of indices, i.e. shape [k, tensor.ndim].

    Example: if tensor is 2D array of values for each head in each layer, this will
    return a list of heads.
    """
    i = torch.topk(tensor.flatten(), k).indices
    return np.array(
        np.unravel_index(
            transformer_lens.utils.to_numpy(i),
            tensor.shape,
        )
    ).T.tolist()


k = 3

for head_type in ["Positive", "Negative"]:

    # Get the heads with largest (or smallest) contribution to the logit difference
    top_heads = topk_of_Nd_tensor(
        per_head_logit_diffs.cpu() * (1 if head_type == "Positive" else -1), k
    )

    # ex: [[0, 1], [1, 0], [0, 0]]
    print(top_heads)

    # Get all their attention patterns
    attn_patterns_for_important_heads: Float[torch.Tensor, "head q k"] = torch.stack(
        [cache["pattern", layer][:, head][0] for layer, head in top_heads]
    )

    print(f"{attn_patterns_for_important_heads.shape=}")

    # Display results
    display(
        cv.attention.attention_heads(
            attention=attn_patterns_for_important_heads,
            tokens=[x for x in input_string],
            attention_head_names=[f"{layer}.{head}" for layer, head in top_heads],
        )
    )

    break

[[0, 1], [1, 0], [0, 0]]
attn_patterns_for_important_heads.shape=torch.Size([3, 4, 4])


### Activation Patching

In [125]:
# TODO(bschoen): Clean and corrupted should actually switch first, should do this for search
clean_logit_diff = logits_to_ave_logit_diff(correct_logits, answer_tokens)
print(f"Clean logit diff: {clean_logit_diff:.4f}")

corrupted_logit_diff = logits_to_ave_logit_diff(incorrect_logits, answer_tokens)
print(f"Corrupted logit diff: {corrupted_logit_diff:.4f}")

Clean logit diff: 11.5820
Corrupted logit diff: 7.8760


In [108]:
cv.attention.attention_heads?

[0;31mSignature:[0m
[0mcv[0m[0;34m.[0m[0mattention[0m[0;34m.[0m[0mattention_heads[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mattention[0m[0;34m:[0m [0mUnion[0m[0;34m[[0m[0mlist[0m[0;34m,[0m [0mnumpy[0m[0;34m.[0m[0mndarray[0m[0;34m,[0m [0mtorch[0m[0;34m.[0m[0mTensor[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtokens[0m[0;34m:[0m [0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mattention_head_names[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mList[0m[0;34m[[0m[0mstr[0m[0;34m][0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmax_value[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmin_value[0m[0;34m:[0m [0mOptional[0m[0;34m[[0m[0mfloat[0m[0;34m][0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnegative_color[0m[0;34m:

## Optuna Study

In [None]:
import optuna


# TODO(bschoen): Do need to use lightning if want to do this generally
# note: generally do want to iterate on this part itself, i.e. once find promising learning rate, searching other hyperparameters
def objective(trial: optuna.Trial) -> float:

    # TODO(bschoen): up to one per position, eh might as well try it

    d_model = trial.suggest_categorical("d_model", [8, 16, 32, 64, 128])
    n_heads = trial.suggest_int("n_heads", 1, 8)

    cfg = ModelAndTrainingConfig(
        num_epochs=1000,
        eval_test_every_n=10000,  # not worth evaluating test loss for study
        n_layers=1,  # trial.suggest_int("n_layers", 1, 2),
        d_model=d_model,
        n_heads=n_heads,
        learning_rate=5e-4,
    )

    # sanity check `d_heads`
    if (cfg.d_model % cfg.n_heads) != 0:
        print(f"Pruning trial for {cfg.d_model=} {cfg.n_heads=}")
        raise optuna.exceptions.TrialPruned()

    result = train_model(cfg)

    return result.train_loss


enable_optuna = False

if enable_optuna:

    study_storage_url = "sqlite:///toy-problem-hooked-transformer.db"

    study = optuna.create_study(
        directions=[optuna.study.StudyDirection.MINIMIZE],
        storage=study_storage_url,
    )

    study.optimize(objective, n_trials=10)

    print("View by launching optuna dashboard from the command line:")
    print(f"optuna-dashboard {study_storage_url}")

    # now let's do a real run
    training_config = ModelAndTrainingConfig(
        num_epochs=10000,
        eval_test_every_n=1000,
        n_layers=1,
        d_model=16,
        n_heads=1,
    )

    result = train_model(cfg=training_config)

    # for compatibility with code later
    model = result.model
    cfg = training_config.get_hooked_transformer_config()

In [18]:
# Look at some example output
import circuitsvis as cv

import functools


def visualize_pattern_hook(
    pattern: Float32[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: transformer_lens.hook_points.HookPoint,
    tokens_as_strings: list[str],
) -> None:
    print(f"Batch size: {pattern.shape[0]}")
    print("Layer: ", hook.layer())
    display(
        cv.attention.attention_patterns(
            tokens=tokens_as_strings, attention=pattern.mean(0)
        )
    )


test_input_string_to_cache = {}

for difficulty, test_loader in test_loaders.items():

    print(difficulty)

    # grab something from the test batch
    example_batch = next(iter(test_loader))

    x, y = example_batch

    example_sample = x[0]

    # example_sample = torch.tensor(tokenizer.encode("<az|za|az>>>>>>>>>>"))

    # grab the first part of it, ex: `<abc|`
    example_prompt = example_sample  # [:8]

    example_prompt = example_prompt.to(device)

    print(f"Using {example_prompt} from {example_sample} (from test set)")

    # note: already encoded
    input_tokens = example_prompt

    # first let's get these as strings so can easily work with them
    input_tokens_as_strings = [token_to_string(x.item()) for x in input_tokens]

    # wrap to bind input tokens
    visualize_pattern_hook_fn = functools.partial(
        visualize_pattern_hook, tokens_as_strings=input_tokens_as_strings
    )

    model.run_with_hooks(
        input_tokens,
        return_type=None,  # For efficiency, we don't need to calculate the logits
        fwd_hooks=[(lambda name: name.endswith("pattern"), visualize_pattern_hook_fn)],
    )

    logits_batch, cache = model.run_with_cache(input_tokens)

    # store so can plot together later
    test_input_string_to_cache["".join(input_tokens_as_strings)] = cache

    logits = logits_batch[0]

    log_probs = logits.log_softmax(dim=-1)

    cv.logits.token_log_probs(
        token_indices=input_tokens,
        log_probs=log_probs,
        to_string=token_to_string,
    )

15
Using tensor([ 0,  2, 14, 22, 21, 28, 21, 22, 14,  2, 28,  2, 14, 22, 21,  1,  1,  1,
         1,  1], device='mps:0') from tensor([ 0,  2, 14, 22, 21, 28, 21, 22, 14,  2, 28,  2, 14, 22, 21,  1,  1,  1,
         1,  1]) (from test set)
Batch size: 1
Layer:  0


Batch size: 1
Layer:  1


12
Using tensor([ 0,  3,  3, 26, 28, 26,  3,  3, 28,  3,  3, 26,  1,  1,  1,  1,  1,  1,
         1,  1], device='mps:0') from tensor([ 0,  3,  3, 26, 28, 26,  3,  3, 28,  3,  3, 26,  1,  1,  1,  1,  1,  1,
         1,  1]) (from test set)
Batch size: 1
Layer:  0


Batch size: 1
Layer:  1


9
Using tensor([ 0,  7,  9, 28,  9,  7, 28,  7,  9,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1], device='mps:0') from tensor([ 0,  7,  9, 28,  9,  7, 28,  7,  9,  1,  1,  1,  1,  1,  1,  1,  1,  1,
         1,  1]) (from test set)
Batch size: 1
Layer:  0


Batch size: 1
Layer:  1


In [19]:
cache.apply_ln_to_stack?

[0;31mSignature:[0m
[0mcache[0m[0;34m.[0m[0mapply_ln_to_stack[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mresidual_stack[0m[0;34m:[0m [0;34m"Float[torch.Tensor, 'num_components *batch_and_pos_dims d_model']"[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mlayer[0m[0;34m:[0m [0;34m'Optional[int]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mmlp_input[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpos_slice[0m[0;34m:[0m [0;34m'Union[Slice, SliceInput]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mbatch_slice[0m[0;34m:[0m [0;34m'Union[Slice, SliceInput]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mhas_batch_dim[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m"Float[torch.Tensor, 'num_components *batch_and_pos_dims_out d_model']"[0m[0;3

In [20]:
cache.stack_head_results??

[0;31mSignature:[0m
[0mcache[0m[0;34m.[0m[0mstack_head_results[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mlayer[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;34m-[0m[0;36m1[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mreturn_labels[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mincl_remainder[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpos_slice[0m[0;34m:[0m [0;34m'Union[Slice, SliceInput]'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mapply_ln[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m"Union[Float[torch.Tensor, 'num_components *batch_and_pos_dims d_model'], Tuple[Float[torch.Tensor, 'num_components *batch_and_pos_dims d_model'], List[str]]]"[0m[0;34m[0m[0;34m[0m[0m
[0;31mSource:[0m   
    [0;32mdef[0m [0mstack_head_re

In [22]:
import transformer_lens.patching

transformer_lens.patching.get_act_patch_resid_pre??

[0;31mSignature:[0m      
[0mtransformer_lens[0m[0;34m.[0m[0mpatching[0m[0;34m.[0m[0mget_act_patch_resid_pre[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mmodel[0m[0;34m:[0m [0;34m'HookedTransformer'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mcorrupted_tokens[0m[0;34m:[0m [0;34m"Int[torch.Tensor, 'batch pos']"[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mclean_cache[0m[0;34m:[0m [0;34m'ActivationCache'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpatching_metric[0m[0;34m:[0m [0;34m"Callable[[Float[torch.Tensor, 'batch pos d_vocab']], Float[torch.Tensor, '']]"[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0;34m*[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mpatch_setter[0m[0;34m:[0m [0;34m'Callable[[CorruptedActivation, Sequence[int], ActivationCache], PatchedActivation]'[0m [0;34m=[0m [0;34m<[0m[0mfunction[0m [0mlayer_pos_patch_setter[0m [0mat[0m [0;36m0x177e2a8e0[0m[0;34m>[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mactivation_name[0m

In [None]:
import einops


def logit_attribution(
    embed: Float32[torch.Tensor, "seq d_model"],
    l1_results: Float32[torch.Tensor, "seq nheads d_model"],
    l2_results: Float32[torch.Tensor, "seq nheads d_model"],
    W_U: Float32[torch.Tensor, "d_model d_vocab"],
    tokens: Int64[torch.Tensor, "seq"],
) -> Float32[torch.Tensor, "seq-1 n_components"]:
    """
    Inputs:
        embed: the embeddings of the tokens (i.e. token + position embeddings)
        l1_results: the outputs of the attention heads at layer 1 (with head as one of the dimensions)
        l2_results: the outputs of the attention heads at layer 2 (with head as one of the dimensions)
        W_U: the unembedding matrix
        tokens: the token ids of the sequence

    Returns:
        Tensor of shape (seq_len-1, n_components)
        represents the concatenation (along dim=-1) of logit attributions from:
            the direct path (seq-1,1)
            layer 0 logits (seq-1, n_heads)
            layer 1 logits (seq-1, n_heads)
        so n_components = 1 + 2*n_heads
    """
    W_U_correct_tokens = W_U[:, tokens[1:]]
    # SOLUTION
    direct_attributions = einops.einsum(
        W_U_correct_tokens, embed[:-1], "emb seq, seq emb -> seq"
    )
    l1_attributions = einops.einsum(
        W_U_correct_tokens, l1_results[:-1], "emb seq, seq nhead emb -> seq nhead"
    )
    l2_attributions = einops.einsum(
        W_U_correct_tokens, l2_results[:-1], "emb seq, seq nhead emb -> seq nhead"
    )
    return torch.concat(
        [direct_attributions.unsqueeze(-1), l1_attributions, l2_attributions], dim=-1
    )


logits, cache = model.run_with_cache(input_tokens, remove_batch_dim=True)
str_tokens = input_tokens_as_strings
tokens = input_tokens

with t.inference_mode():
    embed = cache["embed"]
    l1_results = cache["result", 0]
    l2_results = cache["result", 1]
    logit_attr = logit_attribution(
        embed,
        l1_results,
        l2_results,
        model.W_U,
        tokens[0],
    )

    # Uses fancy indexing to get a len(tokens[0])-1 length tensor, where the kth entry is the predicted logit for the correct k+1th token
    correct_token_logits = logits[0, torch.arange(len(tokens[0]) - 1), tokens[0, 1:]]

## Looking at it with CircuitsViz

In [None]:
# before even going to SAE, let's look at circuitsviz here
import circuitsvis as cv

import circuitsvis.activations
import circuitsvis.attention
import circuitsvis.logits
import circuitsvis.tokens
import circuitsvis.topk_samples
import circuitsvis.topk_tokens

In [None]:
# first let's see what we have
import tabulate

print(f"{len(input_tokens)=}")

# show the first few elements of the `HookedTransformerConfig`, since that has things like `d_model`, num heads, etc
print(tabulate.tabulate([(k, v) for k, v in cfg.__dict__.items()][:10]))

print(tabulate.tabulate([(k, v.shape) for k, v in cache.items()]))

In [None]:
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.palettes import Viridis256

import numpy as np
import pandas as pd

from bokeh.io import output_notebook

import seaborn as sns

import matplotlib.pyplot as plt

# Enable Bokeh output in the notebook
output_notebook()


def tensor_to_dataframe(
    tensor: torch.Tensor, labels: list[str], tokens: list[str]
) -> pd.DataFrame:
    """
    Convert a 2D PyTorch tensor to a pandas DataFrame.

    Args:
        tensor (torch.Tensor): A 2D tensor to convert.

    Returns:
        pd.DataFrame: A DataFrame representation of the input tensor.

    Raises:
        ValueError: If the input tensor is not 2D.
    """
    if tensor.dim() != 2:
        raise ValueError(f"Input tensor must be 2D, got {tensor.dim()}D")
    if len(labels) != 2:
        raise ValueError(f"Expected labels for both dimensions, got {len(labels)}")

    # Convert tensor to numpy array
    numpy_array = tensor.detach().cpu().numpy()

    # Create DataFrame
    df = pd.DataFrame(numpy_array)

    # Name the index the first label
    df.index.name = labels[0]

    # Name the columns the second label
    df.columns = [f"{labels[1]}_{i}" for i in range(numpy_array.shape[1])]

    return df


def visualize_tensor_heatmap(
    tensor: torch.Tensor,
    title: str = "Tensor Heatmap",
    colormap: list[str] = Viridis256,
    width: int = 800,
    height: int = 400,
) -> None:
    """
    Visualize a 2D tensor as a heatmap.

    Args:
        tensor (torch.Tensor): A 2D tensor to visualize.
        title (str): Title of the heatmap.
        colormap (List[str]): A list of colors to use for the heatmap.
        width (int): Width of the plot in pixels.
        height (int): Height of the plot in pixels.

    """

    # Ensure tensor is 2D
    if tensor.dim() != 2:
        raise ValueError(f"Input tensor must be 2D, got {tensor.shape}")

    # convert tensor to dataframe
    df = tensor_to_dataframe(tensor)

    # Create a 2D grid of coordinates
    y, x = np.mgrid[0 : data.shape[0], 0 : data.shape[1]]

    # Flatten the arrays
    x = x.flatten()
    y = y.flatten()
    z = data.flatten()

    # Create a ColumnDataSource
    source = ColumnDataSource(
        data=dict(
            x=x,
            y=y,
            z=z,
            color=Viridis256[:: int(256 / len(z))][: len(z)],  # Map values to colors
        )
    )

    # Create the figure
    p = figure(
        title="Tensor Heatmap",
        x_range=(0, data.shape[1]),
        y_range=(0, data.shape[0]),
        toolbar_location="below",
        tools="pan,wheel_zoom,box_zoom,reset",
    )

    # Add rectangular glyphs
    p.rect(
        x="x",
        y="y",
        width=1,
        height=1,
        source=source,
        fill_color="color",
        line_color=None,
    )

    # Add hover tool
    hover = HoverTool(tooltips=[("x", "@x"), ("y", "@y"), ("value", "@z{0.000}")])
    p.add_tools(hover)

    # Invert y-axis to match tensor indexing
    p.y_range.start, p.y_range.end = p.y_range.end, p.y_range.start

    # Show the plot
    show(p)

In [None]:
print(tabulate.tabulate([(k, v[0].shape) for k, v in cache.items()]))

In [None]:
# let's go ahead and just use first batch
def first_batch(tensor: Float32[torch.Tensor, "b t c"]) -> Float32[torch.Tensor, "t c"]:
    return tensor[0]

In [None]:
model

In [24]:
import torch.nn as nn

from typing import Iterable, TypeVar

import tabulate

T = TypeVar("T")


# alias for `print(tabulate.tabulate(data))`
def print_table(data: T) -> None:
    print(tabulate.tabulate(data))


# Define a function to print module weights recursively
def print_module_weights(module: nn.Module) -> Iterable[tuple[str, str]]:
    """
    Recursively prints the weights of a PyTorch module and its submodules.

    This function traverses through the module hierarchy, printing information
    about parameters that require gradients and are not hook-related.

    Example:
        >>> print_table(print_module_weights(model))

        ------------------  ----------------------
        embed.W_E           torch.Size([29, 14])
        pos_embed.W_pos     torch.Size([9, 14])
        blocks.0.ln1.w      torch.Size([14])
        blocks.0.ln1.b      torch.Size([14])
        blocks.0.ln2.w      torch.Size([14])
        blocks.0.ln2.b      torch.Size([14])
        blocks.0.attn.W_Q   torch.Size([3, 14, 4])
        blocks.0.attn.W_O   torch.Size([3, 4, 14])
        blocks.0.attn.b_Q   torch.Size([3, 4])
        blocks.0.attn.b_O   torch.Size([14])
        blocks.0.attn.W_K   torch.Size([3, 14, 4])
        blocks.0.attn.W_V   torch.Size([3, 14, 4])
        blocks.0.attn.b_K   torch.Size([3, 4])
        blocks.0.attn.b_V   torch.Size([3, 4])
        blocks.0.mlp.W_in   torch.Size([14, 56])
        blocks.0.mlp.b_in   torch.Size([56])
        blocks.0.mlp.W_out  torch.Size([56, 14])
        blocks.0.mlp.b_out  torch.Size([14])
        ln_final.w          torch.Size([14])
        ln_final.b          torch.Size([14])
        unembed.W_U         torch.Size([14, 29])
        unembed.b_U         torch.Size([29])
        ------------------  ----------------------

    Args:
        module (nn.Module): The PyTorch module to inspect.
        prefix (str, optional): A string prefix for indentation in the output.
                                Defaults to an empty string.

    Returns:
        Iterable[tuple[str, str]]: A list of tuples, where each tuple contains
            the name and shape of the parameter.
    """

    # Iterate through named parameters of the module
    for name, param in module.named_parameters():

        # Check if parameter requires gradient and doesn't start with 'hook_'
        if param.requires_grad and not name.startswith("hook_"):

            # yield parameter name and type
            yield f"{name}", f"{param.shape}"


def print_cache(cache: transformer_lens.ActivationCache) -> None:
    print(tabulate.tabulate([(k, v[0].shape) for k, v in cache.items()]))

In [None]:
print("Weights in the model:")
print_table(print_module_weights(model))

In [25]:
print("Cached activations:")
print_cache(cache)

Cached activations:
------------------------------  -----------------------
hook_embed                      torch.Size([20, 16])
hook_pos_embed                  torch.Size([20, 16])
blocks.0.hook_resid_pre         torch.Size([20, 16])
blocks.0.ln1.hook_scale         torch.Size([20, 1])
blocks.0.ln1.hook_normalized    torch.Size([20, 16])
blocks.0.attn.hook_q            torch.Size([20, 2, 8])
blocks.0.attn.hook_k            torch.Size([20, 2, 8])
blocks.0.attn.hook_v            torch.Size([20, 2, 8])
blocks.0.attn.hook_attn_scores  torch.Size([2, 20, 20])
blocks.0.attn.hook_pattern      torch.Size([2, 20, 20])
blocks.0.attn.hook_z            torch.Size([20, 2, 8])
blocks.0.hook_attn_out          torch.Size([20, 16])
blocks.0.hook_resid_mid         torch.Size([20, 16])
blocks.0.ln2.hook_scale         torch.Size([20, 1])
blocks.0.ln2.hook_normalized    torch.Size([20, 16])
blocks.0.mlp.hook_pre           torch.Size([20, 64])
blocks.0.mlp.hook_post          torch.Size([20, 64])
blocks.0.ho

In [23]:
def plot_cache_activation(
    cache: transformer_lens.ActivationCache,
    cache_key: str,
    input_tokens_as_strings: list[str],
) -> None:

    activations = first_batch(cache[cache_key])

    figsize = (4, 4)

    # make figure smaller for vectors
    if activations.shape[-1] == 1:
        figsize = (4, 1.5)

    # for larger activations like MLP, allow it to be taller
    elif activations.shape[-1] > 20:
        figsize = (4, 12)

    plt.figure(figsize=figsize)

    sns.heatmap(
        activations.cpu().numpy().T,
        cmap="coolwarm",
        center=0,
        xticklabels=input_tokens_as_strings,
    )

    plt.title(cache_key)

    # TODO(bschoen): Allow specifying this
    #
    plt.ylabel("Embedding Dimension")
    plt.xlabel("Token")

    plt.tight_layout()
    plt.show()


for cache_key in [
    "hook_embed",
    "hook_pos_embed",
    "blocks.0.hook_resid_pre",
    "blocks.0.ln1.hook_scale",
    "blocks.0.ln1.hook_normalized",
    "blocks.0.hook_attn_out",
    "blocks.0.hook_resid_mid",
    "blocks.0.ln2.hook_scale",
    "blocks.0.ln2.hook_normalized",
    "blocks.0.mlp.hook_pre",
    "blocks.0.mlp.hook_post",
    "blocks.0.hook_mlp_out",
    "blocks.0.hook_resid_post",
    "ln_final.hook_scale",
    "ln_final.hook_normalized",
]:

    plot_cache_activation(
        cache=cache,
        cache_key=cache_key,
        input_tokens_as_strings=input_tokens_as_strings,
    )

NameError: name 'first_batch' is not defined

In [None]:
# visualize MLP

import matplotlib.pyplot as plt
import seaborn as sns
import torch


def plot_mlp_weights_and_biases(model):
    # Function to plot heatmaps for MLP weights and biases

    def plot_weight_bias_pair(weight, bias, title):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        sns.heatmap(weight.detach().cpu().numpy(), ax=ax1, cmap="coolwarm", center=0)
        ax1.set_title(f"{title} - Weights")
        ax1.set_xlabel("Output dimension")
        ax1.set_ylabel("Input dimension")

        sns.heatmap(
            bias.detach().cpu().numpy().reshape(-1, 1),
            ax=ax2,
            cmap="coolwarm",
            center=0,
        )
        ax2.set_title(f"{title} - Biases")
        ax2.set_xlabel("Bias")
        ax2.set_ylabel("Dimension")

        plt.tight_layout()
        plt.show()

    # MLP weights and biases
    plot_weight_bias_pair(
        model.blocks[0].mlp.W_in, model.blocks[0].mlp.b_in, "MLP Input"
    )
    plot_weight_bias_pair(
        model.blocks[0].mlp.W_out, model.blocks[0].mlp.b_out, "MLP Output"
    )

    # Layer Norm final
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    sns.heatmap(
        model.ln_final.w.detach().cpu().numpy().reshape(1, -1),
        cmap="coolwarm",
        center=1,
    )
    plt.title("Layer Norm Final - Weights")
    plt.subplot(1, 2, 2)
    sns.heatmap(
        model.ln_final.b.detach().cpu().numpy().reshape(1, -1),
        cmap="coolwarm",
        center=0,
    )
    plt.title("Layer Norm Final - Biases")
    plt.tight_layout()
    plt.show()

    # Unembed
    plot_weight_bias_pair(model.unembed.W_U, model.unembed.b_U, "Unembed")


# Call the function
plot_mlp_weights_and_biases(model)

# Comment: Additional visualizations that could be useful:
# 1. Histograms of weight/bias distributions
# 2. 3D surface plots for weights to show patterns
# 3. Network architecture diagram with weight magnitudes represented by line thickness
# 4. Animated heatmaps showing weight changes during training

In [None]:
def plot_weight_bias_activation(
    weight,
    bias,
    activation,
    title: str,
) -> None:

    activation = first_batch(activation)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7))

    sns.heatmap(weight.detach().cpu().numpy().T, ax=ax1, cmap="coolwarm", center=0)
    ax1.set_title(f"{title} - Weight")

    sns.barplot(x=list(range(len(bias))), y=bias.detach().cpu().numpy(), ax=ax2)
    ax2.set_title(f"{title} - Bias")
    ax2.set_xlabel("Index")
    ax2.set_ylabel("Value")

    sns.heatmap(activation.detach().cpu().numpy().T, ax=ax3, cmap="coolwarm", center=0)
    ax3.set_title(f"{title} - Activation")

    plt.tight_layout()
    plt.show()


plot_weight_bias_activation(
    model.embed.W_E,
    torch.zeros(model.embed.W_E.shape[1]),
    cache["hook_embed"],
    "Embedding",
)
plot_weight_bias_activation(
    model.pos_embed.W_pos,
    torch.zeros(model.pos_embed.W_pos.shape[1]),
    cache["hook_pos_embed"],
    "Positional Embedding",
)

In [None]:
# TODO(bschoen): Hook residual pre?

In [None]:
# Plotting LayerNorm components


def plot_layernorm(scale, normalized, title):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    scale = first_batch(scale)
    normalized = first_batch(normalized)

    sns.barplot(
        x=list(range(len(scale))), y=scale.squeeze().detach().cpu().numpy(), ax=ax1
    )
    ax1.set_title(f"{title} - Scale")
    ax1.set_xlabel("Index")
    ax1.set_ylabel("Value")

    sns.heatmap(normalized.detach().cpu().numpy().T, ax=ax2, cmap="coolwarm", center=0)
    ax2.set_title(f"{title} - Normalized")

    plt.tight_layout()
    plt.show()


plot_layernorm(
    cache["blocks.0.ln1.hook_scale"],
    cache["blocks.0.ln1.hook_normalized"],
    "LayerNorm 1",
)

In [None]:
# Plotting MLP components
plot_weight_bias_activation(
    model.blocks[0].mlp.W_in,
    model.blocks[0].mlp.b_in,
    cache["blocks.0.mlp.hook_pre"],
    "MLP Input",
)
plot_weight_bias_activation(
    model.blocks[0].mlp.W_out,
    model.blocks[0].mlp.b_out,
    cache["blocks.0.mlp.hook_post"],
    "MLP Output",
)

#### circuitsvis.activations

In [None]:
# tokens := List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`)
# activations := Activations of the shape [tokens x layers x neurons] if single sample or list of [tokens x layers x neurons] if multiple samples

# take first batch for now
activations = cache["blocks.0.hook_mlp_out"][0]
print(f"{activations.shape=}")

# reshape [tokens x neurons] -> [tokens x 1 x neurons]
#  - `-1` means to automatically infer the size of the last dimension
activations_view = activations.view(len(input_tokens), cfg.n_layers, -1)

print(f"{activations_view.shape=}")

# convert to strings (which this function expects)
input_tokens_as_strings = [token_to_string(x.item()) for x in input_tokens]

# TODO(bschoen): Is there a way to essentially stack these? Claude can probably give the React for that

# so here we can visualize activations for a `torch.Size([1, 8, 16])`, which is most
# of them since this is the size of the embedding dimension
circuitsvis.activations.text_neuron_activations(
    tokens=[token_to_string(x.item()) for x in input_tokens],
    activations=activations_view,
)

#### circuitsvis.attention

In [None]:
# tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length as the list of values.
# attention: Attention head activations of the shape [dest_tokens x src_tokens]
# max_value: Maximum value. Used to determine how dark the token color is when positive (i.e. based on how close it is to the maximum value).
# min_value: Minimum value. Used to determine how dark the token color is when negative (i.e. based on how close it is to the minimum value).
# negative_color: Color for negative values
# positive_color: Color for positive values.
# show_axis_labels: Whether to show axis labels.
# mask_upper_tri: Whether or not to mask the upper triangular portion of the attention patterns. Should be true for causal attention, false for bidirectional attention.


# take first batch
# ex: torch.Size([4, 8, 8]) -> [n_heads, n_ctx, n_ctx]
# note: `blocks.0.attn.hook_attn_scores` is too early (not normalized?)
attention = cache["blocks.0.attn.hook_pattern"][0]

print(f"{attention.shape=}")

circuitsvis.attention.attention_heads(
    tokens=input_tokens_as_strings,
    attention=attention,
    max_value=1,
    min_value=-1,
    negative_color="blue",
    positive_color="red",
    mask_upper_tri=True,
)

#### circuitsvis.logits

In [None]:
# this is the normal one we usually show, i.e.
# cv.logits.token_log_probs(
#     token_indices=input_tokens,
#     log_probs=log_probs,
#     to_string=token_to_string,
# )

#### circuitsvis.tokens

In [None]:
# for example, we'll look at each

# take first batch, ex: torch.Size([8, 16])
pos_embed = cache["hook_pos_embed"][0]

# low level function for coloring tokens according to single value
for i in range(cfg.d_model):
    display(
        circuitsvis.tokens.colored_tokens(
            tokens=input_tokens_as_strings,
            values=pos_embed[:, i],
            negative_color="blue",
            positive_color="red",
        )
    )

    # only display a few for example
    # if i >= 2:
    #    break

In [None]:
# take first batch
# ex: torch.size([8, 16]) = [n_ctx, d_model]
attention_out = cache["blocks.0.hook_attn_out"][0]

circuitsvis.tokens.colored_tokens_multi(
    tokens=input_tokens_as_strings,
    values=attention_out,
    labels=[str(x) for x in range(cfg.d_model)],
)

In [None]:
circuitsvis.tokens.visualize_model_performance(
    tokens=input_tokens,
    str_tokens=input_tokens_as_strings,
    logits=logits,
)

#### circuitsvis.topk_samples

In [None]:
circuitsvis.topk_samples.topk_samples??

#### circuitsvis.topk_tokens

In [None]:
circuitsvis.topk_tokens.topk_tokens??

## SAE

In [None]:
for layer_index in range(cfg.n_layers):
    imshow(
        transformer_lens.utils.to_numpy(cache["attn", layer_index].mean([0, 1])),
        title=f"Layer {layer_index} Attention Pattern",
        height=400,
        width=400,
    )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import dataclasses

Loss = Float32[torch.Tensor, ""]
MSELoss = Float32[torch.Tensor, ""]
WeightedSparsityLoss = Float32[torch.Tensor, ""]

Logits = Float32[torch.Tensor, "n_ctx d_vocab"]
BatchedLogits = Float32[torch.Tensor, "batch n_ctx d_vocab"]

ModelActivations = Float32[torch.Tensor, "n_ctx d_model"]
BatchedModelActivations = Float32[torch.Tensor, "batch n_ctx d_model"]

FlattenedModelActivations = Float32[torch.Tensor, "d_sae_in"]

BatchedFlattenedModelActivations = Float32[torch.Tensor, "batch d_sae_in"]
BatchedSAEActivations = Float32[torch.Tensor, "batch d_sae_model"]


@dataclasses.dataclass
class SAEOutput:
    sae_activations: BatchedSAEActivations
    reconstructed_model_activations: BatchedFlattenedModelActivations


def sparse_loss_kl_divergence(
    flattened_model_activations: BatchedFlattenedModelActivations,
    sae_output: SAEOutput,
    sparsity_target: float,
    sparsity_weight: float,
    epsilon: float = 1e-7,
) -> tuple[Loss, MSELoss, WeightedSparsityLoss]:

    # same as dense loss (this is constant?)
    mse_loss = F.mse_loss(
        sae_output.reconstructed_model_activations,
        flattened_model_activations,
    )

    # KL divergence for sparsity
    avg_activation = torch.mean(sae_output.sae_activations, dim=0)

    # print(f'[pre-clamping] {avg_activation=}')

    # Add epsilon for numerical stability
    avg_activation = torch.clamp(avg_activation, epsilon, 1 - epsilon)

    kl_div = sparsity_target * torch.log(sparsity_target / avg_activation) + (
        1 - sparsity_target
    ) * torch.log((1 - sparsity_target) / (1 - avg_activation))
    kl_div = torch.sum(kl_div)

    # `sparsity_weight` decides how much we weight `KL-Divergence`
    sparsity_penalty = sparsity_weight * kl_div

    # print(f"{mse_loss=}, {avg_activation=}, {kl_div.item()}, {sparsity_penalty=}")

    return mse_loss + sparsity_penalty, mse_loss, sparsity_penalty

In [None]:
def sparse_loss_l1_norm(
    flattened_model_activations: BatchedFlattenedModelActivations,
    sae_output: SAEOutput,
    sparsity_weight: float,
) -> tuple[Loss, MSELoss, WeightedSparsityLoss]:

    # Reconstruction loss (Mean Squared Error)
    mse_loss = F.mse_loss(
        sae_output.reconstructed_model_activations,
        flattened_model_activations,
    )

    # L1 sparsity penalty
    l1_penalty = torch.mean(torch.abs(sae_output.sae_activations))

    sparsity_penalty = sparsity_weight * l1_penalty

    # Total loss
    total_loss = mse_loss + sparsity_penalty

    return total_loss, mse_loss, sparsity_penalty

In [None]:
import dataclasses


@dataclasses.dataclass
class SparseAutoencoderConfig:
    d_in: int
    d_model: int


# TODO(bschoen): Start using the config pattern, it stays typesafe and allows
#                easy logging to things like wandb
class SparseAutoencoder(nn.Module):
    def __init__(
        self,
        cfg: SparseAutoencoderConfig,
    ) -> None:

        print(f"Creating SparseAutoencoder with {cfg}")

        super(SparseAutoencoder, self).__init__()

        self.d_in = cfg.d_in
        self.d_model = cfg.d_model

        self.encoder = nn.Linear(cfg.d_in, cfg.d_model)
        self.decoder = nn.Linear(cfg.d_model, cfg.d_in)

    def forward(
        self,
        x: BatchedFlattenedModelActivations,
    ) -> SAEOutput:

        # TODO(bschoen): Which activation function should we use?
        encoded = F.gelu(self.encoder(x))

        decoded = self.decoder(encoded)

        return SAEOutput(
            sae_activations=encoded,
            reconstructed_model_activations=decoded,
        )

In [None]:
import lightning.pytorch


@dataclasses.dataclass
class LightningSparseAutoencoderConfig:

    model_config: transformer_lens.HookedTransformerConfig
    sae_config: SparseAutoencoderConfig
    learning_rate: float
    sparsity_weight: float


# note: this kind of lightning adapter is a common pattern: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#starter-example
class LightningSparseAutoencoder(lightning.pytorch.LightningModule):
    def __init__(
        self,
        cfg: LightningSparseAutoencoderConfig,
    ) -> None:

        super(LightningSparseAutoencoder, self).__init__()

        self.model = transformer_lens.HookedTransformer(cfg=cfg.model_config)
        self.sae = SparseAutoencoder(cfg=cfg.sae_config)
        self.cfg = cfg

    def forward(self, inputs, target):
        return self.model(inputs, target)

    def training_step(self, batch, batch_idx: int) -> Loss:
        inputs, target = batch

        self.model
        output = self(inputs, target)
        loss = torch.nn.functional.cr(output, target.view(-1))
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.1)

In [None]:
hook_id = "blocks.0.hook_mlp_out"

cache[hook_id].shape

In [None]:
# Training loop
sae_num_epochs = 100000
sae_expansion_factor = 64

learning_rate = 5e-4

# both arbitrary for now
# - Start small: A common approach is to begin with a relatively small sparsity weight,
#                typically in the range of 1e-5 to 1e-3. This allows the model to
#                learn meaningful representations before enforcing strong sparsity
#                constraints.
sparsity_weight: float = 1e-3  # Weight of the sparsity loss in the total loss
sparsity_target: float = 0.05  # Target average activation of hidden neurons

print(f"Training SAE for {hook_id}...")
sae_d_in = (cfg.n_ctx - 1) * cfg.d_model  # -1 since not predicting first token
sae_d_model = sae_d_in * sae_expansion_factor

sae_cfg = SparseAutoencoderConfig(
    d_in=sae_d_in,
    d_model=sae_d_model,
)

sae_model = SparseAutoencoder(cfg=sae_cfg)
sae_model.to(device)

sae_optimizer = optim.Adam(sae_model.parameters(), lr=learning_rate)

wandb.init(
    project="toy-problem-hooked-transformer-sae",
    config={
        "sae_num_epochs": sae_num_epochs,
        "sae_expansion_factor": sae_expansion_factor,
        "learning_rate": learning_rate,
        "sparsity_weight": sparsity_weight,
        "sparsity_target": sparsity_target,
        "sae_d_in": sae_d_in,
        "sae_d_model": sae_d_model,
        "hook_id": hook_id,
    },
)

# put model itself into eval mode so doesn't change
model.eval()

# go through the training data again, this time training the sae on the activations
for epoch, batch in tqdm.tqdm(
    zip(
        range(sae_num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # run through the model (with cache) to get the activations
    logits, cache = model.run_with_cache(tokens)

    # ex: torch.Size([4, 8, 16])
    activations = cache[hook_id]

    # ex: torch.Size([4, 128])
    flattened_activations = activations.reshape(activations.size(0), -1)

    sae_optimizer.zero_grad()

    # now the SAE model is given the *activations*
    sae_output = sae_model.forward(flattened_activations)

    # compute loss

    total_loss, reconstruction_loss, weighted_sparsity_loss = sparse_loss_kl_divergence(
        flattened_activations,
        sae_output,
        sparsity_target=sparsity_target,
        sparsity_weight=sparsity_weight,
    )

    """total_loss, reconstruction_loss, weighted_sparsity_loss = sparse_loss_l1_norm(
        flattened_model_activations=flattened_activations,
        sae_output=sae_output,
        sparsity_weight=sparsity_weight,
    )"""

    total_loss.backward()

    sae_optimizer.step()

    if epoch % 500 == 0:
        print(
            f"Step {epoch}, "
            f"Total Loss: {total_loss.item():.6f}, "
            f"Reconstruction Loss: {reconstruction_loss.item():.6f}, "
            f"Sparsity Loss: {weighted_sparsity_loss.item():.6f}",
        )

        wandb.log(
            {
                "epoch": epoch,
                "total_loss": total_loss.item(),
                "reconstruction_loss": reconstruction_loss.item(),
                "weighted_sparsity_loss": weighted_sparsity_loss.item(),
            }
        )

wandb.finish()

#### Dictionary Learning Implementation

See [simple_dictionary_learning.ipynb](simple_dictionary_learning.ipynb) for a details

#### Extracting the learned dictionary

In [None]:
# Creating SparseAutoencoder with d_in=128, d_model=512, sparsity_target=0.05
dictionary: Float32[torch.Tensor, "sae_hidden sae_in"] = (
    sae_model.encoder.weight.detach()
)

# ex: Dictionary shape: torch.Size([512, 128])
print(f"Dictionary shape: {dictionary.shape}")

In [None]:
# Reshape dictionary elements to match original activation shape
# (essentially `unflatting`)
reshaped_dictionary = dictionary.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

# Motivation: Extract the learned features (dictionary elements) from the encoder weights
# ex: Dictionary shape: torch.Size([512, 8, 16])
print(f"Dictionary shape: {reshaped_dictionary.shape}")

In [None]:
# It's always worth checking this sort of thing when you do this by hand
# to check that you haven't got the wrong site, or are missing a
# scaling factor or something like this.
#
# This is like the overfitting thing

In [None]:
import matplotlib.pyplot as plt

In [None]:
# let's look at an example batch from `test`

# set both to eval mode
model.eval()
sae_model.eval()

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

_, cache = model.run_with_cache(x)

activations = cache[hook_id]

print(f"Activations shape: {activations.shape}")

# flatten it
flattened_activations = activations.reshape(activations.size(0), -1)

print(f"{flattened_activations.shape=}")

sae_outputs = sae_model(flattened_activations)

print(f"{sae_outputs.sae_activations.shape=}")
print(f"{sae_outputs.reconstructed_model_activations.shape=}")

# now we can get the dictionary
dictionary = sae_model.encoder.weight.detach()

print(f"Dictionary shape: {dictionary.shape}")

# now we can get the sparse coefficients
alpha = dictionary @ flattened_activations.T

### Determine Quality Of SAE

In [None]:
def calculate_sparsity(
    sae_activations: BatchedSAEActivations,
    threshold: float = 1e-5,
) -> float:
    """
    Calculate sparsity of SAE activations across a batch.

    Args:
    sae_activations (torch.Tensor): The activations from the Sparse Autoencoder.
                                    Shape: (batch, d_sae_model)
    threshold (float): The threshold below which an activation is considered "inactive".

    Returns:
    float: The average sparsity value across the batch (fraction of inactive neurons).
    """
    # Count the number of neurons that are below the threshold (inactive)
    inactive_neurons = torch.sum(torch.abs(sae_activations) < threshold, dim=1)

    # Calculate the fraction of inactive neurons for each item in the batch
    sparsity_per_item = inactive_neurons.float() / sae_activations.shape[1]

    # Take the mean across the batch
    average_sparsity = torch.mean(sparsity_per_item)

    return average_sparsity.item()

In [None]:
def calculate_explained_variance(
    reconstructed_model_activations: BatchedFlattenedModelActivations,
    flattened_activations: BatchedFlattenedModelActivations,
) -> float:
    """
    Calculate the explained variance of the SAE activations.
    """

    numerator = torch.mean(
        (reconstructed_model_activations[:, 1:] - flattened_activations[:, 1:]) ** 2
    )
    denominator = flattened_activations[:, 1:].to(torch.float32).var()

    explained_variance = 1 - (numerator / denominator)

    return explained_variance.item()

In [None]:
# explained_variance=0.995 -> good, basically all the variance is explained by our SAE
# sparsity=0.0045 -> good, very sparse, and more sparse than our target of 0.05
explained_variance = calculate_explained_variance(
    sae_outputs.reconstructed_model_activations,
    flattened_activations,
)
print(f"{explained_variance=:.4f}")

sparsity = calculate_sparsity(sae_outputs.sae_activations)
print(f"{sparsity=:.4f}")

In [None]:
# Let's analyze the relationship between SAE activations and input features

# TODO(bschoen): Oh `imshow` is huge here!

# 1. Visualize the dictionary (encoder weights)
plt.figure(figsize=(12, 8))
plt.imshow(dictionary.cpu().T, aspect="auto", cmap="RdBu_r")
plt.colorbar()
plt.title("SAE Dictionary (Encoder Weights)")
plt.xlabel("Dictionary Elements")
plt.ylabel("Input Features")
plt.show()

In [None]:
# 2. Find the most active neurons for each input
top_k = 5  # Number of top activations to consider

# so this is essentially the top 5 activations over `batch_size` examples
top_activations = torch.topk(sae_outputs.sae_activations, k=top_k, dim=1)

# Visualization of top activations
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
sns.heatmap(
    top_activations.values.detach().cpu().numpy(), cmap="viridis", annot=True, fmt=".2f"
)
plt.title("Top 5 Activation Values")
plt.xlabel("Top K")
plt.ylabel("Batch Sample")

plt.subplot(2, 1, 2)
sns.heatmap(
    top_activations.indices.detach().cpu().numpy(), cmap="YlOrRd", annot=True, fmt="d"
)
plt.title("Indices of Top 5 Activations")
plt.xlabel("Top K")
plt.ylabel("Batch Sample")

plt.tight_layout()
plt.show()

# Additional analysis: frequency of top neurons
top_neuron_counts = torch.bincount(
    top_activations.indices.flatten().detach().cpu(),
    minlength=sae_outputs.sae_activations.shape[1],
)
top_10_neurons = torch.topk(top_neuron_counts, k=10)

plt.figure(figsize=(10, 6))
plt.bar(range(10), top_10_neurons.values.detach().cpu().numpy())
plt.title("Top 10 Most Frequently Activated Neurons")
plt.xlabel("Neuron Index")
plt.ylabel("Activation Frequency")
plt.xticks(range(10), top_10_neurons.indices.detach().cpu().numpy())
plt.show()

In [None]:
sae_outputs.sae_activations[:, 1210]

In [None]:
print(f"{sae_outputs.sae_activations.shape=}")
print(f"{top_activations.values.shape=}")
print(f"{top_activations.indices.shape=}")

In [None]:
print(top_activations.indices)

In [None]:
# ex: 51 and 410 show up a lot
sns.heatmap(top_activations.values.cpu().T, cmap="viridis")

In [None]:
# 3. Analyze feature importance for each neuron
feature_importance = torch.abs(dictionary).sum(dim=1)
top_features = torch.topk(feature_importance, k=10)

print(f"{dictionary.shape=}")
print(f"{feature_importance.shape=}")
print(f"{top_features.values.shape=}")
print(f"{top_features.indices.shape=}")

top_features

In [None]:
print("\nTop 10 most important neurons:")
for i, (value, index) in enumerate(
    zip(top_features.values.tolist(), top_features.indices.tolist())
):
    print(f"Neuron {index}:\t{value:.4f}")

In [None]:
top_features.values.tolist()

In [None]:
top_features.indices.tolist()

In [None]:
# 4. Visualize activations for a few examples

# first look at a single batch
sae_activations = sae_outputs.sae_activations[0].detach().cpu()

print(f"{sae_activations.shape=}")

plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)

# Look at a single batch
plt.bar(range(sae_activations.shape[0]), sae_activations)

plt.title(f"SAE Activations for Example")
plt.xlabel("Neuron")
plt.ylabel("Activation")
plt.tight_layout()
plt.show()

In [None]:
# 5. Reconstruct input features from SAE activations
#
# Take a single batch first
reconstructed_model_activations = (
    sae_outputs.reconstructed_model_activations.detach().cpu()
)

# 6. Compare original and reconstructed features
num_features = 5

plt.figure(figsize=(15, 3 * num_features))
for i in range(num_features):
    plt.subplot(num_features, 1, i + 1)
    plt.ylim(-1, 1)  # Set y-axis range from -1 to 1
    plt.plot(flattened_activations[:, i].cpu(), label="Original", alpha=0.5)
    plt.plot(reconstructed_model_activations[:, i], label="Reconstructed", alpha=0.5)
    plt.title(f"Feature {i}: Original vs Reconstructed")
    plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# 7. Correlation between SAE activations and input features
correlation_matrix = torch.corrcoef(
    torch.cat([sae_outputs.sae_activations, flattened_activations], dim=1).T
)
num_neurons = sae_outputs.sae_activations.shape[1]
neuron_feature_correlation = correlation_matrix[:num_neurons, num_neurons:]

plt.figure(figsize=(12, 8))
plt.imshow(
    neuron_feature_correlation.detach().cpu(),
    aspect="auto",
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
)
plt.colorbar()
plt.title("Correlation between SAE Neurons and Input Features")
plt.xlabel("Input Features")
plt.ylabel("SAE Neurons")
plt.show()

In [None]:
sae_outputs.sae_activations

In [None]:
# collect max activations


with torch.no_grad():

    # go through the training data again, but don't cycle, no reason to go through more than once
    for batch in tqdm.tqdm(train_loader):

        tokens, target = batch

        tokens, target = tokens.to(device), target.to(device)

        # run through the model (with cache) to get the activations
        logits, cache = model.run_with_cache(tokens)

        # ex: torch.Size([4, 8, 16])
        activations = cache[hook_id]

        # ex: torch.Size([4, 128])
        flattened_activations = activations.reshape(activations.size(0), -1)

        # now the SAE model is given the *activations*
        encoded, decoded = sae_model(flattened_activations)

        sae_activations = encoded

        # sae_activations.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

        # max_activations = torch.max(encoded, dim=1)

        break

In [None]:
alpha = sae_model.encoder.weight @ flattened_activations[0]

print(f"{alpha.shape=}")

In [None]:
torch.mean(torch.abs(alpha))

In [None]:
sae_activations[0].shape

In [None]:
8 * 16