In [19]:
# note: same thing for REPL
# note: we use this instead of magic because `black` will otherwise fail to format
#
# Enable autoreload to automatically reload modules when they change

from IPython import get_ipython

# do this so that formatter not messed up
ipython = get_ipython()
ipython.run_line_magic("load_ext", "autoreload")
ipython.run_line_magic("autoreload", "2")

# Import commonly used libraries
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

# graphics
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go

# type annotation
import jaxtyping
from jaxtyping import Float, Float32, Int64, jaxtyped
from typeguard import typechecked as typechecker

# more itertools
import more_itertools as mi

# itertools
import itertools
import collections

# tensor manipulation
import einops

# automatically apply jaxtyping
# %load_ext jaxtyping
# %jaxtyping.typechecker typeguard.typechecked

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# Load Model

In [5]:
# load the device we'll use (GPU or MPS)
device = transformer_lens.utils.get_device()

print(f"Using device: {device}")

Using device: mps


In [31]:
# load our model
model_name = "gpt2-small"
model = transformer_lens.HookedTransformer.from_pretrained(
    model_name,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
    device=device,
)



Loaded pretrained model gpt2-small into HookedTransformer


In [32]:
# sanity check with an example
example_prompt = "After John and Mary went to the store, John gave a bottle of milk to"
example_answer = " Mary"
transformer_lens.utils.test_prompt(
    example_prompt,
    example_answer,
    model,
    prepend_bos=True,
)

Tokenized prompt: ['<|endoftext|>', 'After', ' John', ' and', ' Mary', ' went', ' to', ' the', ' store', ',', ' John', ' gave', ' a', ' bottle', ' of', ' milk', ' to']
Tokenized answer: [' Mary']


Top 0th token. Logit: 18.09 Prob: 70.07% Token: | Mary|
Top 1th token. Logit: 15.38 Prob:  4.67% Token: | the|
Top 2th token. Logit: 15.35 Prob:  4.54% Token: | John|
Top 3th token. Logit: 15.25 Prob:  4.11% Token: | them|
Top 4th token. Logit: 14.84 Prob:  2.73% Token: | his|
Top 5th token. Logit: 14.06 Prob:  1.24% Token: | her|
Top 6th token. Logit: 13.54 Prob:  0.74% Token: | a|
Top 7th token. Logit: 13.52 Prob:  0.73% Token: | their|
Top 8th token. Logit: 13.13 Prob:  0.49% Token: | Jesus|
Top 9th token. Logit: 12.97 Prob:  0.42% Token: | him|


# Define Transcoder Config

In [24]:
import dataclasses

from min_transcoder.transcoder import (
    TranscoderResults,
    TranscoderConfig,
    Transcoder,
)


@dataclasses.dataclass
class TranscoderTrainingConfig:

    # Name of the layer to hook into for feature extraction
    hook_point: str
    out_hook_point: str

    batch_size: int = 32
    num_epochs: int = 5
    learning_rate: float = 1e-3

    # Coefficient for L1 regularization
    l1_coefficient: float = 0.01

    @property
    def hook_point_layer(self) -> int:
        "Parse out the hook point layer as int ex: 'blocks.8.ln2.hook_normalized' -> 8"
        return int(self.hook_point.split(".")[-2])


class TranscoderLoss:
    loss: Float[torch.Tensor, ""]
    mse_loss: Float[torch.Tensor, ""]
    l1_loss: Float[torch.Tensor, ""]


def compute_loss(
    cfg: TranscoderTrainingConfig,
    mlp_out: Float[torch.Tensor, "..."],
    results: TranscoderResults,
) -> TranscoderLoss:

    mse_loss_per_batch: Float[torch.Tensor, "..."] = (
        torch.pow((results.transcoder_out - mlp_out.float()), 2)
        / (mlp_out**2).sum(dim=-1, keepdim=True).sqrt()
    )

    mse_loss = mse_loss_per_batch.mean()

    sparsity = torch.abs(results.hidden_activations).sum(dim=1).mean(dim=(0,))

    l1_loss = cfg.l1_coefficient * sparsity

    loss = mse_loss + l1_loss

    return TranscoderLoss(loss=loss, mse_loss=mse_loss, l1_loss=l1_loss)


transcoder_expansion_factor = 4

# reference: https://github.com/jacobdunefsky/transcoder_circuits/blob/master/train_transcoder.py#L28
#
transcoder_cfg = TranscoderConfig(
    d_in=model.cfg.d_model,
    d_out=model.cfg.d_model,
    # our transcoder has a hidden dimension of d_mlp * expansion factor
    d_hidden=model.cfg.d_mlp * transcoder_expansion_factor,
    dtype=model.cfg.dtype,
    device=device,
)

In [25]:
transcoder = Transcoder(cfg=transcoder_cfg)

# Load Data

In [12]:
import datasets
import numpy as np

