In [2]:
import os
from pathlib import Path

import torch
from transformer_lens import HookedTransformer
from transformers import PreTrainedTokenizerBase
import wandb

from sparse_autoencoder import (
    ActivationResampler,
    AdamWithReset,
    L2ReconstructionLoss,
    LearnedActivationsL1Loss,
    LossReducer,
    Pipeline,
    PreTokenizedDataset,
    SparseAutoencoder,
)


os.environ["TOKENIZERS_PARALLELISM"] = "false"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")  # You will need a GPU

  from .autonotebook import tqdm as notebook_tqdm


Using device: cuda


In [3]:
torch.random.manual_seed(49)

hyperparameters = {
    # Expansion factor is the number of features in the sparse representation, relative to the
    # number of features in the original MLP layer. The original paper experimented with 1x to 256x,
    # and we have found that 4x is a good starting point.
    "expansion_factor": 4,
    # L1 coefficient is the coefficient of the L1 regularization term (used to encourage sparsity).
    "l1_coefficient": 1e-3,
    # Adam parameters (set to the default ones here)
    "lr": 3e-4,
    "adam_beta_1": 0.9,
    "adam_beta_2": 0.999,
    "adam_epsilon": 1e-8,
    "adam_weight_decay": 0.0,
    # Batch sizes
    "train_batch_size": 4096,
    "context_size": 128,
    # Source model hook point
    "source_model_name": "EleutherAI/Pythia-70M-deduped",
    "source_model_dtype": "float32",
    "source_model_hook_point": "blocks.0.hook_mlp_out",
    "source_model_hook_point_layer": 0,
    # Train pipeline parameters
    "max_store_size": 384 * 4096 * 2,
    "max_activations": 2_000_000_000,
    "resample_frequency": 122_880_000,
    "checkpoint_frequency": 100_000_000,
    "validation_frequency": 384 * 4096 * 2 * 100,  # Every 100 generations
}

In [4]:
# Source model setup with TransformerLens
src_model = HookedTransformer.from_pretrained(
    str(hyperparameters["source_model_name"])
)

# Details about the activations we'll train the sparse autoencoder on
autoencoder_input_dim: int = src_model.cfg.d_model  # type: ignore (TransformerLens typing is currently broken)

f"Source: {hyperparameters['source_model_name']}, \
    Hook: {hyperparameters['source_model_hook_point']}, \
    Features: {autoencoder_input_dim}"

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/Pythia-70M-deduped into HookedTransformer


'Source: EleutherAI/Pythia-70M-deduped,     Hook: blocks.0.hook_mlp_out,     Features: 512'

In [5]:
expansion_factor = hyperparameters["expansion_factor"]
autoencoder = SparseAutoencoder(
    n_input_features=autoencoder_input_dim,  # size of the activations we are autoencoding
    n_learned_features=int(autoencoder_input_dim * expansion_factor),  # size of SAE
    geometric_median_dataset=torch.zeros(
        autoencoder_input_dim
    ),  # this is used to initialize the tied bias
).to(device)
autoencoder  # Print the model (it's pretty straightforward)

SparseAutoencoder(
  (_pre_encoder_bias): TiedBias(position=pre_encoder)
  (_encoder): LinearEncoder(
    in_features=512, out_features=2048
    (activation_function): ReLU()
  )
  (_decoder): UnitNormDecoder(in_features=2048, out_features=512)
  (_post_decoder_bias): TiedBias(position=post_decoder)
)

In [6]:
# We use a loss reducer, which simply adds up the losses from the underlying loss functions.
loss = LossReducer(
    LearnedActivationsL1Loss(
        l1_coefficient=float(hyperparameters["l1_coefficient"]),
    ),
    L2ReconstructionLoss(),
)
loss

LossReducer(
  (0): LearnedActivationsL1Loss(l1_coefficient=0.001)
  (1): L2ReconstructionLoss()
)

In [7]:
optimizer = AdamWithReset(
    params=autoencoder.parameters(),
    named_parameters=autoencoder.named_parameters(),
    lr=float(hyperparameters["lr"]),
    betas=(float(hyperparameters["adam_beta_1"]), float(hyperparameters["adam_beta_2"])),
    eps=float(hyperparameters["adam_epsilon"]),
    weight_decay=float(hyperparameters["adam_weight_decay"]),
)
optimizer

AdamWithReset (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.0003
    maximize: False
    weight_decay: 0.0
)

In [8]:
activation_resampler = ActivationResampler()

In [9]:
tokenizer: PreTrainedTokenizerBase = src_model.tokenizer  # type: ignore
source_data = PreTokenizedDataset(
    dataset_path="NeelNanda/c4-code-tokenized-2b", context_size=int(hyperparameters["context_size"])
)

In [10]:
import json
secrets = json.load(open("secrets.json"))
wandb.login(key=secrets["wandb_key"])
checkpoint_path = Path("../../.checkpoints")
checkpoint_path.mkdir(exist_ok=True)
Path(".cache/").mkdir(exist_ok=True)
wandb.init(
    project="sparse-autoencoder",
    dir=".cache",
    config=hyperparameters,
)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33melriggs[0m ([33msparse_coding[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
pipeline = Pipeline(
    activation_resampler=activation_resampler,
    autoencoder=autoencoder,
    cache_name=str(hyperparameters["source_model_hook_point"]),
    checkpoint_directory=checkpoint_path,
    layer=int(hyperparameters["source_model_hook_point_layer"]),
    loss=loss,
    optimizer=optimizer,
    source_data_batch_size=6,
    source_dataset=source_data,
    source_model=src_model,
)

pipeline.run_pipeline(
    train_batch_size=int(hyperparameters["train_batch_size"]),
    max_store_size=int(hyperparameters["max_store_size"]),
    max_activations=int(hyperparameters["max_activations"]),
    resample_frequency=int(hyperparameters["resample_frequency"]),
    checkpoint_frequency=int(hyperparameters["checkpoint_frequency"]),
    validate_frequency=int(hyperparameters["validation_frequency"]),
)

Activations trained on:   0%|          | 0/2000000000 [00:07<?, ?it/s, stage=generate]


AttributeError: 'list' object has no attribute 'to'

In [12]:
source_dataloader = source_data.get_dataloader(6)
# self.stateful_dataloader_iterable(source_dataloader)

In [None]:
wandb.finish()