# Sparse Autoencoder Training Demo

This demo trains a sparse autoencoder on activations from a Tiny Stories 1M model.

To do this we setup a *source model* (the TinyStories model) that we want to generate activations
from, along with a *source dataset* of prompts to help generate these activations.

We also setup a *sparse autoencoder model* which we'll train on these generated activations, to
learn a sparse representation of them in higher dimensional space.

Finally we'll wrap this all together in a *pipeline*, which alternates between generating
activations (storing them in ram), and training the SAE on said activations.

## Setup

### Imports

In [75]:
%load_ext autoreload
%autoreload 2

import os
from pathlib import Path

import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
from transformers import PreTrainedTokenizerBase
import wandb

from sparse_autoencoder.autoencoder.fista_autoencoder import FistaSparseAutoencoder
from sparse_autoencoder import SparseAutoencoder
from sparse_autoencoder.activation_resampler import ActivationResampler
from sparse_autoencoder.loss.learned_activations_l1 import LearnedActivationsL1Loss
from sparse_autoencoder.loss.mse_reconstruction_loss import MSEReconstructionLoss
from sparse_autoencoder.loss.reducer import LossReducer
from sparse_autoencoder.optimizer.adam_with_reset import AdamWithReset
from sparse_autoencoder.source_data.text_dataset import GenericTextDataset
from sparse_autoencoder.train.pipeline import Pipeline


os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = get_device()
print(f"Using device: {device}")  # You will need a GPU

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
Using device: cpu


### Hyperparameters

The way this library works is that you can define your own hyper-parameters and then setup the
underlying components with them. This is extremely flexible, but to help you get started we've
included some common ones below along with some sensible defaults. You can also easily sweep through
multiple hyperparameters with `wandb.sweep`.

In [76]:
hyperparameters = {
    # Expansion factor is the number of features in the sparse representation, relative to the
    # number of features in the original MLP layer. The original paper experimented with 1x to 256x,
    # and we have found that 4x is a good starting point.
    "expansion_factor": 4,
    # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).
    "l1_coefficient": 0.001,
    # Adam parameters (set to the default ones here)
    "lr": 0.001,
    "adam_beta_1": 0.9,
    "adam_beta_2": 0.999,
    "adam_epsilon": 1e-8,
    "adam_weight_decay": 0.0,
    # Batch sizes
    "train_batch_size": 8192,
}

### Source Model