dataset = datasets.load_dataset(
    path="NeelNanda/pile-10k",
    split="train",
    streaming=False,
)

Generating train split: 100%|██████████| 10000/10000 [00:00<00:00, 104130.77 examples/s]


In [23]:
token_dataset = transformer_lens.utils.tokenize_and_concatenate(
    dataset=dataset,
    tokenizer=model.tokenizer,
    streaming=True,
    max_length=model.cfg.n_ctx,
    add_bos_token=model.cfg.default_prepend_bos,
)

Map:   0%|          | 0/10000 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (229134 > 1024). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 10000/10000 [00:16<00:00, 607.32 examples/s]


In [27]:
# torch.Size([16957, 1024])
print(token_dataset["tokens"].shape)

torch.Size([16957, 1024])


# Train Transcoder

In [29]:
# Define training parameters
training_cfg = TranscoderTrainingConfig(
    batch_size=32,
    num_epochs=5,
    learning_rate=1e-3,
    # "hook_point" is the TransformerLens HookPoint representing
    #    the input activations to the transcoder that we want to train on.
    # Here, "ln2.hook_normalized" refers to the activations after the
    #    pre-MLP LayerNorm -- that is, the inputs to the MLP.
    # You might alternatively prefer to train on "blocks.8.hook_resid_mid",
    #    which corresponds to the input to the pre-MLP LayerNorm.
    hook_point="blocks.8.ln2.hook_normalized",
    # "out_hook_point" is the TransformerLens HookPoint representing
    #    the output activations that the transcoder should reconstruct.
    # In our use case, we're using transcoders to interpret MLP sublayers.
    # This means that our transcoder will take in the input to an MLP and
    #    attempt to spit out the output of the MLP (but in the form of a
    #    sparse linear combination of feature vectors).
    # As such, we want to grab the "hook_mlp_out" activations from our
    #    transformer, which (as the name suggests), represent the
    #    output activations of the original MLP sublayer.
    out_hook_point="blocks.8.hook_mlp_out",
)

# Create a DataLoader for batching
#
# for batch in dataloader:
#     print(batch.shape) # torch.Size([32, 1024])
#     break
#
dataloader = torch.utils.data.DataLoader(
    token_dataset["tokens"],
    batch_size=training_cfg.batch_size,
    shuffle=True,
)

In [10]:
import tqdm

# Initialize optimizer
optimizer = torch.optim.AdamW(transcoder.parameters(), lr=training_cfg.learning_rate)

# TODO(bschoen): Learning rate scheduler

# Training loop
for epoch in range(training_cfg.num_epochs):
    total_loss = 0
    num_batches = 0

    # Do a training step.
    transcoder.train()

    # Make sure the W_dec is still zero-norm
    transcoder.set_decoder_norm_to_unit_norm()

    for batch in tqdm.tqdm(
        dataloader, desc=f"Epoch {epoch+1}/{training_cfg.num_epochs}"
    ):

        optimizer.zero_grad()

        # move batch to device
        batch = batch.to(device)

        # Get MLP input and output activations
        with torch.no_grad():

            # TODO(bschoen): Is it faster to just run with cache and extract?
            _, cache = model.run_with_cache(batch)

        mlp_in = cache[training_cfg.hook_point]
        mlp_out = cache[training_cfg.out_hook_point]

        # Forward pass through transcoder
        transcoder_results = transcoder(mlp_in)

        # Compute loss
        loss_result = compute_loss(training_cfg, mlp_out, transcoder_results)

        # Backward pass and optimization
        loss_result.loss.backward()
        optimizer.step()

        total_loss += loss_result.loss.item()
        num_batches += 1

    # Print epoch statistics
    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{training_cfg.num_epochs}, Average Loss: {avg_loss:.4f}")

print("Training completed!")

(25600, 128)


# Compute Loss When Substituting MLP with Transcoder

In [None]:
@torch.no_grad()
def get_test_loss_when_replacing_mlp_with_transcoder(
    self,
    batch_tokens: Float[torch.Tensor, "batch seq"],
    transcoder: Transcoder,
    model: transformer_lens.HookedTransformer,
) -> Float[torch.Tensor, ""]:
    """
    A method for running the model with the SAE activations in order to return the
    loss returns per token loss when activations are substituted in.

    """
    old_mlp = model.blocks[self.cfg.hook_point_layer]

    class TranscoderWrapper(torch.nn.Module):
        def __init__(self, transcoder: Transcoder):
            super().__init__()
            self.transcoder = transcoder

        def forward(
            self,
            x: Float[torch.Tensor, "... d_in"],
        ) -> Float[torch.Tensor, "... d_out"]:
            return self.transcoder(x)[0]

    model.blocks[self.cfg.hook_point_layer].mlp = TranscoderWrapper(self)
    ce_loss_with_recons = model.run_with_hooks(batch_tokens, return_type="loss")
    model.blocks[self.cfg.hook_point_layer] = old_mlp

    return ce_loss_with_recons