# Sparse Autoencoder Training Demo

## Setup

In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2


In [2]:
from sparse_autoencoder import TensorActivationStore, SparseAutoencoder, pipeline
from sparse_autoencoder.source_data.pile_uncopyrighted import PileUncopyrightedDataset
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
from transformers import PreTrainedTokenizerBase
import torch

In [3]:
device = get_device()
device

device(type='cuda')

### Source Model

In [4]:
src_model = HookedTransformer.from_pretrained("solu-1l", dtype="float32")
src_d_mlp: int = src_model.cfg.d_mlp  # type: ignore
src_d_mlp

Loaded pretrained model solu-1l into HookedTransformer


2048

### Source Dataset

In [5]:
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
source_data = PileUncopyrightedDataset(tokenizer=tokenizer)

Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

### Activation Store

In [6]:
max_items = 1_000_000
store = TensorActivationStore(max_items, src_d_mlp, device)

### Autoencoder

In [7]:
autoencoder = SparseAutoencoder(src_d_mlp, src_d_mlp * 8, torch.zeros(src_d_mlp))
autoencoder.to(device)

SparseAutoencoder(
  (encoder): Sequential(
    (TiedBias): TiedBias(position=pre_encoder)
    (Linear): Linear(in_features=2048, out_features=16384, bias=False)
    (ReLU): ReLU()
  )
  (decoder): Sequential(
    (ConstrainedUnitNormLinear): ConstrainedUnitNormLinear(in_features=16384, out_features=2048, bias=False)
    (TiedBias): TiedBias(position=post_decoder)
  )
)

## Training

If you initialise [wandb](https://wandb.ai/site), the pipeline will automatically log all metrics to wandb.

In [8]:
import wandb

wandb.init(project="sparse-autoencoder", dir=".cache/wandb")

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjbloom[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
pipeline(
    src_model=src_model,
    src_model_activation_hook_point="blocks.0.mlp.hook_post",
    src_model_activation_layer=0,
    source_dataset=source_data,
    activation_store=store,
    num_activations_before_training=max_items,
    autoencoder=autoencoder,
    device=device,
)

Generate/Train Cycles: 0it [00:00, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1000000 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1000000 [00:00<?, ?it/s]

KeyboardInterrupt: 

In [None]:
wandb.finish()



VBox(children=(Label(value='0.011 MB of 0.011 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))