In [1]:
import torch
import os

from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner

if torch.cuda.is_available():
    device = "cuda"
elif torch.backends.mps.is_available():
    device = "mps"
else:
    device = "cpu"

print("Using device:", device)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

Using device: cuda


In [None]:
total_training_steps = 100_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 1000
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="gpt2-small",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.8.hook_resid_pre",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=8,  # Only one layer in the model.
    d_in=768,  # the input dimension
    dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    is_dataset_tokenized=True,
    context_size = 1024,
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    architecture = "jumprelu",
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=2,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=True,
    scale_sparsity_penalty_by_decoder_norm=False,
    decoder_heuristic_init=False,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="none",
    # Training Parameters
    threshold_init_value=0.001,
    lr=1e-4,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="cosineannealing",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=1,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,

    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="jumprelu_sae_768",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device=device,
    seed=42,
    n_checkpoints=10,
    checkpoint_path="checkpoints",
    dtype="float32"
)
# look at the next cell to see some instruction for what to do while this is running.
sparse_autoencoder = SAETrainingRunner(cfg).run()

Run name: 1536-L1-1-LR-0.0001-Tokens-4.096e+08
n_tokens_per_buffer (millions): 1.048576
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 100000
Total wandb updates: 3333
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 100 times.
Number tokens in sparsity calculation window: 4.10e+06




Loaded pretrained model gpt2-small into HookedTransformer


Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

[34m[1mwandb[0m: Currently logged in as: [33msriramb[0m. Use [1m`wandb login --relogin`[0m to force relogin


  self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast)
93000| MSE Loss 0.064 | L1 8861.385:  93%|█████████▎| 380928000/409600000 [3:58:52<22:11, 21531.14it/s]  