In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from typing import Iterable

# HookedTransformer

* [TransformerLens - Tutorial - Trains HookedTransformer from Scratch](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb)

```python
import transformers

# note: it's probably easier to just operate on tokens outside of the model,
#       that'll also make it clearer where tokenizer is used
#
# okay wrapping a pretrained tokenizer *can* be done:
# - https://huggingface.co/learn/nlp-course/chapter6/8#building-a-bpe-tokenizer-from-scratch
# - but none of the models support just naive encoding
#   - https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.BPE
class HookedTransformer:
    cfg: HookedTransformerConfig

    # note: actually does an `isinstance` check in the constructor
    tokenizer: transformers.PreTrainedTokenizerBase | None
```

In [3]:
import transformer_lens

from jaxtyping import Int64, Float32

import numpy as np
import plotly.express as px
import plotly.io as pio

import torch
import torch.utils.data

In [4]:
# plotting code copied over from transformer_lens tutorial notebook


def line(tensor: torch.Tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor: torch.Tensor, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

In [5]:
device = transformer_lens.utils.get_device()

print(f"Using device: {device}")

Using device: mps


### Setup Sample Generator

In [144]:
import string
import itertools
import more_itertools


class SpecialToken:
    # note: as assume a BOS token because transformerlens expects it
    BOS = "<"
    # we use a EOS token for convenience
    EOS = ">"


# note: without length, the model doesn't need to learn induction heads, just directly copies


# TODO(bschoen): Allow this to generalize in the future
def generate_sample() -> Iterable[str]:
    """Generate palindrom samples like `<abc|cba|abc>`."""

    # Generate all combinations of lowercase letters
    characters = string.ascii_lowercase

    # note: chosen arbitrarily
    lengths = [4, 5, 6, 7]

    # pad to max length
    max_length = 1 + max(lengths) + 1 + max(lengths) + 1 + max(lengths) + 1

    # set max number to take of each length
    max_combinations_per_length = 10000

    for length in lengths:

        for combination_index, combination in enumerate(
            itertools.product(characters, repeat=length)
        ):

            if combination_index > max_combinations_per_length:
                break

            combination_str = "".join(combination)
            reversed_str = "".join(reversed(combination_str))

            sample = (
                SpecialToken.BOS
                + combination_str
                + "|"
                + reversed_str
                + "|"
                + combination_str
                + SpecialToken.EOS
            )

            # Pad the sample to max_length with EOS tokens
            padded_sample = sample.ljust(max_length, SpecialToken.EOS)

            yield padded_sample  # Return the padded sample


# show a few examples
[x for x in more_itertools.take(10, generate_sample())]

['<aaaa|aaaa|aaaa>>>>>>>>>>',
 '<aaab|baaa|aaab>>>>>>>>>>',
 '<aaac|caaa|aaac>>>>>>>>>>',
 '<aaad|daaa|aaad>>>>>>>>>>',
 '<aaae|eaaa|aaae>>>>>>>>>>',
 '<aaaf|faaa|aaaf>>>>>>>>>>',
 '<aaag|gaaa|aaag>>>>>>>>>>',
 '<aaah|haaa|aaah>>>>>>>>>>',
 '<aaai|iaaa|aaai>>>>>>>>>>',
 '<aaaj|jaaa|aaaj>>>>>>>>>>']

### Setup Tokenizer

In [54]:
from gpt_from_scratch.naive_tokenizer import NaiveTokenizer

vocab = string.ascii_lowercase + "|" + SpecialToken.BOS + SpecialToken.EOS

tokenizer = NaiveTokenizer.from_text(vocab)

In [64]:
from gpt_from_scratch import tokenizer_utils

# test tokenizer
input_text = "<abc|cba|abc><bd|db|bd>>>>"
tokenizer_utils.show_token_mapping(tokenizer, input_text)

Input:		<abc|cba|abc><bd|db|bd>>>>
Tokenized:	[43m[97m<[0m[41m[97ma[0m[42m[97mb[0m[44m[97mc[0m[46m[97m|[0m[45m[97mc[0m[43m[97mb[0m[41m[97ma[0m[42m[97m|[0m[44m[97ma[0m[46m[97mb[0m[45m[97mc[0m[43m[97m>[0m[41m[97m<[0m[42m[97mb[0m[44m[97md[0m[46m[97m|[0m[45m[97md[0m[43m[97mb[0m[41m[97m|[0m[42m[97mb[0m[44m[97md[0m[46m[97m>[0m[45m[97m>[0m[43m[97m>[0m[41m[97m>[0m
Token ID | Token Bytes | Token String
---------+-------------+--------------
       0 | [38;5;2m3C[0m | '<'
          [48;5;1m[38;5;15m<[0mabc|cba|abc><bd|db|bd>>>>
          U+003C LESS-THAN SIGN (1 bytes: [38;5;2m3C[0m)
       2 | [38;5;2m61[0m | 'a'
          <[48;5;1m[38;5;15ma[0mbc|cba|abc><bd|db|bd>>>>
          U+0061 LATIN SMALL LETTER A (1 bytes: [38;5;2m61[0m)
       3 | [38;5;2m62[0m | 'b'
          <a[48;5;1m[38;5;15mb[0mc|cba|abc><bd|db|bd>>>>
          U+0062 LATIN SMALL LETTER B (1 bytes: [38;5;2m62[0m)
       4 | [38;5;

### Setup Loss Function

In [65]:
def loss_fn(logits, target):
    # standard cross entropy loss
    return torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)),
        target.view(-1),
    )

### Evaluate On Test

In [125]:
def evaluate_loss_on_test_batches(
    model: transformer_lens.HookedTransformer,
    data_loader: torch.utils.data.DataLoader,
    max_batches: int,
) -> float:

    # Set the model to evaluation mode
    model.eval()

    losses = []

    with torch.no_grad():  # Disable gradient computation

        for batch_index, batch in enumerate(data_loader):

            # if batch_index > max_batches:
            #    break

            x, y = batch

            x, y = x.to(device), y.to(device)

            logits = model(x)

            loss = loss_fn(logits, y)

            losses.append(loss.item())

    # Set the model back to training mode
    model.train()

    return sum(losses) / len(losses)

### Setup Data Loaders

In [139]:
class AutoregressiveDataset(torch.utils.data.Dataset):
    def __init__(self, samples: list[str], tokenizer: NaiveTokenizer) -> None:
        self.samples = samples
        self.tokenizer = tokenizer  # Assuming tokenizer is defined in the global scope

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens = self.tokenizer.encode(sample)

        # Convert to tensor and add batch dimension
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)

        return x, y


def make_batch_dataloader(
    samples: list[str],
    tokenizer: NaiveTokenizer,
    batch_size: int,
) -> tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:

    dataset = AutoregressiveDataset(samples=samples, tokenizer=tokenizer)

    # Create DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        # drop the last batch if it's incomplete
        drop_last=True,
    )

    return dataset, dataloader


# Example usage:
# batch_generator = make_batch_generator(tokenizer, batch_size=4)
# for x, y in batch_generator:
#     # x is input, y is target (x shifted by 1)
#     pass

In [150]:
import random

# split into test and train
all_samples = list(generate_sample())

# note: 4394 batches = (26 * 26 * 26) / 4
print(f"{len(all_samples)} samples")

# Randomly shuffle all_samples
random.shuffle(all_samples)  # In-place shuffling of the list

# Inline comment explaining the motivation
# We shuffle the samples to ensure a random distribution of data points
# between the training and test sets, reducing potential bias


# max_samples = 10
# print(f'Capping at {max_samples} batches first to make sure we can overfit')
# all_samples = all_samples[:max_samples]

test_train_ratio = 0.1

test_size = int(test_train_ratio * len(all_samples))

# put remaining ones into train
train_size = len(all_samples) - test_size

train_samples = all_samples[:train_size]
test_samples = all_samples[train_size:]

print(f"{len(train_samples)=}")
print(f"{len(test_samples)=}")

# now we can finally construct dataloaders
batch_size = 32

train_dataset, train_loader = make_batch_dataloader(
    samples=train_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)
# Split test_samples based on the number of '>' characters
test_samples_by_difficulty = {}
for sample in test_samples:
    num_gt = sample.count(">")
    if num_gt not in test_samples_by_difficulty:
        test_samples_by_difficulty[num_gt] = []
    test_samples_by_difficulty[num_gt].append(sample)

# Create dataloaders for each difficulty level
test_datasets = {}
test_loaders = {}
for difficulty, samples in test_samples_by_difficulty.items():
    test_datasets[difficulty], test_loaders[difficulty] = make_batch_dataloader(
        samples=samples,
        tokenizer=tokenizer,
        batch_size=batch_size,
    )

# Inline comment explaining the motivation
# We split the test samples based on the number of '>' characters to create
# separate datasets for different difficulty levels. This allows us to evaluate
# the model's performance across varying complexities of input sequences.

40004 samples
len(train_samples)=36004
len(test_samples)=4000


### Setup Optuna

In [None]:
import optuna

In [None]:
# we want an optuna usage pattern that we can easily adapt to our training loop when experimenting

optuna.Trial

### Setup Model

In [146]:
# now we know our vocab size from our sample generation
def make_hooked_transformer_config(
    n_layers: int,
    d_model: int,
    n_heads: int,
) -> transformer_lens.HookedTransformerConfig:

    for sample in generate_sample():
        n_ctx = len(sample)
        break

    cfg = transformer_lens.HookedTransformerConfig(
        n_layers=n_layers,
        d_model=d_model,
        d_head=d_model // n_heads,
        # The number of attention heads.
        # If not specified, will be set to d_in // d_head.
        # (This is represented by a default value of -1)
        n_heads=n_heads,
        # The dimensionality of the feedforward mlp network.
        # Defaults to 4 * d_in, and in an attn-only model is None.
        # TODO(bschoen): Need to try out also setting `attn_only`
        # d_mlp=None,
        # note: transformerlens does the same thing if this is not set
        d_vocab=len(tokenizer.byte_to_token_dict),
        # length of the longest sample is our context length
        n_ctx=n_ctx,
        act_fn="relu",
        normalization_type="LN",
        # note: must be set, otherwise tries to default to cuda / cpu (not mps)
        device=device.type,
    )

    print(f"Num params: {cfg.n_params}")

    return cfg

### Training

In [159]:
import tqdm

import torch.optim

import wandb

import dataclasses
import json

LossValue = float


def print_json(value):
    print(json.dumps(value, indent=2))


# everything customizable via optuna
@dataclasses.dataclass(frozen=True)
class ModelAndTrainingConfig:

    # input
    train_loader: torch.utils.data.DataLoader
    test_loaders: dict[int, torch.utils.data.DataLoader]

    # training
    num_epochs: int = 10000
    eval_test_every_n: int = 500

    # model
    n_layers: int = 2
    d_model: int = 12
    n_heads: int = 4

    # optimizers
    betas: tuple[float, float] = (0.9, 0.999)
    learning_rate: float = 1e-4
    max_grad_norm: float = 1.0
    weight_decay: float = 0.1

    def get_hooked_transformer_config(self) -> transformer_lens.HookedTransformerConfig:
        return make_hooked_transformer_config(
            n_layers=self.n_layers,
            d_model=self.d_model,
            n_heads=self.n_heads,
        )

    def to_dict(self) -> dict[str, str | int]:
        dict_repr = dataclasses.asdict(self)
        dict_repr.pop("train_loader")
        dict_repr.pop("test_loaders")
        return dict_repr


@dataclasses.dataclass(frozen=True)
class TrainModelResult:
    model: transformer_lens.HookedTransformer

    # returned because optuna needs it
    # TODO(bschoen): Is this usually val loss?
    train_loss: LossValue


def train_model(cfg: ModelAndTrainingConfig) -> TrainModelResult:

    # create new model instance
    ht_cfg = cfg.get_hooked_transformer_config()
    model = transformer_lens.HookedTransformer(ht_cfg)

    # setup optimizers
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=cfg.learning_rate,
        betas=cfg.betas,
        weight_decay=cfg.weight_decay,
    )
    # scheduler = torch.optim.lr_scheduler.LambdaLR(
    #    optimizer, lambda i: min(i / 100, 1.0)
    # )

    num_epochs = cfg.num_epochs

    # setup wandb
    wandb.init(project="toy-problem-hooked-transformer-v2", config=cfg.to_dict())

    losses = []

    for epoch, batch in tqdm.tqdm(
        zip(
            range(num_epochs),
            itertools.cycle(train_loader),
        )
    ):

        tokens, target = batch

        tokens, target = tokens.to(device), target.to(device)

        # ex: torch.Size([4, 9, 29])
        logits: Float32[torch.Tensor, "b t c"] = model(tokens)

        # print(f"Logits:\n{logits.shape}")
        loss = loss_fn(logits, target)

        loss.backward()

        if cfg.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm)

        optimizer.step()

        optimizer.zero_grad()

        # scheduler.step()

        losses.append(loss.item())

        # more frequently than eval, print out train loss
        # if (epoch % (cfg.eval_test_every_n // 10)) == 0:
        #
        #    print(f"Epoch {epoch}, " f"Train loss: {loss.item():.6f}")

        # TODO(bschoen): Shouldn't you actually divide loss by batch size?
        # TODO(bschoen): Do we want like an `is trial` (for example logging last one)
        if (epoch % cfg.eval_test_every_n) == 0:

            # skip evaluating test loss if we just started training
            if epoch == 0:
                continue

            print("Evaluating test loss...")

            # compute loss at each difficulty
            test_loss_by_difficulty = {}

            for difficulty, test_loader in cfg.test_loaders.items():

                test_loss = evaluate_loss_on_test_batches(
                    model,
                    test_loader,
                    max_batches=100,
                )

                test_loss_by_difficulty[difficulty] = test_loss

            wandb_log_dict = {"epoch": epoch, "train_loss": loss.item()}

            for difficulty, test_loss in test_loss_by_difficulty.items():

                wandb_log_dict[f"train_loss_difficulty_{difficulty}"] = test_loss

            # print_json(wandb_log_dict)

            # Log metrics
            wandb.log(wandb_log_dict)

    wandb.finish()

    # log locally to sanity check
    # px.line(losses, labels={"x": "Epoch", "y": "Train Loss"})

    print(f"Final train loss: {loss.item():.6f}")

    # take model out of train
    model.eval()

    return TrainModelResult(model=model, train_loss=loss.item())


# train brief run to test code
training_config = ModelAndTrainingConfig(
    num_epochs=30000,
    eval_test_every_n=100,
    weight_decay=0.2,
    train_loader=train_loader,
    test_loaders=test_loaders,
)

result = train_model(training_config)

# for compatibility with code later
model = result.model
cfg = training_config.get_hooked_transformer_config()

Num params: 3456


100it [00:02, 47.77it/s]

Evaluating test loss...


200it [00:04, 48.14it/s]

Evaluating test loss...


297it [00:07, 49.34it/s]

Evaluating test loss...


395it [00:10, 50.76it/s]

Evaluating test loss...


465it [00:12, 45.53it/s]

## Optuna Study

In [None]:
# TODO(bschoen): Do need to use lightning if want to do this generally
# note: generally do want to iterate on this part itself, i.e. once find promising learning rate, searching other hyperparameters
def objective(trial: optuna.Trial) -> float:

    # TODO(bschoen): up to one per position, eh might as well try it

    d_model = trial.suggest_categorical("d_model", [8, 16, 32, 64, 128])
    n_heads = trial.suggest_int("n_heads", 1, 8)

    cfg = ModelAndTrainingConfig(
        num_epochs=1000,
        eval_test_every_n=10000,  # not worth evaluating test loss for study
        n_layers=1,  # trial.suggest_int("n_layers", 1, 2),
        d_model=d_model,
        n_heads=n_heads,
        learning_rate=5e-4,
    )

    # sanity check `d_heads`
    if (cfg.d_model % cfg.n_heads) != 0:
        print(f"Pruning trial for {cfg.d_model=} {cfg.n_heads=}")
        raise optuna.exceptions.TrialPruned()

    result = train_model(cfg)

    return result.train_loss


study_storage_url = "sqlite:///toy-problem-hooked-transformer.db"

study = optuna.create_study(
    directions=[optuna.study.StudyDirection.MINIMIZE],
    storage=study_storage_url,
)

study.optimize(objective, n_trials=10)

print("View by launching optuna dashboard from the command line:")
print(f"optuna-dashboard {study_storage_url}")

In [None]:
# now let's do a real run
training_config = ModelAndTrainingConfig(
    num_epochs=10000,
    eval_test_every_n=1000,
    n_layers=1,
    d_model=16,
    n_heads=1,
)

result = train_model(cfg=training_config)

# for compatibility with code later
model = result.model
cfg = training_config.get_hooked_transformer_config()

In [158]:
# Look at some example output
import circuitsvis as cv

import functools


def visualize_pattern_hook(
    pattern: Float32[torch.Tensor, "batch head_index dest_pos source_pos"],
    hook: transformer_lens.hook_points.HookPoint,
    tokens_as_strings: list[str],
) -> None:
    print("Layer: ", hook.layer())
    display(
        cv.attention.attention_patterns(
            tokens=tokens_as_strings, attention=pattern.mean(0)
        )
    )


# create a custom to_string function since using our own tokenizer
def token_to_string(token: int) -> str:
    return tokenizer.decode([token])


for difficulty, test_loader in test_loaders.items():

    print(difficulty)

    # grab something from the test batch
    example_batch = next(iter(test_loader))

    x, y = example_batch

    example_sample = x[0]

    # example_sample = torch.tensor(tokenizer.encode("<az|za|az>>>>>>>>>>"))

    # grab the first part of it, ex: `<abc|`
    example_prompt = example_sample  # [:8]

    example_prompt = example_prompt.to(device)

    print(f"Using {example_prompt} from {example_sample} (from test set)")

    # note: already encoded
    input_tokens = example_prompt

    # first let's get these as strings so can easily work with them
    input_tokens_as_strings = [token_to_string(x.item()) for x in input_tokens]

    # wrap to bind input tokens
    visualize_pattern_hook_fn = functools.partial(
        visualize_pattern_hook, tokens_as_strings=input_tokens_as_strings
    )

    model.run_with_hooks(
        input_tokens,
        return_type=None,  # For efficiency, we don't need to calculate the logits
        fwd_hooks=[(lambda name: name.endswith("pattern"), visualize_pattern_hook_fn)],
    )

    logits_batch, cache = model.run_with_cache(input_tokens)

    logits = logits_batch[0]

    log_probs = logits.log_softmax(dim=-1)

    cv.logits.token_log_probs(
        token_indices=input_tokens,
        log_probs=log_probs,
        to_string=token_to_string,
    )

1
Using tensor([ 0,  2,  2,  2,  2,  9, 15, 25, 28, 25, 15,  9,  2,  2,  2,  2, 28,  2,
         2,  2,  2,  9, 15, 25], device='mps:0') from tensor([ 0,  2,  2,  2,  2,  9, 15, 25, 28, 25, 15,  9,  2,  2,  2,  2, 28,  2,
         2,  2,  2,  9, 15, 25]) (from test set)
Layer:  0


Layer:  1


7
Using tensor([ 0,  2,  2, 13, 21,  7, 28,  7, 21, 13,  2,  2, 28,  2,  2, 13, 21,  7,
         1,  1,  1,  1,  1,  1], device='mps:0') from tensor([ 0,  2,  2, 13, 21,  7, 28,  7, 21, 13,  2,  2, 28,  2,  2, 13, 21,  7,
         1,  1,  1,  1,  1,  1]) (from test set)
Layer:  0


Layer:  1


10
Using tensor([ 0,  2,  5, 11,  6, 28,  6, 11,  5,  2, 28,  2,  5, 11,  6,  1,  1,  1,
         1,  1,  1,  1,  1,  1], device='mps:0') from tensor([ 0,  2,  5, 11,  6, 28,  6, 11,  5,  2, 28,  2,  5, 11,  6,  1,  1,  1,
         1,  1,  1,  1,  1,  1]) (from test set)
Layer:  0


Layer:  1


4
Using tensor([ 0,  2,  2,  2,  5,  9, 15, 28, 15,  9,  5,  2,  2,  2, 28,  2,  2,  2,
         5,  9, 15,  1,  1,  1], device='mps:0') from tensor([ 0,  2,  2,  2,  5,  9, 15, 28, 15,  9,  5,  2,  2,  2, 28,  2,  2,  2,
         5,  9, 15,  1,  1,  1]) (from test set)
Layer:  0


Layer:  1


### Looking at it with CircuitsViz

In [83]:
# before even going to SAE, let's look at circuitsviz here
import circuitsvis as cv

import circuitsvis.activations
import circuitsvis.attention
import circuitsvis.logits
import circuitsvis.tokens
import circuitsvis.topk_samples
import circuitsvis.topk_tokens

In [84]:
# first let's see what we have
import tabulate

print(f"{len(input_tokens)=}")

# show the first few elements of the `HookedTransformerConfig`, since that has things like `d_model`, num heads, etc
print(tabulate.tabulate([(k, v) for k, v in cfg.__dict__.items()][:10]))

print(tabulate.tabulate([(k, v.shape) for k, v in cache.items()]))

len(input_tokens)=18
----------  ------
n_layers    2
d_model     12
n_ctx       19
d_head      3
model_name  custom
n_heads     4
d_mlp       48
act_fn      relu
d_vocab     29
eps         1e-05
----------  ------
------------------------------  --------------------------
hook_embed                      torch.Size([1, 18, 12])
hook_pos_embed                  torch.Size([1, 18, 12])
blocks.0.hook_resid_pre         torch.Size([1, 18, 12])
blocks.0.ln1.hook_scale         torch.Size([1, 18, 1])
blocks.0.ln1.hook_normalized    torch.Size([1, 18, 12])
blocks.0.attn.hook_q            torch.Size([1, 18, 4, 3])
blocks.0.attn.hook_k            torch.Size([1, 18, 4, 3])
blocks.0.attn.hook_v            torch.Size([1, 18, 4, 3])
blocks.0.attn.hook_attn_scores  torch.Size([1, 4, 18, 18])
blocks.0.attn.hook_pattern      torch.Size([1, 4, 18, 18])
blocks.0.attn.hook_z            torch.Size([1, 18, 4, 3])
blocks.0.hook_attn_out          torch.Size([1, 18, 12])
blocks.0.hook_resid_mid         torch.Siz

In [85]:
from bokeh.plotting import figure, show
from bokeh.models import ColumnDataSource, HoverTool
from bokeh.palettes import Viridis256

import numpy as np
import pandas as pd

from bokeh.io import output_notebook

import seaborn as sns

import matplotlib.pyplot as plt

# Enable Bokeh output in the notebook
output_notebook()


def tensor_to_dataframe(
    tensor: torch.Tensor, labels: list[str], tokens: list[str]
) -> pd.DataFrame:
    """
    Convert a 2D PyTorch tensor to a pandas DataFrame.

    Args:
        tensor (torch.Tensor): A 2D tensor to convert.

    Returns:
        pd.DataFrame: A DataFrame representation of the input tensor.

    Raises:
        ValueError: If the input tensor is not 2D.
    """
    if tensor.dim() != 2:
        raise ValueError(f"Input tensor must be 2D, got {tensor.dim()}D")
    if len(labels) != 2:
        raise ValueError(f"Expected labels for both dimensions, got {len(labels)}")

    # Convert tensor to numpy array
    numpy_array = tensor.detach().cpu().numpy()

    # Create DataFrame
    df = pd.DataFrame(numpy_array)

    # Name the index the first label
    df.index.name = labels[0]

    # Name the columns the second label
    df.columns = [f"{labels[1]}_{i}" for i in range(numpy_array.shape[1])]

    return df


def visualize_tensor_heatmap(
    tensor: torch.Tensor,
    title: str = "Tensor Heatmap",
    colormap: list[str] = Viridis256,
    width: int = 800,
    height: int = 400,
) -> None:
    """
    Visualize a 2D tensor as a heatmap.

    Args:
        tensor (torch.Tensor): A 2D tensor to visualize.
        title (str): Title of the heatmap.
        colormap (List[str]): A list of colors to use for the heatmap.
        width (int): Width of the plot in pixels.
        height (int): Height of the plot in pixels.

    """

    # Ensure tensor is 2D
    if tensor.dim() != 2:
        raise ValueError(f"Input tensor must be 2D, got {tensor.shape}")

    # convert tensor to dataframe
    df = tensor_to_dataframe(tensor)

    # Create a 2D grid of coordinates
    y, x = np.mgrid[0 : data.shape[0], 0 : data.shape[1]]

    # Flatten the arrays
    x = x.flatten()
    y = y.flatten()
    z = data.flatten()

    # Create a ColumnDataSource
    source = ColumnDataSource(
        data=dict(
            x=x,
            y=y,
            z=z,
            color=Viridis256[:: int(256 / len(z))][: len(z)],  # Map values to colors
        )
    )

    # Create the figure
    p = figure(
        title="Tensor Heatmap",
        x_range=(0, data.shape[1]),
        y_range=(0, data.shape[0]),
        toolbar_location="below",
        tools="pan,wheel_zoom,box_zoom,reset",
    )

    # Add rectangular glyphs
    p.rect(
        x="x",
        y="y",
        width=1,
        height=1,
        source=source,
        fill_color="color",
        line_color=None,
    )

    # Add hover tool
    hover = HoverTool(tooltips=[("x", "@x"), ("y", "@y"), ("value", "@z{0.000}")])
    p.add_tools(hover)

    # Invert y-axis to match tensor indexing
    p.y_range.start, p.y_range.end = p.y_range.end, p.y_range.start

    # Show the plot
    show(p)

In [44]:
print(tabulate.tabulate([(k, v[0].shape) for k, v in cache.items()]))

------------------------------  -----------------------
hook_embed                      torch.Size([15, 12])
hook_pos_embed                  torch.Size([15, 12])
blocks.0.hook_resid_pre         torch.Size([15, 12])
blocks.0.ln1.hook_scale         torch.Size([15, 1])
blocks.0.ln1.hook_normalized    torch.Size([15, 12])
blocks.0.attn.hook_q            torch.Size([15, 4, 3])
blocks.0.attn.hook_k            torch.Size([15, 4, 3])
blocks.0.attn.hook_v            torch.Size([15, 4, 3])
blocks.0.attn.hook_attn_scores  torch.Size([4, 15, 15])
blocks.0.attn.hook_pattern      torch.Size([4, 15, 15])
blocks.0.attn.hook_z            torch.Size([15, 4, 3])
blocks.0.hook_attn_out          torch.Size([15, 12])
blocks.0.hook_resid_mid         torch.Size([15, 12])
blocks.0.ln2.hook_scale         torch.Size([15, 1])
blocks.0.ln2.hook_normalized    torch.Size([15, 12])
blocks.0.mlp.hook_pre           torch.Size([15, 48])
blocks.0.mlp.hook_post          torch.Size([15, 48])
blocks.0.hook_mlp_out          

In [45]:
# let's go ahead and just use first batch
def first_batch(tensor: Float32[torch.Tensor, "b t c"]) -> Float32[torch.Tensor, "t c"]:
    return tensor[0]

In [None]:
model

In [46]:
import torch.nn as nn

from typing import Iterable, TypeVar

import tabulate

T = TypeVar("T")


# alias for `print(tabulate.tabulate(data))`
def print_table(data: T) -> None:
    print(tabulate.tabulate(data))


# Define a function to print module weights recursively
def print_module_weights(module: nn.Module) -> Iterable[tuple[str, str]]:
    """
    Recursively prints the weights of a PyTorch module and its submodules.

    This function traverses through the module hierarchy, printing information
    about parameters that require gradients and are not hook-related.

    Example:
        >>> print_table(print_module_weights(model))

        ------------------  ----------------------
        embed.W_E           torch.Size([29, 14])
        pos_embed.W_pos     torch.Size([9, 14])
        blocks.0.ln1.w      torch.Size([14])
        blocks.0.ln1.b      torch.Size([14])
        blocks.0.ln2.w      torch.Size([14])
        blocks.0.ln2.b      torch.Size([14])
        blocks.0.attn.W_Q   torch.Size([3, 14, 4])
        blocks.0.attn.W_O   torch.Size([3, 4, 14])
        blocks.0.attn.b_Q   torch.Size([3, 4])
        blocks.0.attn.b_O   torch.Size([14])
        blocks.0.attn.W_K   torch.Size([3, 14, 4])
        blocks.0.attn.W_V   torch.Size([3, 14, 4])
        blocks.0.attn.b_K   torch.Size([3, 4])
        blocks.0.attn.b_V   torch.Size([3, 4])
        blocks.0.mlp.W_in   torch.Size([14, 56])
        blocks.0.mlp.b_in   torch.Size([56])
        blocks.0.mlp.W_out  torch.Size([56, 14])
        blocks.0.mlp.b_out  torch.Size([14])
        ln_final.w          torch.Size([14])
        ln_final.b          torch.Size([14])
        unembed.W_U         torch.Size([14, 29])
        unembed.b_U         torch.Size([29])
        ------------------  ----------------------

    Args:
        module (nn.Module): The PyTorch module to inspect.
        prefix (str, optional): A string prefix for indentation in the output.
                                Defaults to an empty string.

    Returns:
        Iterable[tuple[str, str]]: A list of tuples, where each tuple contains
            the name and shape of the parameter.
    """

    # Iterate through named parameters of the module
    for name, param in module.named_parameters():

        # Check if parameter requires gradient and doesn't start with 'hook_'
        if param.requires_grad and not name.startswith("hook_"):

            # yield parameter name and type
            yield f"{name}", f"{param.shape}"


def print_cache(cache: transformer_lens.ActivationCache) -> None:
    print(tabulate.tabulate([(k, v[0].shape) for k, v in cache.items()]))

In [None]:
print("Weights in the model:")
print_table(print_module_weights(model))

In [None]:
print("Cached activations:")
print_cache(cache)

In [None]:
def plot_cache_activation(
    cache: transformer_lens.ActivationCache,
    cache_key: str,
    input_tokens_as_strings: list[str],
) -> None:

    activations = first_batch(cache[cache_key])

    figsize = (4, 4)

    # make figure smaller for vectors
    if activations.shape[-1] == 1:
        figsize = (4, 1.5)

    # for larger activations like MLP, allow it to be taller
    elif activations.shape[-1] > 20:
        figsize = (4, 12)

    plt.figure(figsize=figsize)

    sns.heatmap(
        activations.cpu().numpy().T,
        cmap="coolwarm",
        center=0,
        xticklabels=input_tokens_as_strings,
    )

    plt.title(cache_key)

    # TODO(bschoen): Allow specifying this
    #
    plt.ylabel("Embedding Dimension")
    plt.xlabel("Token")

    plt.tight_layout()
    plt.show()


for cache_key in [
    "hook_embed",
    "hook_pos_embed",
    "blocks.0.hook_resid_pre",
    "blocks.0.ln1.hook_scale",
    "blocks.0.ln1.hook_normalized",
    "blocks.0.hook_attn_out",
    "blocks.0.hook_resid_mid",
    "blocks.0.ln2.hook_scale",
    "blocks.0.ln2.hook_normalized",
    "blocks.0.mlp.hook_pre",
    "blocks.0.mlp.hook_post",
    "blocks.0.hook_mlp_out",
    "blocks.0.hook_resid_post",
    "ln_final.hook_scale",
    "ln_final.hook_normalized",
]:

    plot_cache_activation(
        cache=cache,
        cache_key=cache_key,
        input_tokens_as_strings=input_tokens_as_strings,
    )

In [None]:
# visualize MLP

import matplotlib.pyplot as plt
import seaborn as sns
import torch


def plot_mlp_weights_and_biases(model):
    # Function to plot heatmaps for MLP weights and biases

    def plot_weight_bias_pair(weight, bias, title):
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

        sns.heatmap(weight.detach().cpu().numpy(), ax=ax1, cmap="coolwarm", center=0)
        ax1.set_title(f"{title} - Weights")
        ax1.set_xlabel("Output dimension")
        ax1.set_ylabel("Input dimension")

        sns.heatmap(
            bias.detach().cpu().numpy().reshape(-1, 1),
            ax=ax2,
            cmap="coolwarm",
            center=0,
        )
        ax2.set_title(f"{title} - Biases")
        ax2.set_xlabel("Bias")
        ax2.set_ylabel("Dimension")

        plt.tight_layout()
        plt.show()

    # MLP weights and biases
    plot_weight_bias_pair(
        model.blocks[0].mlp.W_in, model.blocks[0].mlp.b_in, "MLP Input"
    )
    plot_weight_bias_pair(
        model.blocks[0].mlp.W_out, model.blocks[0].mlp.b_out, "MLP Output"
    )

    # Layer Norm final
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    sns.heatmap(
        model.ln_final.w.detach().cpu().numpy().reshape(1, -1),
        cmap="coolwarm",
        center=1,
    )
    plt.title("Layer Norm Final - Weights")
    plt.subplot(1, 2, 2)
    sns.heatmap(
        model.ln_final.b.detach().cpu().numpy().reshape(1, -1),
        cmap="coolwarm",
        center=0,
    )
    plt.title("Layer Norm Final - Biases")
    plt.tight_layout()
    plt.show()

    # Unembed
    plot_weight_bias_pair(model.unembed.W_U, model.unembed.b_U, "Unembed")


# Call the function
plot_mlp_weights_and_biases(model)

# Comment: Additional visualizations that could be useful:
# 1. Histograms of weight/bias distributions
# 2. 3D surface plots for weights to show patterns
# 3. Network architecture diagram with weight magnitudes represented by line thickness
# 4. Animated heatmaps showing weight changes during training

In [None]:
def plot_weight_bias_activation(
    weight,
    bias,
    activation,
    title: str,
) -> None:

    activation = first_batch(activation)

    fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 7))

    sns.heatmap(weight.detach().cpu().numpy().T, ax=ax1, cmap="coolwarm", center=0)
    ax1.set_title(f"{title} - Weight")

    sns.barplot(x=list(range(len(bias))), y=bias.detach().cpu().numpy(), ax=ax2)
    ax2.set_title(f"{title} - Bias")
    ax2.set_xlabel("Index")
    ax2.set_ylabel("Value")

    sns.heatmap(activation.detach().cpu().numpy().T, ax=ax3, cmap="coolwarm", center=0)
    ax3.set_title(f"{title} - Activation")

    plt.tight_layout()
    plt.show()


plot_weight_bias_activation(
    model.embed.W_E,
    torch.zeros(model.embed.W_E.shape[1]),
    cache["hook_embed"],
    "Embedding",
)
plot_weight_bias_activation(
    model.pos_embed.W_pos,
    torch.zeros(model.pos_embed.W_pos.shape[1]),
    cache["hook_pos_embed"],
    "Positional Embedding",
)

In [None]:
# TODO(bschoen): Hook residual pre?

In [None]:
# Plotting LayerNorm components


def plot_layernorm(scale, normalized, title):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    scale = first_batch(scale)
    normalized = first_batch(normalized)

    sns.barplot(
        x=list(range(len(scale))), y=scale.squeeze().detach().cpu().numpy(), ax=ax1
    )
    ax1.set_title(f"{title} - Scale")
    ax1.set_xlabel("Index")
    ax1.set_ylabel("Value")

    sns.heatmap(normalized.detach().cpu().numpy().T, ax=ax2, cmap="coolwarm", center=0)
    ax2.set_title(f"{title} - Normalized")

    plt.tight_layout()
    plt.show()


plot_layernorm(
    cache["blocks.0.ln1.hook_scale"],
    cache["blocks.0.ln1.hook_normalized"],
    "LayerNorm 1",
)

In [None]:
# Plotting MLP components
plot_weight_bias_activation(
    model.blocks[0].mlp.W_in,
    model.blocks[0].mlp.b_in,
    cache["blocks.0.mlp.hook_pre"],
    "MLP Input",
)
plot_weight_bias_activation(
    model.blocks[0].mlp.W_out,
    model.blocks[0].mlp.b_out,
    cache["blocks.0.mlp.hook_post"],
    "MLP Output",
)

Layer:  0


Layer:  1


#### circuitsvis.activations

In [47]:
# tokens := List of tokens if single sample (e.g. `["A", "person"]`) or list of lists of tokens (e.g. `[[["A", "person"], ["is", "walking"]]]`)
# activations := Activations of the shape [tokens x layers x neurons] if single sample or list of [tokens x layers x neurons] if multiple samples

# take first batch for now
activations = cache["blocks.0.hook_mlp_out"][0]
print(f"{activations.shape=}")

# reshape [tokens x neurons] -> [tokens x 1 x neurons]
#  - `-1` means to automatically infer the size of the last dimension
activations_view = activations.view(len(input_tokens), cfg.n_layers, -1)

print(f"{activations_view.shape=}")

# convert to strings (which this function expects)
input_tokens_as_strings = [token_to_string(x.item()) for x in input_tokens]

# TODO(bschoen): Is there a way to essentially stack these? Claude can probably give the React for that

# so here we can visualize activations for a `torch.Size([1, 8, 16])`, which is most
# of them since this is the size of the embedding dimension
circuitsvis.activations.text_neuron_activations(
    tokens=[token_to_string(x.item()) for x in input_tokens],
    activations=activations_view,
)

activations.shape=torch.Size([15, 12])
activations_view.shape=torch.Size([15, 2, 6])


#### circuitsvis.attention

In [48]:
# tokens: List of tokens (e.g. `["A", "person"]`). Must be the same length as the list of values.
# attention: Attention head activations of the shape [dest_tokens x src_tokens]
# max_value: Maximum value. Used to determine how dark the token color is when positive (i.e. based on how close it is to the maximum value).
# min_value: Minimum value. Used to determine how dark the token color is when negative (i.e. based on how close it is to the minimum value).
# negative_color: Color for negative values
# positive_color: Color for positive values.
# show_axis_labels: Whether to show axis labels.
# mask_upper_tri: Whether or not to mask the upper triangular portion of the attention patterns. Should be true for causal attention, false for bidirectional attention.


# take first batch
# ex: torch.Size([4, 8, 8]) -> [n_heads, n_ctx, n_ctx]
# note: `blocks.0.attn.hook_attn_scores` is too early (not normalized?)
attention = cache["blocks.0.attn.hook_pattern"][0]

print(f"{attention.shape=}")

circuitsvis.attention.attention_heads(
    tokens=input_tokens_as_strings,
    attention=attention,
    max_value=1,
    min_value=-1,
    negative_color="blue",
    positive_color="red",
    mask_upper_tri=True,
)

attention.shape=torch.Size([4, 15, 15])


#### circuitsvis.logits

In [None]:
# this is the normal one we usually show, i.e.
# cv.logits.token_log_probs(
#     token_indices=input_tokens,
#     log_probs=log_probs,
#     to_string=token_to_string,
# )

#### circuitsvis.tokens

In [49]:
# for example, we'll look at each

# take first batch, ex: torch.Size([8, 16])
pos_embed = cache["hook_pos_embed"][0]

# low level function for coloring tokens according to single value
for i in range(cfg.d_model):
    display(
        circuitsvis.tokens.colored_tokens(
            tokens=input_tokens_as_strings,
            values=pos_embed[:, i],
            negative_color="blue",
            positive_color="red",
        )
    )

    # only display a few for example
    # if i >= 2:
    #    break

In [None]:
# take first batch
# ex: torch.size([8, 16]) = [n_ctx, d_model]
attention_out = cache["blocks.0.hook_attn_out"][0]

circuitsvis.tokens.colored_tokens_multi(
    tokens=input_tokens_as_strings,
    values=attention_out,
    labels=[str(x) for x in range(cfg.d_model)],
)

In [None]:
circuitsvis.tokens.visualize_model_performance(
    tokens=input_tokens,
    str_tokens=input_tokens_as_strings,
    logits=logits,
)

#### circuitsvis.topk_samples

In [None]:
circuitsvis.topk_samples.topk_samples??

#### circuitsvis.topk_tokens

In [None]:
circuitsvis.topk_tokens.topk_tokens??

## SAE

In [None]:
for layer_index in range(cfg.n_layers):
    imshow(
        transformer_lens.utils.to_numpy(cache["attn", layer_index].mean([0, 1])),
        title=f"Layer {layer_index} Attention Pattern",
        height=400,
        width=400,
    )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

import dataclasses

Loss = Float32[torch.Tensor, ""]
MSELoss = Float32[torch.Tensor, ""]
WeightedSparsityLoss = Float32[torch.Tensor, ""]

Logits = Float32[torch.Tensor, "n_ctx d_vocab"]
BatchedLogits = Float32[torch.Tensor, "batch n_ctx d_vocab"]

ModelActivations = Float32[torch.Tensor, "n_ctx d_model"]
BatchedModelActivations = Float32[torch.Tensor, "batch n_ctx d_model"]

FlattenedModelActivations = Float32[torch.Tensor, "d_sae_in"]

BatchedFlattenedModelActivations = Float32[torch.Tensor, "batch d_sae_in"]
BatchedSAEActivations = Float32[torch.Tensor, "batch d_sae_model"]


@dataclasses.dataclass
class SAEOutput:
    sae_activations: BatchedSAEActivations
    reconstructed_model_activations: BatchedFlattenedModelActivations


def sparse_loss_kl_divergence(
    flattened_model_activations: BatchedFlattenedModelActivations,
    sae_output: SAEOutput,
    sparsity_target: float,
    sparsity_weight: float,
    epsilon: float = 1e-7,
) -> tuple[Loss, MSELoss, WeightedSparsityLoss]:

    # same as dense loss (this is constant?)
    mse_loss = F.mse_loss(
        sae_output.reconstructed_model_activations,
        flattened_model_activations,
    )

    # KL divergence for sparsity
    avg_activation = torch.mean(sae_output.sae_activations, dim=0)

    # print(f'[pre-clamping] {avg_activation=}')

    # Add epsilon for numerical stability
    avg_activation = torch.clamp(avg_activation, epsilon, 1 - epsilon)

    kl_div = sparsity_target * torch.log(sparsity_target / avg_activation) + (
        1 - sparsity_target
    ) * torch.log((1 - sparsity_target) / (1 - avg_activation))
    kl_div = torch.sum(kl_div)

    # `sparsity_weight` decides how much we weight `KL-Divergence`
    sparsity_penalty = sparsity_weight * kl_div

    # print(f"{mse_loss=}, {avg_activation=}, {kl_div.item()}, {sparsity_penalty=}")

    return mse_loss + sparsity_penalty, mse_loss, sparsity_penalty

In [None]:
def sparse_loss_l1_norm(
    flattened_model_activations: BatchedFlattenedModelActivations,
    sae_output: SAEOutput,
    sparsity_weight: float,
) -> tuple[Loss, MSELoss, WeightedSparsityLoss]:

    # Reconstruction loss (Mean Squared Error)
    mse_loss = F.mse_loss(
        sae_output.reconstructed_model_activations,
        flattened_model_activations,
    )

    # L1 sparsity penalty
    l1_penalty = torch.mean(torch.abs(sae_output.sae_activations))

    sparsity_penalty = sparsity_weight * l1_penalty

    # Total loss
    total_loss = mse_loss + sparsity_penalty

    return total_loss, mse_loss, sparsity_penalty

In [None]:
import dataclasses


@dataclasses.dataclass
class SparseAutoencoderConfig:
    d_in: int
    d_model: int


# TODO(bschoen): Start using the config pattern, it stays typesafe and allows
#                easy logging to things like wandb
class SparseAutoencoder(nn.Module):
    def __init__(
        self,
        cfg: SparseAutoencoderConfig,
    ) -> None:

        print(f"Creating SparseAutoencoder with {cfg}")

        super(SparseAutoencoder, self).__init__()

        self.d_in = cfg.d_in
        self.d_model = cfg.d_model

        self.encoder = nn.Linear(cfg.d_in, cfg.d_model)
        self.decoder = nn.Linear(cfg.d_model, cfg.d_in)

    def forward(
        self,
        x: BatchedFlattenedModelActivations,
    ) -> SAEOutput:

        # TODO(bschoen): Which activation function should we use?
        encoded = F.gelu(self.encoder(x))

        decoded = self.decoder(encoded)

        return SAEOutput(
            sae_activations=encoded,
            reconstructed_model_activations=decoded,
        )

In [None]:
import lightning.pytorch


@dataclasses.dataclass
class LightningSparseAutoencoderConfig:

    model_config: transformer_lens.HookedTransformerConfig
    sae_config: SparseAutoencoderConfig
    learning_rate: float
    sparsity_weight: float


# note: this kind of lightning adapter is a common pattern: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#starter-example
class LightningSparseAutoencoder(lightning.pytorch.LightningModule):
    def __init__(
        self,
        cfg: LightningSparseAutoencoderConfig,
    ) -> None:

        super(LightningSparseAutoencoder, self).__init__()

        self.model = transformer_lens.HookedTransformer(cfg=cfg.model_config)
        self.sae = SparseAutoencoder(cfg=cfg.sae_config)
        self.cfg = cfg

    def forward(self, inputs, target):
        return self.model(inputs, target)

    def training_step(self, batch, batch_idx: int) -> Loss:
        inputs, target = batch

        self.model
        output = self(inputs, target)
        loss = torch.nn.functional.cr(output, target.view(-1))
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.1)

In [None]:
hook_id = "blocks.0.hook_mlp_out"

cache[hook_id].shape

In [None]:
# Training loop
sae_num_epochs = 100000
sae_expansion_factor = 64

learning_rate = 5e-4

# both arbitrary for now
# - Start small: A common approach is to begin with a relatively small sparsity weight,
#                typically in the range of 1e-5 to 1e-3. This allows the model to
#                learn meaningful representations before enforcing strong sparsity
#                constraints.
sparsity_weight: float = 1e-3  # Weight of the sparsity loss in the total loss
sparsity_target: float = 0.05  # Target average activation of hidden neurons

print(f"Training SAE for {hook_id}...")
sae_d_in = (cfg.n_ctx - 1) * cfg.d_model  # -1 since not predicting first token
sae_d_model = sae_d_in * sae_expansion_factor

sae_cfg = SparseAutoencoderConfig(
    d_in=sae_d_in,
    d_model=sae_d_model,
)

sae_model = SparseAutoencoder(cfg=sae_cfg)
sae_model.to(device)

sae_optimizer = optim.Adam(sae_model.parameters(), lr=learning_rate)

wandb.init(
    project="toy-problem-hooked-transformer-sae",
    config={
        "sae_num_epochs": sae_num_epochs,
        "sae_expansion_factor": sae_expansion_factor,
        "learning_rate": learning_rate,
        "sparsity_weight": sparsity_weight,
        "sparsity_target": sparsity_target,
        "sae_d_in": sae_d_in,
        "sae_d_model": sae_d_model,
        "hook_id": hook_id,
    },
)

# put model itself into eval mode so doesn't change
model.eval()

# go through the training data again, this time training the sae on the activations
for epoch, batch in tqdm.tqdm(
    zip(
        range(sae_num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # run through the model (with cache) to get the activations
    logits, cache = model.run_with_cache(tokens)

    # ex: torch.Size([4, 8, 16])
    activations = cache[hook_id]

    # ex: torch.Size([4, 128])
    flattened_activations = activations.reshape(activations.size(0), -1)

    sae_optimizer.zero_grad()

    # now the SAE model is given the *activations*
    sae_output = sae_model.forward(flattened_activations)

    # compute loss

    total_loss, reconstruction_loss, weighted_sparsity_loss = sparse_loss_kl_divergence(
        flattened_activations,
        sae_output,
        sparsity_target=sparsity_target,
        sparsity_weight=sparsity_weight,
    )

    """total_loss, reconstruction_loss, weighted_sparsity_loss = sparse_loss_l1_norm(
        flattened_model_activations=flattened_activations,
        sae_output=sae_output,
        sparsity_weight=sparsity_weight,
    )"""

    total_loss.backward()

    sae_optimizer.step()

    if epoch % 500 == 0:
        print(
            f"Step {epoch}, "
            f"Total Loss: {total_loss.item():.6f}, "
            f"Reconstruction Loss: {reconstruction_loss.item():.6f}, "
            f"Sparsity Loss: {weighted_sparsity_loss.item():.6f}",
        )

        wandb.log(
            {
                "epoch": epoch,
                "total_loss": total_loss.item(),
                "reconstruction_loss": reconstruction_loss.item(),
                "weighted_sparsity_loss": weighted_sparsity_loss.item(),
            }
        )

wandb.finish()

#### Dictionary Learning Implementation

See [simple_dictionary_learning.ipynb](simple_dictionary_learning.ipynb) for a details

#### Extracting the learned dictionary

In [None]:
# Creating SparseAutoencoder with d_in=128, d_model=512, sparsity_target=0.05
dictionary: Float32[torch.Tensor, "sae_hidden sae_in"] = (
    sae_model.encoder.weight.detach()
)

# ex: Dictionary shape: torch.Size([512, 128])
print(f"Dictionary shape: {dictionary.shape}")

In [None]:
# Reshape dictionary elements to match original activation shape
# (essentially `unflatting`)
reshaped_dictionary = dictionary.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

# Motivation: Extract the learned features (dictionary elements) from the encoder weights
# ex: Dictionary shape: torch.Size([512, 8, 16])
print(f"Dictionary shape: {reshaped_dictionary.shape}")

In [None]:
# It's always worth checking this sort of thing when you do this by hand
# to check that you haven't got the wrong site, or are missing a
# scaling factor or something like this.
#
# This is like the overfitting thing

In [None]:
import matplotlib.pyplot as plt

In [None]:
# let's look at an example batch from `test`

# set both to eval mode
model.eval()
sae_model.eval()

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

_, cache = model.run_with_cache(x)

activations = cache[hook_id]

print(f"Activations shape: {activations.shape}")

# flatten it
flattened_activations = activations.reshape(activations.size(0), -1)

print(f"{flattened_activations.shape=}")

sae_outputs = sae_model(flattened_activations)

print(f"{sae_outputs.sae_activations.shape=}")
print(f"{sae_outputs.reconstructed_model_activations.shape=}")

# now we can get the dictionary
dictionary = sae_model.encoder.weight.detach()

print(f"Dictionary shape: {dictionary.shape}")

# now we can get the sparse coefficients
alpha = dictionary @ flattened_activations.T

### Determine Quality Of SAE

In [None]:
def calculate_sparsity(
    sae_activations: BatchedSAEActivations,
    threshold: float = 1e-5,
) -> float:
    """
    Calculate sparsity of SAE activations across a batch.

    Args:
    sae_activations (torch.Tensor): The activations from the Sparse Autoencoder.
                                    Shape: (batch, d_sae_model)
    threshold (float): The threshold below which an activation is considered "inactive".

    Returns:
    float: The average sparsity value across the batch (fraction of inactive neurons).
    """
    # Count the number of neurons that are below the threshold (inactive)
    inactive_neurons = torch.sum(torch.abs(sae_activations) < threshold, dim=1)

    # Calculate the fraction of inactive neurons for each item in the batch
    sparsity_per_item = inactive_neurons.float() / sae_activations.shape[1]

    # Take the mean across the batch
    average_sparsity = torch.mean(sparsity_per_item)

    return average_sparsity.item()

In [None]:
def calculate_explained_variance(
    reconstructed_model_activations: BatchedFlattenedModelActivations,
    flattened_activations: BatchedFlattenedModelActivations,
) -> float:
    """
    Calculate the explained variance of the SAE activations.
    """

    numerator = torch.mean(
        (reconstructed_model_activations[:, 1:] - flattened_activations[:, 1:]) ** 2
    )
    denominator = flattened_activations[:, 1:].to(torch.float32).var()

    explained_variance = 1 - (numerator / denominator)

    return explained_variance.item()

In [None]:
# explained_variance=0.995 -> good, basically all the variance is explained by our SAE
# sparsity=0.0045 -> good, very sparse, and more sparse than our target of 0.05
explained_variance = calculate_explained_variance(
    sae_outputs.reconstructed_model_activations,
    flattened_activations,
)
print(f"{explained_variance=:.4f}")

sparsity = calculate_sparsity(sae_outputs.sae_activations)
print(f"{sparsity=:.4f}")

In [None]:
# Let's analyze the relationship between SAE activations and input features

# TODO(bschoen): Oh `imshow` is huge here!

# 1. Visualize the dictionary (encoder weights)
plt.figure(figsize=(12, 8))
plt.imshow(dictionary.cpu().T, aspect="auto", cmap="RdBu_r")
plt.colorbar()
plt.title("SAE Dictionary (Encoder Weights)")
plt.xlabel("Dictionary Elements")
plt.ylabel("Input Features")
plt.show()

In [None]:
# 2. Find the most active neurons for each input
top_k = 5  # Number of top activations to consider

# so this is essentially the top 5 activations over `batch_size` examples
top_activations = torch.topk(sae_outputs.sae_activations, k=top_k, dim=1)

# Visualization of top activations
plt.figure(figsize=(12, 8))
plt.subplot(2, 1, 1)
sns.heatmap(
    top_activations.values.detach().cpu().numpy(), cmap="viridis", annot=True, fmt=".2f"
)
plt.title("Top 5 Activation Values")
plt.xlabel("Top K")
plt.ylabel("Batch Sample")

plt.subplot(2, 1, 2)
sns.heatmap(
    top_activations.indices.detach().cpu().numpy(), cmap="YlOrRd", annot=True, fmt="d"
)
plt.title("Indices of Top 5 Activations")
plt.xlabel("Top K")
plt.ylabel("Batch Sample")

plt.tight_layout()
plt.show()

# Additional analysis: frequency of top neurons
top_neuron_counts = torch.bincount(
    top_activations.indices.flatten().detach().cpu(),
    minlength=sae_outputs.sae_activations.shape[1],
)
top_10_neurons = torch.topk(top_neuron_counts, k=10)

plt.figure(figsize=(10, 6))
plt.bar(range(10), top_10_neurons.values.detach().cpu().numpy())
plt.title("Top 10 Most Frequently Activated Neurons")
plt.xlabel("Neuron Index")
plt.ylabel("Activation Frequency")
plt.xticks(range(10), top_10_neurons.indices.detach().cpu().numpy())
plt.show()

In [None]:
sae_outputs.sae_activations[:, 1210]

In [None]:
print(f"{sae_outputs.sae_activations.shape=}")
print(f"{top_activations.values.shape=}")
print(f"{top_activations.indices.shape=}")

In [None]:
print(top_activations.indices)

In [None]:
# ex: 51 and 410 show up a lot
sns.heatmap(top_activations.values.cpu().T, cmap="viridis")

In [None]:
# 3. Analyze feature importance for each neuron
feature_importance = torch.abs(dictionary).sum(dim=1)
top_features = torch.topk(feature_importance, k=10)

print(f"{dictionary.shape=}")
print(f"{feature_importance.shape=}")
print(f"{top_features.values.shape=}")
print(f"{top_features.indices.shape=}")

top_features

In [None]:
print("\nTop 10 most important neurons:")
for i, (value, index) in enumerate(
    zip(top_features.values.tolist(), top_features.indices.tolist())
):
    print(f"Neuron {index}:\t{value:.4f}")

In [None]:
top_features.values.tolist()

In [None]:
top_features.indices.tolist()

In [None]:
# 4. Visualize activations for a few examples

# first look at a single batch
sae_activations = sae_outputs.sae_activations[0].detach().cpu()

print(f"{sae_activations.shape=}")

plt.figure(figsize=(15, 5))
plt.subplot(1, 1, 1)

# Look at a single batch
plt.bar(range(sae_activations.shape[0]), sae_activations)

plt.title(f"SAE Activations for Example")
plt.xlabel("Neuron")
plt.ylabel("Activation")
plt.tight_layout()
plt.show()

In [None]:
# 5. Reconstruct input features from SAE activations
#
# Take a single batch first
reconstructed_model_activations = (
    sae_outputs.reconstructed_model_activations.detach().cpu()
)

# 6. Compare original and reconstructed features
num_features = 5

plt.figure(figsize=(15, 3 * num_features))
for i in range(num_features):
    plt.subplot(num_features, 1, i + 1)
    plt.ylim(-1, 1)  # Set y-axis range from -1 to 1
    plt.plot(flattened_activations[:, i].cpu(), label="Original", alpha=0.5)
    plt.plot(reconstructed_model_activations[:, i], label="Reconstructed", alpha=0.5)
    plt.title(f"Feature {i}: Original vs Reconstructed")
    plt.legend()
plt.tight_layout()
plt.show()

In [None]:
# 7. Correlation between SAE activations and input features
correlation_matrix = torch.corrcoef(
    torch.cat([sae_outputs.sae_activations, flattened_activations], dim=1).T
)
num_neurons = sae_outputs.sae_activations.shape[1]
neuron_feature_correlation = correlation_matrix[:num_neurons, num_neurons:]

plt.figure(figsize=(12, 8))
plt.imshow(
    neuron_feature_correlation.detach().cpu(),
    aspect="auto",
    cmap="RdBu_r",
    vmin=-1,
    vmax=1,
)
plt.colorbar()
plt.title("Correlation between SAE Neurons and Input Features")
plt.xlabel("Input Features")
plt.ylabel("SAE Neurons")
plt.show()

In [None]:
sae_outputs.sae_activations

In [None]:
# collect max activations


with torch.no_grad():

    # go through the training data again, but don't cycle, no reason to go through more than once
    for batch in tqdm.tqdm(train_loader):

        tokens, target = batch

        tokens, target = tokens.to(device), target.to(device)

        # run through the model (with cache) to get the activations
        logits, cache = model.run_with_cache(tokens)

        # ex: torch.Size([4, 8, 16])
        activations = cache[hook_id]

        # ex: torch.Size([4, 128])
        flattened_activations = activations.reshape(activations.size(0), -1)

        # now the SAE model is given the *activations*
        encoded, decoded = sae_model(flattened_activations)

        sae_activations = encoded

        # sae_activations.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

        # max_activations = torch.max(encoded, dim=1)

        break

In [None]:
alpha = sae_model.encoder.weight @ flattened_activations[0]

print(f"{alpha.shape=}")

In [None]:
torch.mean(torch.abs(alpha))

In [None]:
sae_activations[0].shape

In [None]:
8 * 16