In [None]:
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("/root/specialised-SAEs/")
from sae_lens.config import LanguageModelSAERunnerConfig
from sae_lens.training.training_sae import TrainingSAEConfig, TrainingSAE
from sae_lens.sae import SAE
from sae_lens.sae_training_runner import SAETrainingRunner
from sae_lens.training.activations_store import ActivationsStore
import logging
logger = logging.getLogger()
logger.setLevel(logging.ERROR)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
total_training_steps = 5_000  # probably we should do more
batch_size = 4096*4
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training


expansion_factor = 4

# for l1_coefficient in [20]:
#     for control_mixture in [0, 0.1, 0.5]:
l1_coefficient = 20
control_mixture = 0.2
cfg = LanguageModelSAERunnerConfig(
    # JACOB
    gsae_path="/root/SSAE-training/sae_lens/jacob/checkpoints/gsae_gpt2_l1=58/final_409600000",
    control_dataset_path="NeelNanda/openwebtext-tokenized-9b",
    is_control_dataset_tokenized=True,
    control_mixture=control_mixture,

    dataset_path="jacobcd52/physics-papers",
    is_dataset_tokenized=False,

    # Data Generating Function (Model + Training Distribution)
    architecture="gated",  # we'll use the gated variant.
    model_name="gpt2-small",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.8.hook_resid_pre",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=8,  # Only one layer in the model.
    d_in=768,  # the width of the mlp output.
    streaming=True,  # we could pre-download the token dataset if it was small.
    # SAE Parameters
    mse_loss_normalization=None,  # We won't normalize the mse loss,
    expansion_factor=expansion_factor,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=True,  # We won't apply the decoder weights to the input.
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=False,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    # normalize_activations=False, JACOB
    # Training Parameters
    lr=5e-5,  # lower the better, we'll go fairly high to speed up the tutorial.
    adam_beta1=0.9,  # adam params (default, but once upon a time we experimented with these.)
    adam_beta2=0.999,
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=l1_coefficient,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=256,  # will control the lenght of the prompts we feed to the model. Larger is better but slower. so for the tutorial we'll use a short one.
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    # Resampling protocol
    use_ghost_grads=False,  # we don't use ghost grads anymore.
    feature_sampling_window=1000,  # this controls our reporting of feature sparsity stats
    dead_feature_window=1000,  # would effect resampling or ghost grads if we were using it.
    dead_feature_threshold=1e-4,  # would effect resampling or ghost grads if we were using it.
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="physics-SSAE-gpt2",
    wandb_log_frequency=30,
    eval_every_n_wandb_logs=20,
    # Misc
    device="cuda",
    seed=42,
    n_checkpoints=10,
    checkpoint_path=f"phys_gpt2_ssae_checkpoints_l1_coeff={l1_coefficient}_expansion={expansion_factor}_control_mixture={control_mixture}",
    dtype="float32"
)

ssae = SAETrainingRunner(cfg)
activation_store = ssae.activations_store


Run name: 3072-L1-20-LR-5e-05-Tokens-8.192e+07
n_tokens_per_buffer (millions): 0.262144
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 5000
Total wandb updates: 166
n_tokens_per_feature_sampling_window (millions): 4194.304
n_tokens_per_dead_feature_window (millions): 4194.304
We will reset the sparsity calculation 5 times.
Number tokens in sparsity calculation window: 1.64e+07




config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

  _torch_pytree._register_pytree_node(


model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Loaded pretrained model gpt2-small into HookedTransformer


Downloading readme:   0%|          | 0.00/281 [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/386 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/37 [00:00<?, ?it/s]

yes control dataset!


TypeError: TrainingSAEConfig.__init__() got an unexpected keyword argument 'is_control_dataset_tokenized'