In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import os
import sys

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

dataset_path = 'taufeeque/othellogpt'
model_name = 'othello-gpt'
device = "cuda" if torch.cuda.is_available() else "cpu"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

for l1_coefficient in [0.0002, 0.0001]:
    for exp_factor in [0.5, 1, 4, 16, 64]:
        config = LanguageModelSAERunnerConfig(
            model_name=model_name,
            hook_point="blocks.6.hook_resid_pre",
            hook_point_layer=6,
            dataset_path=dataset_path,
            context_size=59,
            d_in=512,
            n_batches_in_buffer=32,
            # total_training_tokens=1*(1e6), # prev: 10*(1e6)
            total_training_tokens=100*(1e6), # prev: 10*(1e6)
            store_batch_size=32,
            device=device,
            seed=42,
            dtype=torch.float32,
            b_dec_init_method="geometric_median", # todo: geometric_median
            expansion_factor=exp_factor, # todo: adjust
            l1_coefficient=l1_coefficient, # prev: 0.001, 0.0001, 0.0002
            lr=0.00003, # prev: 0.0003
            lr_scheduler_name="constantwithwarmup",
            lr_warm_up_steps=5000,
            train_batch_size=4096,
            use_ghost_grads=True,
            feature_sampling_window=500,
            dead_feature_window=1e6,
            log_to_wandb=True,
            wandb_project="othello_gpt_sae",
            wandb_log_frequency=30,
            n_checkpoints=0,
            checkpoint_path="checkpoints",
            start_pos_offset=5, # exclude first seq position
            end_pos_offset=-5
        )

        sparse_autoencoder = language_model_sae_runner(config)
    # import time
    # rand_string = time.time()

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


Run name: 512-L1-0.0002-LR-3e-05-Tokens-1.000e+06
n_tokens_per_buffer (millions): 0.060416
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 244.0
Total wandb updates: 8.0
n_tokens_per_feature_sampling_window (millions): 120.832
n_tokens_per_dead_feature_window (millions): 241664.0
Using Ghost Grads.
We will reset the sparsity calculation 0.0 times.
Number tokens in sparsity calculation window: 2.05e+06
Loaded pretrained model othello-gpt into HookedTransformer
Moving model to device:  cuda


Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Dataset is tokenized! Updating config.


Objective value: 22.3585:  10%|█         | 10/100 [00:03<00:32,  2.74it/s]
  out = torch.tensor(out, dtype=self.dtype, device=self.device)


Reinitializing b_dec with geometric median of activations
Previous distances: 23.284273147583008
New distances: 22.28409194946289


244| MSE Loss 0.084 | L1 0.057: 100%|█████████▉| 999424/1000000.0 [00:18<00:00, 49082.79it/s]

Saved model to checkpoints/r1e9986g/final_sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_512.pt
Run name: 8192-L1-0.0002-LR-3e-05-Tokens-1.000e+06
n_tokens_per_buffer (millions): 0.060416
Lower bound: n_contexts_per_buffer (millions): 0.001024
Total training steps: 244.0
Total wandb updates: 8.0
n_tokens_per_feature_sampling_window (millions): 120.832
n_tokens_per_dead_feature_window (millions): 241664.0
Using Ghost Grads.
We will reset the sparsity calculation 0.0 times.
Number tokens in sparsity calculation window: 2.05e+06


244| MSE Loss 0.084 | L1 0.057: : 1003520it [00:34, 49082.79it/s]                            

Loaded pretrained model othello-gpt into HookedTransformer
Moving model to device:  cuda


Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/20 [00:00<?, ?it/s]

Dataset is tokenized! Updating config.


244| MSE Loss 0.084 | L1 0.057: : 1003520it [01:32, 10871.49it/s]
Objective value: 22.3597:  10%|█         | 10/100 [00:03<00:34,  2.60it/s]


Reinitializing b_dec with geometric median of activations
Previous distances: 23.286479949951172
New distances: 22.28495979309082


244| MSE Loss 0.078 | L1 0.226: : 1003520it [00:20, 66651.79it/s]                            

Saved model to checkpoints/vsrt2ynw/final_sparse_autoencoder_othello-gpt_blocks.6.hook_resid_pre_8192.pt


244| MSE Loss 0.078 | L1 0.226: : 1003520it [00:38, 66651.79it/s]

In [3]:
for l1_coefficient in [0.0002, 0.0001]:
    for exp_factor in [0.5, 1, 4, 16, 64]:
        print(f"python mats_sae_training/sae_simple_train.py --l1_coefficient {l1_coefficient} --exp_factor {exp_factor}")

python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0002 --exp_factor 0.5
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0002 --exp_factor 1
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0002 --exp_factor 4
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0002 --exp_factor 16
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0002 --exp_factor 64
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0001 --exp_factor 0.5
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0001 --exp_factor 1
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0001 --exp_factor 4
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0001 --exp_factor 16
python mats_sae_training/sae_simple_train.py --l1_coefficient 0.0001 --exp_factor 64
