In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
from typing import Iterable

# HookedTransformer

* [TransformerLens - Tutorial - Trains HookedTransformer from Scratch](https://colab.research.google.com/github/TransformerLensOrg/TransformerLens/blob/main/demos/No_Position_Experiment.ipynb)

```python
import transformers

# note: it's probably easier to just operate on tokens outside of the model,
#       that'll also make it clearer where tokenizer is used
#
# okay wrapping a pretrained tokenizer *can* be done:
# - https://huggingface.co/learn/nlp-course/chapter6/8#building-a-bpe-tokenizer-from-scratch
# - but none of the models support just naive encoding
#   - https://huggingface.co/docs/tokenizers/api/models#tokenizers.models.BPE
class HookedTransformer:
    cfg: HookedTransformerConfig

    # note: actually does an `isinstance` check in the constructor
    tokenizer: transformers.PreTrainedTokenizerBase | None
```

In [3]:
import transformer_lens

from jaxtyping import Int64, Float32

import numpy as np
import plotly.express as px
import plotly.io as pio

import torch
import torch.utils.data

In [31]:
# plotting code copied over from transformer_lens tutorial notebook

def line(tensor: torch.Tensor, line_labels=None, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    labels = {"y": yaxis, "x": xaxis}
    fig = px.line(tensor, labels=labels, **kwargs)
    if line_labels:
        for c, label in enumerate(line_labels):
            fig.data[c].name = label
    fig.show()


def imshow(tensor: torch.Tensor, yaxis="", xaxis="", **kwargs):
    tensor = transformer_lens.utils.to_numpy(tensor)
    plot_kwargs = {
        "color_continuous_scale": "RdBu",
        "color_continuous_midpoint": 0.0,
        "labels": {"x": xaxis, "y": yaxis},
    }
    plot_kwargs.update(kwargs)
    px.imshow(tensor, **plot_kwargs).show()

In [4]:
device = transformer_lens.utils.get_device()

print(f'Using device: {device}')

Using device: mps


### Setup Sample Generator

In [5]:
import string
import itertools
import more_itertools

class SpecialToken:
    # note: as assume a BOS token because transformerlens expects it
    BOS = '<'
    # we use a EOS token for convenience
    EOS = '>'

# TODO(bschoen): Allow this to generalize in the future
def generate_sample() -> Iterable[str]:
    """Generate palindrom samples like `<abc|cba>`."""

    # Generate all combinations of lowercase letters
    characters = string.ascii_lowercase

    # note: chosen arbitrarily
    length = 3
    
    for combination in itertools.product(characters, repeat=length):

        combination_str = ''.join(combination)
        reversed_str = ''.join(reversed(combination_str))

        yield SpecialToken.BOS + combination_str + '|' + reversed_str + SpecialToken.EOS

# show a few examples
[x for x in more_itertools.take(10, generate_sample())]

['<aaa|aaa>',
 '<aab|baa>',
 '<aac|caa>',
 '<aad|daa>',
 '<aae|eaa>',
 '<aaf|faa>',
 '<aag|gaa>',
 '<aah|haa>',
 '<aai|iaa>',
 '<aaj|jaa>']

### Setup Tokenizer

In [6]:
from gpt_from_scratch.naive_tokenizer import NaiveTokenizer

vocab = string.ascii_lowercase + '|' + SpecialToken.BOS + SpecialToken.EOS

tokenizer = NaiveTokenizer.from_text(vocab)

In [7]:
from gpt_from_scratch import tokenizer_utils

# test tokenizer
input_text = '<abc|cba><bdd|ddb>'
tokenizer_utils.show_token_mapping(tokenizer, input_text)

Input:		<abc|cba><bdd|ddb>
Tokenized:	[44m[97m<[0m[45m[97ma[0m[46m[97mb[0m[42m[97mc[0m[43m[97m|[0m[41m[97mc[0m[44m[97mb[0m[45m[97ma[0m[46m[97m>[0m[42m[97m<[0m[43m[97mb[0m[41m[97md[0m[44m[97md[0m[45m[97m|[0m[46m[97md[0m[42m[97md[0m[43m[97mb[0m[41m[97m>[0m
Token ID | Token Bytes | Token String
---------+-------------+--------------
       0 | [38;5;2m3C[0m | '<'
          [48;5;1m[38;5;15m<[0mabc|cba><bdd|ddb>
          U+003C LESS-THAN SIGN (1 bytes: [38;5;2m3C[0m)
       2 | [38;5;2m61[0m | 'a'
          <[48;5;1m[38;5;15ma[0mbc|cba><bdd|ddb>
          U+0061 LATIN SMALL LETTER A (1 bytes: [38;5;2m61[0m)
       3 | [38;5;2m62[0m | 'b'
          <a[48;5;1m[38;5;15mb[0mc|cba><bdd|ddb>
          U+0062 LATIN SMALL LETTER B (1 bytes: [38;5;2m62[0m)
       4 | [38;5;2m63[0m | 'c'
          <ab[48;5;1m[38;5;15mc[0m|cba><bdd|ddb>
          U+0063 LATIN SMALL LETTER C (1 bytes: [38;5;2m63[0m)
      28 | [38;5;2m

### Setup Model

In [80]:
# now we know our vocab size from our sample generation

cfg = transformer_lens.HookedTransformerConfig(
    n_layers=1,
    d_model=16,
    d_head=16,
    # The number of attention heads.
    # If not specified, will be set to d_in // d_head.
    # (This is represented by a default value of -1)
    n_heads=1,
    # The dimensionality of the feedforward mlp network.
    # Defaults to 4 * d_in, and in an attn-only model is None.
    # d_mlp=16,
    # note: transformerlens does the same thing if this is not set
    d_vocab=len(tokenizer.byte_to_token_dict),
    # length of the first sample is our context length
    n_ctx=len(more_itertools.nth(generate_sample(), 1)),
    act_fn="relu",
    normalization_type="LN",
    # note: must be set, otherwise tries to default to cuda / cpu (not mps)
    device=device.type,
)

print(f'Num params: {cfg.n_params}')

cfg

Num params: 3072


HookedTransformerConfig:
{'act_fn': 'relu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_scale': 4.0,
 'attn_scores_soft_cap': -1.0,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 16,
 'd_mlp': 64,
 'd_model': 16,
 'd_vocab': 29,
 'd_vocab_out': 29,
 'decoder_start_token_id': None,
 'default_prepend_bos': True,
 'device': 'mps',
 'dtype': torch.float32,
 'eps': 1e-05,
 'experts_per_token': None,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': True,
 'initializer_range': 0.2,
 'load_in_4bit': False,
 'model_name': 'custom',
 'n_ctx': 9,
 'n_devices': 1,
 'n_heads': 1,
 'n_key_value_heads': None,
 'n_layers': 1,
 'n_params': 3072,
 'normalization_type': 'LN',
 'num_experts': None,
 'original_architecture': None,
 'output_logits_soft_cap': -1.0,
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rel

### Setup Loss Function

In [23]:
def loss_fn(logits, target):
    # standard cross entropy loss
    return torch.nn.functional.cross_entropy(
        logits.view(-1, logits.size(-1)),
        target.view(-1),
    )

### Evaluate On Test

In [24]:
def evaluate_loss_on_test_batches(
    model: transformer_lens.HookedTransformer,
    data_loader: torch.utils.data.DataLoader,
) -> float:

    # Set the model to evaluation mode
    model.eval()  

    losses = []

    with torch.no_grad():  # Disable gradient computation
        
        for batch in data_loader:
            
            x, y = batch

            x, y = x.to(device), y.to(device)

            logits = model(x)

            loss = loss_fn(logits, y)

            losses.append(loss.item())

    # Set the model back to training mode
    model.train() 

    return sum(losses) / len(losses)

### Setup Data Loaders

In [81]:
class AutoregressiveDataset(torch.utils.data.Dataset):
    def __init__(self, samples: list[str], tokenizer: NaiveTokenizer) -> None:
        self.samples = samples
        self.tokenizer = tokenizer  # Assuming tokenizer is defined in the global scope

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens = self.tokenizer.encode(sample)
        
        # Convert to tensor and add batch dimension
        x = torch.tensor(tokens[:-1], dtype=torch.long)
        y = torch.tensor(tokens[1:], dtype=torch.long)
        
        return x, y

def make_batch_dataloader(
    samples: list[str],
    tokenizer: NaiveTokenizer,
    batch_size: int,
) -> tuple[torch.utils.data.Dataset, torch.utils.data.DataLoader]:

    dataset = AutoregressiveDataset(samples=samples, tokenizer=tokenizer)
    
    # Create DataLoader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        # drop the last batch if it's incomplete
        drop_last=True,
    )
    
    return dataset, dataloader

# Example usage:
# batch_generator = make_batch_generator(tokenizer, batch_size=4)
# for x, y in batch_generator:
#     # x is input, y is target (x shifted by 1)
#     pass


In [82]:
# split into test and train
all_samples = list(generate_sample())

# note: 4394 batches = (26 * 26 * 26) / 4
print(f'{len(all_samples)} samples')

# max_samples = 10
# print(f'Capping at {max_samples} batches first to make sure we can overfit')
# all_samples = all_samples[:max_samples]

test_train_ratio = 0.1

test_size = int(test_train_ratio * len(all_samples))
    
# put remaining ones into train
train_size = len(all_samples) - test_size

train_samples = all_samples[:train_size]
test_samples = all_samples[train_size:]

print(f'{len(train_samples)=}')
print(f'{len(test_samples)=}')

# now we can finally construct dataloaders
batch_size = 4

train_dataset, train_loader = make_batch_dataloader(
    samples=train_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)
test_dataset, test_loader = make_batch_dataloader(
    samples=test_samples,
    tokenizer=tokenizer,
    batch_size=batch_size,
)


17576 samples
len(train_samples)=15819
len(test_samples)=1757


### Training

In [83]:
import tqdm

import torch.optim

import wandb


# create new model instance
model = transformer_lens.HookedTransformer(cfg)

# setup optimizers
lr = 1e-4
betas = (0.9, 0.95)
max_grad_norm = 1.0
wd = 0.1

optimizer = torch.optim.AdamW(model.parameters(), lr=lr, betas=betas, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda i: min(i / 100, 1.0))

num_epochs = 5000

# setup wandb
config = cfg.to_dict()
config.update({
    'num_epochs': num_epochs,
    'batch_size': batch_size,
    'lr': 1e-4,
    'max_grad_norm': 1.0,
    'wd': 0.1,
})
wandb.init(
    project="toy-problem-hooked-transformer",
    config=config,
)

losses = []
test_losses = []

for epoch, batch in tqdm.tqdm(
    zip(
        range(num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # ex: torch.Size([4, 9, 29])
    logits: Float32[torch.Tensor, "batch_size n_ctx d_vocab"] = model(tokens)

    # print(f"Logits:\n{logits.shape}")
    loss = loss_fn(logits, target)

    loss.backward()

    if max_grad_norm is not None:
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)

    optimizer.step()

    optimizer.zero_grad()

    scheduler.step()

    losses.append(loss.item())

    # TODO(bschoen): Shouldn't you actually divide loss by batch size?
    if epoch % 500 == 0:
        print('Evaluating test loss')

        test_loss = evaluate_loss_on_test_batches(model, test_loader)

        test_losses.append(test_loss)

        print(f"Epoch {epoch}: Train loss: {loss.item():.6f}, Test loss: {test_loss:.6f}")

        wandb.log({
            'epoch': epoch,
            'train_loss': loss.item(),
            'test_loss': test_loss,
        })

wandb.finish()

# log locally to sanity check
px.line(losses, labels={"x": "Epoch", "y": "Train Loss"})


0it [00:00, ?it/s]

Evaluating test loss


9it [00:01,  8.19it/s]

Epoch 0: Train loss: 3.625956, Test loss: 3.420482


500it [00:07, 77.02it/s]

Evaluating test loss


509it [00:09, 17.10it/s]

Epoch 500: Train loss: 2.902704, Test loss: 3.079812


998it [00:14, 90.50it/s]

Evaluating test loss


1008it [00:16, 20.30it/s]

Epoch 1000: Train loss: 2.709388, Test loss: 2.844454


1497it [00:21, 85.73it/s]

Evaluating test loss


1506it [00:23, 17.94it/s]

Epoch 1500: Train loss: 2.484645, Test loss: 2.693214


1991it [00:28, 89.82it/s]

Evaluating test loss


2008it [00:30, 22.49it/s]

Epoch 2000: Train loss: 2.491163, Test loss: 2.569218


2497it [00:36, 81.56it/s]

Evaluating test loss


2514it [00:38, 22.86it/s]

Epoch 2500: Train loss: 2.415325, Test loss: 2.452568


3000it [00:43, 85.90it/s]

Evaluating test loss


3018it [00:45, 24.16it/s]

Epoch 3000: Train loss: 2.276484, Test loss: 2.400700


3499it [00:51, 83.21it/s]

Evaluating test loss


3517it [00:52, 22.20it/s]

Epoch 3500: Train loss: 2.299750, Test loss: 2.365307


3995it [00:58, 89.58it/s]

Evaluating test loss


4013it [01:00, 25.11it/s]

Epoch 4000: Train loss: 2.250075, Test loss: 2.335708


4494it [01:05, 83.28it/s]

Evaluating test loss


4513it [01:07, 25.10it/s]

Epoch 4500: Train loss: 2.065055, Test loss: 2.281692


5000it [01:12, 68.86it/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▂▃▃▄▅▆▆▇█
test_loss,█▆▄▄▃▂▂▂▁▁
train_loss,█▅▄▃▃▃▂▂▂▁

0,1
epoch,4500.0
test_loss,2.28169
train_loss,2.06505


In [84]:
# Look at some example output
import circuitsvis as cv

# create a custom to_string function since using our own tokenizer
def token_to_string(token: int) -> str:
    return tokenizer.decode([token])

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

example_sample = x[0]

# grab the first part of it, ex: `<abc|`
example_prompt = example_sample[:8]

example_prompt = example_prompt.to(device)

print(f'Using {example_prompt} from {example_sample} (from test set)')

# note: already encoded
input_tokens = example_prompt

logits_batch, cache = model.run_with_cache(input_tokens)

logits = logits_batch[0]

log_probs = logits.log_softmax(dim=-1)

cv.logits.token_log_probs(
    token_indices=input_tokens,
    log_probs=log_probs,
    to_string=token_to_string,
)
 

Using tensor([ 0, 26, 16, 18, 28, 18, 16, 26], device='mps:0') from tensor([ 0, 26, 16, 18, 28, 18, 16, 26]) (from test set)


#### Looking at attention patterns

In [85]:
# see what's in our cache
[c for c in cache.keys()]

['hook_embed',
 'hook_pos_embed',
 'blocks.0.hook_resid_pre',
 'blocks.0.ln1.hook_scale',
 'blocks.0.ln1.hook_normalized',
 'blocks.0.attn.hook_q',
 'blocks.0.attn.hook_k',
 'blocks.0.attn.hook_v',
 'blocks.0.attn.hook_attn_scores',
 'blocks.0.attn.hook_pattern',
 'blocks.0.attn.hook_z',
 'blocks.0.hook_attn_out',
 'blocks.0.hook_resid_mid',
 'blocks.0.ln2.hook_scale',
 'blocks.0.ln2.hook_normalized',
 'blocks.0.mlp.hook_pre',
 'blocks.0.mlp.hook_post',
 'blocks.0.hook_mlp_out',
 'blocks.0.hook_resid_post',
 'ln_final.hook_scale',
 'ln_final.hook_normalized']

In [198]:
cache['blocks.0.attn.hook_q'].shape

torch.Size([4, 8, 1, 16])

In [44]:
cache["blocks.0.attn.hook_pattern"].shape

torch.Size([1, 2, 8, 8])

In [86]:
for layer_index in range(cfg.n_layers):
    imshow(
        transformer_lens.utils.to_numpy(cache["attn", layer_index].mean([0, 1])),
        title=f"Layer {layer_index} Attention Pattern",
        height=400,
        width=400,
    )

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

Loss = Float32[torch.Tensor, ""]
MSELoss = Float32[torch.Tensor, ""]
WeightedSparsityLoss = Float32[torch.Tensor, ""]

Logits = Float32[torch.Tensor, "n_ctx d_vocab"]
BatchedLogits = Float32[torch.Tensor, "batch n_ctx d_vocab"]

ModelActivations = Float32[torch.Tensor, "n_ctx d_model"]
BatchedModelActivations = Float32[torch.Tensor, "batch n_ctx d_model"]

FlattenedModelActivations = Float32[torch.Tensor, "d_sae_in"]

BatchedFlattenedModelActivations = Float32[torch.Tensor, "batch d_sae_in"]
BatchedSAEActivations = Float32[torch.Tensor, "batch d_sae_model"]

@dataclasses.dataclass
class SAEOutput:
    sae_activations: BatchedSAEActivations
    reconstructed_model_activations: BatchedFlattenedModelActivations

def sparse_loss_kl_divergence(
    flattened_model_activations: BatchedModelActivations,
    sae_output: SAEOutput,
    sparsity_target: float,
    sparsity_weight: float,
    epsilon: float = 1e-10,
) -> tuple[Loss, MSELoss, WeightedSparsityLoss]:

    # same as dense loss (this is constant?)
    mse_loss = F.mse_loss(
        sae_output.reconstructed_model_activations,
        flattened_model_activations,
    )
    
    # KL divergence for sparsity
    avg_activation = torch.mean(sae_output.sae_activations, dim=0)

    # print(f'[pre-clamping] {avg_activation=}')

    # Add epsilon for numerical stability
    avg_activation = torch.clamp(avg_activation, epsilon, 1 - epsilon)

    kl_div = sparsity_target * torch.log(sparsity_target / avg_activation) + \
             (1 - sparsity_target) * torch.log((1 - sparsity_target) / (1 - avg_activation))
    kl_div = torch.sum(kl_div)

    # `sparsity_weight` decides how much we weight `KL-Divergence`
    sparsity_penalty = sparsity_weight * kl_div

    # print(f'[sparse_loss] {mse_loss=}, {avg_activation=}, {kl_div=}, {sparsity_penalty=}')
    
    return mse_loss + sparsity_penalty, mse_loss, sparsity_penalty

In [138]:
import dataclasses

import lightning.pytorch



@dataclasses.dataclass
class SparseAutoencoderConfig:
    d_in: int
    d_model: int
    sparsity_target: float = 0.05

# TODO(bschoen): Start using the config pattern, it stays typesafe and allows
#                easy logging to things like wandb
class SparseAutoencoder(nn.Module):
    def __init__(
        self,
        cfg: SparseAutoencoderConfig,
    ) -> None:

        print(f'Creating SparseAutoencoder with {cfg}')

        super(SparseAutoencoder, self).__init__()

        self.d_in = cfg.d_in
        self.d_model = cfg.d_model

        self.encoder = nn.Linear(cfg.d_in, cfg.d_model)
        self.decoder = nn.Linear(cfg.d_model, cfg.d_in)

        # Target average activation of hidden neurons
        # Motivation: Encourages each neuron to be active for ex: ~5% of inputs, promoting specialization
        self.sparsity_target = cfg.sparsity_target
    
    def forward(
        self,
        x: BatchedFlattenedModelActivations,
    ) -> SAEOutput:
        
        # TODO(bschoen): Which activation function should we use?
        
        encoded = F.gelu(self.encoder(x))
        
        decoded = self.decoder(encoded)
        
        return SAEOutput(
            sae_activations=encoded,
            reconstructed_model_activations=decoded,
        )
    
@dataclasses.dataclass
class LightningSparseAutoencoderConfig:
    
    model_config: transformer_lens.HookedTransformerConfig
    sae_config: SparseAutoencoderConfig
    learning_rate: float
    sparsity_weight: float
    
# note: this kind of lightning adapter is a common pattern: https://lightning.ai/docs/pytorch/stable/common/lightning_module.html#starter-example
class LightningSparseAutoencoder(lightning.pytorch.LightningModule):
    def __init__(
        self,
        cfg: LightningSparseAutoencoderConfig,
    ) -> None:

        super(LightningSparseAutoencoder, self).__init__()
        
        self.model = transformer_lens.HookedTransformer(cfg=cfg.model_config)
        self.sae = SparseAutoencoder(cfg=cfg.sae_config)
        self.cfg = cfg

    def forward(self, inputs, target):
        return self.model(inputs, target)

    def training_step(self, batch, batch_idx: int) -> LossValue:
        inputs, target = batch

        self.model
        output = self(inputs, target)
        loss = torch.nn.functional.cr(output, target.view(-1))
        return loss

    def configure_optimizers(self):
        return torch.optim.SGD(self.model.parameters(), lr=0.1)

In [139]:
hook_id = 'blocks.0.hook_mlp_out'

cache[hook_id].shape

torch.Size([4, 8, 16])

In [173]:
# Training loop
sae_num_epochs = 10000
sae_expansion_factor = 4

learning_rate = 1e-4

# both arbitrary for now
# - Start small: A common approach is to begin with a relatively small sparsity weight, 
#                typically in the range of 1e-5 to 1e-3. This allows the model to
#                learn meaningful representations before enforcing strong sparsity
#                constraints.
sparsity_weight: float = 1e-3  # Weight of the sparsity loss in the total loss
sparsity_target: float = 0.05  # Target average activation of hidden neurons

print(f'Training SAE for {hook_id}...')
sae_d_in = (cfg.n_ctx - 1) * cfg.d_model  # -1 since not predicting first token
sae_d_model = sae_d_in * sae_expansion_factor

sae_model = SparseAutoencoder(
    d_in=sae_d_in,
    d_model=sae_d_model,
    sparsity_target=sparsity_target,
)
sae_model.to(device)

sae_optimizer = optim.Adam(sae_model.parameters(), lr=learning_rate)
sae_criterion = nn.MSELoss()

sae_config = {
    'sae_num_epochs': sae_num_epochs,
    'sae_expansion_factor': sae_expansion_factor,
    'learning_rate': learning_rate,
    'sparsity_weight': sparsity_weight,
    'sparsity_target': sparsity_target,
    'sae_d_in': sae_d_in,
    'sae_d_model': sae_d_model,
    'hook_id': hook_id,
}
wandb.init(
    project="toy-problem-hooked-transformer-sae",
    config=sae_config,
)

# put model itself into eval mode so doesn't change
model.eval()

# go through the training data again, this time training the sae on the activations
for epoch, batch in tqdm.tqdm(
    zip(
        range(sae_num_epochs),
        itertools.cycle(train_loader),
    )
):

    tokens, target = batch

    tokens, target = tokens.to(device), target.to(device)

    # run through the model (with cache) to get the activations
    logits, cache = model.run_with_cache(tokens)

    # ex: torch.Size([4, 8, 16])
    activations = cache[hook_id]

    # ex: torch.Size([4, 128])
    flattened_activations = activations.reshape(activations.size(0), -1)

    sae_optimizer.zero_grad()

    # now the SAE model is given the *activations*
    encoded, decoded = sae_model(flattened_activations)

    # print(f'Encoded: {encoded.shape}')
    # print(f'Decoded: {decoded.shape}')

    # compute loss
    reconstruction_loss = sae_criterion(decoded, flattened_activations)
    sparsity_loss = sae_model.get_sparsity_loss(encoded)

    total_loss = reconstruction_loss + (sparsity_weight * sparsity_loss)

    total_loss.backward()

    sae_optimizer.step()

    if epoch % 500 == 0:
        print(
            f"Step {epoch}, "
            f"Total Loss: {total_loss.item():.6f}, "
            f"Reconstruction Loss: {reconstruction_loss.item():.6f}, "
            f"Sparsity Loss: {sparsity_weight *sparsity_loss.item():.6f}",
        )

        wandb.log({
            'epoch': epoch,
            'total_loss': total_loss.item(),
            'reconstruction_loss': reconstruction_loss.item(),
            'sparsity_loss': sparsity_loss.item(),
            'weighted_sparsity_loss': sparsity_weight * sparsity_loss.item(),
        })

wandb.finish()

Training SAE for blocks.0.hook_mlp_out...
Creating SparseAutoencoder with d_in=128, d_model=512, sparsity_target=0.05


34it [00:00, 167.91it/s]

Step 0, Total Loss: 1.671702, Reconstruction Loss: 1.199252, Sparsity Loss: 0.472450


540it [00:02, 213.28it/s]

Step 500, Total Loss: 0.318797, Reconstruction Loss: 0.086799, Sparsity Loss: 0.231999


1041it [00:05, 213.45it/s]

Step 1000, Total Loss: 0.260015, Reconstruction Loss: 0.050985, Sparsity Loss: 0.209030


1540it [00:07, 220.50it/s]

Step 1500, Total Loss: 0.213470, Reconstruction Loss: 0.030246, Sparsity Loss: 0.183224


2036it [00:09, 217.50it/s]

Step 2000, Total Loss: 0.184074, Reconstruction Loss: 0.017250, Sparsity Loss: 0.166824


2530it [00:12, 220.57it/s]

Step 2500, Total Loss: 0.192843, Reconstruction Loss: 0.017640, Sparsity Loss: 0.175202


3031it [00:14, 222.63it/s]

Step 3000, Total Loss: 0.182178, Reconstruction Loss: 0.015011, Sparsity Loss: 0.167168


3522it [00:16, 220.19it/s]

Step 3500, Total Loss: 0.158005, Reconstruction Loss: 0.009903, Sparsity Loss: 0.148102


4025it [00:19, 217.75it/s]

Step 4000, Total Loss: 0.155063, Reconstruction Loss: 0.009516, Sparsity Loss: 0.145546


4529it [00:21, 207.41it/s]

Step 4500, Total Loss: 0.146475, Reconstruction Loss: 0.006087, Sparsity Loss: 0.140388


5035it [00:23, 223.34it/s]

Step 5000, Total Loss: 0.146551, Reconstruction Loss: 0.005646, Sparsity Loss: 0.140905


5525it [00:26, 223.65it/s]

Step 5500, Total Loss: 0.140666, Reconstruction Loss: 0.005347, Sparsity Loss: 0.135320


6036it [00:28, 221.54it/s]

Step 6000, Total Loss: 0.144114, Reconstruction Loss: 0.004701, Sparsity Loss: 0.139413


6543it [00:30, 227.07it/s]

Step 6500, Total Loss: 0.139305, Reconstruction Loss: 0.004662, Sparsity Loss: 0.134643


7020it [00:33, 191.58it/s]

Step 7000, Total Loss: 0.139794, Reconstruction Loss: 0.004297, Sparsity Loss: 0.135497


7524it [00:35, 226.16it/s]

Step 7500, Total Loss: 0.141234, Reconstruction Loss: 0.003519, Sparsity Loss: 0.137715


8035it [00:37, 227.02it/s]

Step 8000, Total Loss: 0.149151, Reconstruction Loss: 0.003701, Sparsity Loss: 0.145450


8541it [00:39, 228.39it/s]

Step 8500, Total Loss: 0.144533, Reconstruction Loss: 0.003170, Sparsity Loss: 0.141364


9025it [00:42, 223.55it/s]

Step 9000, Total Loss: 0.129214, Reconstruction Loss: 0.002277, Sparsity Loss: 0.126938


9523it [00:44, 195.13it/s]

Step 9500, Total Loss: 0.130186, Reconstruction Loss: 0.002419, Sparsity Loss: 0.127766


10000it [00:46, 214.52it/s]


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
reconstruction_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sparsity_loss,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁
total_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
weighted_sparsity_loss,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,9500.0
reconstruction_loss,0.00242
sparsity_loss,127.76648
total_loss,0.13019
weighted_sparsity_loss,0.12777


#### Dictionary Learning Implementation

See [simple_dictionary_learning.ipynb](simple_dictionary_learning.ipynb) for a details

#### Extracting the learned dictionary

In [174]:
# Creating SparseAutoencoder with d_in=128, d_model=512, sparsity_target=0.05
dictionary: Float32[torch.Tensor, "sae_hidden sae_in"] = sae_model.encoder.weight.detach()

# ex: Dictionary shape: torch.Size([512, 128])
print(f'Dictionary shape: {dictionary.shape}')

Dictionary shape: torch.Size([512, 128])


In [175]:
# Reshape dictionary elements to match original activation shape
# (essentially `unflatting`)
reshaped_dictionary = dictionary.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

# Motivation: Extract the learned features (dictionary elements) from the encoder weights
# ex: Dictionary shape: torch.Size([512, 8, 16])
print(f"Dictionary shape: {reshaped_dictionary.shape}")


Dictionary shape: torch.Size([512, 8, 16])


In [176]:
# It's always worth checking this sort of thing when you do this by hand
# to check that you haven't got the wrong site, or are missing a
# scaling factor or something like this. 
#
# This is like the overfitting thing

In [177]:
# let's look at an example batch from `test`

# set both to eval mode
model.eval()
sae_model.eval()

# grab something from the test batch
example_batch = next(iter(test_loader))

x, y = example_batch

_, cache = model.run_with_cache(x)

activations = cache[hook_id]

print(f'Activations shape: {activations.shape}')

# flatten it
flattened_activations = activations.reshape(activations.size(0), -1)

print(f'{flattened_activations.shape=}')

encoded, decoded = sae_model(flattened_activations)

# renamed
sae_activations = encoded
reconstructed_activations = decoded

print(f'{sae_activations.shape=}')
print(f'{reconstructed_activations.shape=}')

# now we can get the dictionary
# dictionary = sae_model.encoder.weight.detach()

# print(f'Dictionary shape: {dictionary.shape}')

# now we can get the sparse coefficients
# alpha = dictionary.T @ flattened_activations

Activations shape: torch.Size([4, 8, 16])
flattened_activations.shape=torch.Size([4, 128])
sae_activations.shape=torch.Size([4, 512])
reconstructed_activations.shape=torch.Size([4, 128])


In [178]:
# good, 0.995, basically all the variance is explained by our SAE
numerator = torch.mean((reconstructed_activations[:, 1:] - flattened_activations[:, 1:]) ** 2)
denominator = (flattened_activations[:, 1:].to(torch.float32).var())

explained_variance = 1 - (numerator / denominator)

print(f'{explained_variance.item()=:.4f}')

explained_variance.item()=0.9969


In [179]:
(sae_activations.mean(dim=0) > 0).sum()

tensor(435, device='mps:0')

In [181]:
# collect max activations


with torch.no_grad():

    # go through the training data again, but don't cycle, no reason to go through more than once
    for batch in tqdm.tqdm(train_loader):

        tokens, target = batch

        tokens, target = tokens.to(device), target.to(device)

        # run through the model (with cache) to get the activations
        logits, cache = model.run_with_cache(tokens)

        # ex: torch.Size([4, 8, 16])
        activations = cache[hook_id]

        # ex: torch.Size([4, 128])
        flattened_activations = activations.reshape(activations.size(0), -1)

        # now the SAE model is given the *activations*
        encoded, decoded = sae_model(flattened_activations)

        sae_activations = encoded

        # sae_activations.reshape(sae_d_model, (cfg.n_ctx - 1), cfg.d_model)

        # max_activations = torch.max(encoded, dim=1)


        break

  0%|          | 0/3954 [00:00<?, ?it/s]


In [183]:
sae_activations.shape

torch.Size([4, 512])

In [196]:
sae_activations[0].shape

torch.Size([512])

In [194]:
8 * 16

128