# Training SAE's for Copy Suppression Analysis

In [1]:
# Autoreload
%load_ext autoreload
%autoreload 2

import sys
sys.path.append('../')
import wandb


from sparse_autoencoder import TensorActivationStore, SparseAutoencoder, pipeline
from sparse_autoencoder.source_data.pile_uncopyrighted import PileUncopyrightedDataset
from sparse_autoencoder.train.sweep_config import SweepParametersRuntime
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device
from transformers import GPT2TokenizerFast

import torch

device = get_device()

# We want to work with GPT2 since it's small and has the copy suppression results.
model_name = "gpt2"
precision = "float32"
src_model = HookedTransformer.from_pretrained(
    model_name, dtype=precision  # gpt2 -> gpt2-small
)
src_d_model: int = src_model.cfg.d_model

tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

# Make Source Data
source_data = PileUncopyrightedDataset(
    tokenizer=tokenizer,
    context_size=src_model.cfg.n_ctx,
)

max_items = 1_500_000  # max number of items in store
store = TensorActivationStore(max_items, src_d_model, device)
expansion_rate = 4  # 4x expansion
# Make Autoencoder|
in_width = src_model.cfg.d_model
n_features = src_model.cfg.d_model * expansion_rate  # 4x expansion
src_model_activation_hook_point = "blocks.10.hook_resid_pre"  # start with layer 10
autoencoder = SparseAutoencoder(in_width, n_features, torch.zeros(in_width))
autoencoder.to(device)

# hyper parameter
max_activations = 60* 1.5 * max_items
print(f"Training on {max_activations / 10**6} million tokens")

sweep_config = SweepParametersRuntime(
    lr=1e-3,
    batch_size=4096,
    l1_coefficient=5e-3,
)

config = sweep_config.__dict__
config = config | {
    "model_name": model_name,
    "precision": precision,
    "max_activations": max_activations,
    "src_model_activation_hook_point": src_model_activation_hook_point,
    "max_items": max_items,
    "n_features": n_features,
    "expansion_factor": expansion_rate,
}

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


Resolving data files:   0%|          | 0/30 [00:00<?, ?it/s]

Training on 135.0 million tokens


The autoencoder width is an important hyperparameter. 

In [2]:
wandb.init(project="sparse-autoencoder", dir=".cache/wandb", config=config)

pipeline(
    src_model=src_model,
    src_model_activation_hook_point=src_model_activation_hook_point,
    src_model_activation_layer=10,  # why do we need to specify this as well?
    source_dataset=source_data,
    activation_store=store,
    num_activations_before_training=max_items,
    sweep_parameters=sweep_config,
    log_artifacts=True,
    autoencoder=autoencoder,
    device=device,
    max_activations=max_activations,
)

wandb.finish()

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjbloom[0m. Use [1m`wandb login --relogin`[0m to force relogin


Total activations trained on:   0%|          | 0/135000000.0 [00:00<?, ?it/s, Generate/train iterations=0]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1052 > 1024). Running this sequence through the model will result in indexing errors


Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

Train Autoencoder:   0%|          | 0/1490944 [00:00<?, ?it/s]

Generate Activations:   0%|          | 0/1490944 [00:00<?, ?it/s]

ChunkedEncodingError: ('Connection broken: IncompleteRead(5187998 bytes read, 55002 more expected)', IncompleteRead(5187998 bytes read, 55002 more expected))

In [None]:
wandb.finish()