The source model is just a [TransformerLens](https://github.com/neelnanda-io/TransformerLens) model
(see [here](https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
for a full list of supported models).

In this example we're training a sparse autoencoder on the activations from the first MLP layer, so
we'll also get some details about that hook point.

In [77]:
# Source model setup with TransformerLens
src_model_name = "tiny-stories-1M"
src_model = HookedTransformer.from_pretrained(src_model_name, dtype="float32")

# Details about the activations we'll train the sparse autoencoder on
src_model_activation_hook_point = "blocks.0.mlp.hook_post"
src_model_activation_layer = 0
src_d_mlp: int = src_model.cfg.d_mlp  # type: ignore (TransformerLens typing is currently broken)

f"Source: {src_model_name}, Hook: {src_model_activation_hook_point}, Features: {src_d_mlp}"

Loaded pretrained model tiny-stories-1M into HookedTransformer


'Source: tiny-stories-1M, Hook: blocks.0.mlp.hook_post, Features: 256'

### Sparse Autoencoder

We can then setup the sparse autoencoder. The default model (`SparseAutoencoder`) is setup as per
the original Anthropic paper [Towards Monosemanticity: Decomposing Language Models With Dictionary
Learning ](https://transformer-circuits.pub/2023/monosemantic-features/index.html).

However it's just a standard PyTorch model, so you can create your own model instead if you want to
use a different architecture. To do this you just need to extend the `AbstractAutoencoder`, and
optionally the underlying `AbstractEncoder`, `AbstractDecoder` and `AbstractOuterBias`. See these
classes (which are fully documented) for more details.

In [78]:
expansion_factor = hyperparameters["expansion_factor"]
autoencoder = FistaSparseAutoencoder(
    n_input_features=src_d_mlp,  # size of the activations we are autoencoding
    n_learned_features=int(src_d_mlp * expansion_factor),  # size of SAE
    geometric_median_dataset=torch.zeros(src_d_mlp),  # this is used to initialize the tied bias
).to(device)
autoencoder  # Print the model (it's pretty straightforward)

FistaSparseAutoencoder(
  (_pre_encoder_bias): TiedBias(position=pre_encoder)
  (_encoder): LinearEncoder(
    in_features=256, out_features=1024
    (activation_function): ReLU()
  )
  (_decoder): UnitNormDecoder(in_features=1024, out_features=256)
  (_post_decoder_bias): TiedBias(position=post_decoder)
)

We'll also want to setup an Optimizer and Loss function. In this case we'll also use the standard
approach from the original Anthropic paper. However you can create your own loss functions and
optimizers by extending `AbstractLoss` and `AbstractOptimizerWithReset` respectively.

In [79]:
# We use a loss reducer, which simply adds up the losses from the underlying loss functions.
loss = LossReducer(
    LearnedActivationsL1Loss(
        l1_coefficient=hyperparameters["l1_coefficient"],
    ),
    MSEReconstructionLoss(),
)
loss

LossReducer(
  (0): LearnedActivationsL1Loss(l1_coefficient=0.001)
  (1): MSEReconstructionLoss()
)

In [80]:
optimizer = AdamWithReset(
    params=autoencoder.parameters(),
    named_parameters=autoencoder.named_parameters(),
    lr=hyperparameters["lr"],
    betas=(hyperparameters["adam_beta_1"], hyperparameters["adam_beta_2"]),
    eps=hyperparameters["adam_epsilon"],
    weight_decay=hyperparameters["adam_weight_decay"],
)
optimizer

AdamWithReset (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.001
    maximize: False
    weight_decay: 0.0
)

Finally we'll initialise an activation resampler.

In [81]:
activation_resampler = ActivationResampler()

### Source dataset

This is just a dataset of tokenized prompts, to be used in generating activations (which are in turn
used to train the SAE).

In [82]:
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
source_data = GenericTextDataset(tokenizer=tokenizer, dataset_path="roneneldan/TinyStories")

###

## Training

If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to
wandb. However, we should pass in a dictionary with all of our hyperaparameters so they're on 
wandb. 

We strongly encourage users to make use of wandb in order to understand the training process.

In [83]:
Path(".cache/").mkdir(exist_ok=True)
# wandb.init(
#     project="sparse-autoencoder",
#     dir=".cache",
#     config=hyperparameters,
# )

VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011111111111111112, max=1.0…

In [92]:
max_store_size = 1_000_0
max_activations = 10_000_0
resample_frequency = 2_500_0
source_data_batch_size = 8

pipeline = Pipeline(
    cache_name=src_model_activation_hook_point,
    layer=src_model_activation_layer,
    source_model=src_model,
    autoencoder=autoencoder,
    source_dataset=source_data,
    optimizer=optimizer,
    loss=loss,
    activation_resampler=activation_resampler,
    source_data_batch_size=source_data_batch_size,
)

# pipeline.run_pipeline(
#     train_batch_size=int(hyperparameters["train_batch_size"]),
#     max_store_size=max_store_size,
#     # Sizes for demo purposes (you probably want to scale these by 10x)
#     max_activations=max_activations,
#     resample_frequency=resample_frequency,
# )

# pipeline.run_pipeline(
#     train_batch_size=int(hyperparameters["train_batch_size"]),
#     max_store_size=1_000_000,
#     # Sizes for demo purposes (you probably want to scale these by 10x)
#     max_activations=10_000_000,
#     resample_frequency=2_500_000,
# )

Activations trained on:   0%|          | 0/100000 [00:00<?, ?it/s]

In [None]:
from sparse_autoencoder.metrics.post_train.fvu_metric import FVUMetric
from sparse_autoencoder.metrics.post_train.sparsity import SparsityMetric
from sparse_autoencoder.metrics.post_train.abstract_post_train_metric import PostTrainMetricData
from torch.utils.data import DataLoader


def get_fvu_sparsity(autoencoder, source_data, pipeline, max_store_size, source_data_batch_size):
    store_size: int = max_store_size - max_store_size % (
        source_data_batch_size * source_data.context_size
    )

    activation_store = pipeline.generate_activations(store_size=store_size)
    activations_dataloader = DataLoader(
                activation_store,
                batch_size=int(hyperparameters["train_batch_size"]),
            )

    for activation_batch in activations_dataloader:
        data = PostTrainMetricData(
            input_activations = activation_batch,
            learned_activations = torch.tensor([0]),
            decoded_activations = torch.tensor([0]),
            model = autoencoder
        )

        # fvu uses only input_activations and model
        fvu_metric = FVUMetric()
        fvu = fvu_metric.calculate(data)
        
        sparsity_metric = SparsityMetric()
        sparsity = sparsity_metric.calculate(data)
        
        break
    return sparsity, fvu

def get_fvu_sparsity_average(autoencoder, source_data, pipeline, max_store_size, source_data_batch_size, num_iterations):
    fvu_values = []
    sparsity_values = []

    for i in range(num_iterations):
        sparsity_dict, fvu_dict = get_fvu_sparsity(autoencoder, source_data, pipeline, max_store_size, source_data_batch_size)
    
        sparsity = sparsity_dict['sparsity']
        fvu = fvu_dict['fvu'].item()

        sparsity_values.append(sparsity)
        fvu_values.append(fvu)


    mean_fvu = sum(fvu_values) / num_iterations
    mean_sparsity = sum(sparsity_values) / num_iterations

    variance_fvu = sum((x - mean_fvu) ** 2 for x in fvu_values) / num_iterations
    variance_sparsity = sum((x - mean_sparsity) ** 2 for x in sparsity_values) / num_iterations

    return mean_sparsity, mean_fvu, variance_sparsity, variance_fvu

In [None]:
min_l1_coefficient = 0.0001  
max_l1_coefficient = 0.01   
num_runs = 6              

save = True

step_size = (max_l1_coefficient - min_l1_coefficient) / (num_runs - 1)

for i in range(num_runs):
    run_name = f"run_{i}"  
    with wandb.init(project="sparse-autoencoder", name=run_name, dir=".cache", config=hyperparameters):
        
        current_l1_coefficient = round(min_l1_coefficient + (i * step_size), 6)

        loss = LossReducer(
            LearnedActivationsL1Loss(l1_coefficient=current_l1_coefficient),
            MSEReconstructionLoss(),
        )

        pipeline = Pipeline(
            cache_name=src_model_activation_hook_point,
            layer=src_model_activation_layer,
            source_model=src_model,
            autoencoder=autoencoder,
            source_dataset=source_data,
            optimizer=optimizer,
            loss=loss,
            activation_resampler=activation_resampler,
            source_data_batch_size=source_data_batch_size,
        )

        pipeline.run_pipeline(
            train_batch_size=int(hyperparameters["train_batch_size"]),
            max_store_size=max_store_size,
            max_activations=max_activations,
            resample_frequency=resample_frequency,
        )
        
        
        sparsity, fvu, variance_sparsity, variance_fvu = get_fvu_sparsity_average(autoencoder, source_data, pipeline, max_store_size, source_data_batch_size, num_iterations=50)

        wandb.log({"sparsity": sparsity, "fvu": fvu, "l1_coefficient": current_l1_coefficient, "variance_sparsity": variance_sparsity, "variance_fvu": variance_fvu})

        filename = f"fista_{i}_{current_l1_coefficient}"
        if save:
            with open(filename + ".txt", 'w') as file:
                for key, value in hyperparameters.items():
                    file.write(f"{key}: {value}\n")
                file.write(f"l1_coefficient: {current_l1_coefficient}\n")  
                file.write(f"sparsity: {sparsity}\n")
                file.write(f"fvu: {fvu}\n")
                file.write(f"variance_sparsity: {variance_sparsity}\n")
                file.write(f"variance_fvu: {variance_fvu}\n")

            torch.save(autoencoder.state_dict(), filename + ".pt")


In [86]:
wandb.finish()

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
LearnedActivationsL1Loss,▃▅▅█▂▁▁▁▁▁▂▁▁▁▁▂▂▂▁▁
LossReducer,▂▄▅█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
MSEReconstructionLoss,▂▄▅█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
LearnedActivationsL1Loss,0.04742
LossReducer,0.04882
MSEReconstructionLoss,0.0014


## Training Advice

-- Unfinished --

- Check recovery loss is low while sparsity is low as well (<20 L1) usually.
- Can't be sure features are useful until you dig into them more. 

# Analysis

-- Unfinished --