# Gelu-2L

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gelu-2l",
    hook_point = "blocks.0.hook_mlp_out",
    hook_point_layer = 0,
    d_in = 512,
    dataset_path = "NeelNanda/c4-tokenized-2b",
    is_dataset_tokenized=True,
    
    # SAE Parameters
    expansion_factor = 32,
    
    # Training Parameters
    lr = 1e-4,
    l1_coefficient = 3e-4,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 500, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500,
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 1e-8,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_models_gelu_2l",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 10,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )


sparse_autoencoder = language_model_sae_runner(cfg)


# GPT2 - Small

In [None]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.hook_resid_pre",
    hook_point_layer = 11,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 1e-5,
    l1_coefficient = 5e-4,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 200, # 200M tokens seems doable overnight.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-7,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small",
    wandb_entity = None,
    wandb_log_frequency=50,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 5,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# GPT2-Small Hook Q

In [1]:
import torch
import os 
import sys
sys.path.append("../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner


# for l1_coefficient in [9e-4,8e-4,7e-4]:
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.attn.hook_q",
    hook_point_layer = 10,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=True,
    cached_activations_path="../activations/",
    
    # SAE Parameters
    expansion_factor = 64 * 16, # 65536
    
    # Training Parameters
    lr = 1e-3,
    l1_coefficient = 4e-4,
    # lr_scheduler_name="LinearWarmupDecay",
    lr_warm_up_steps=2200,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 512,
    total_training_tokens = 1_000_000 * 50 - 2_500_000,# avoid having to muse a buffer we don't have.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=3000,
    dead_feature_threshold = 5e-6,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small_hook_q_new",
    wandb_entity = None,
    wandb_log_frequency=30,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 15,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


n_tokens_per_buffer (millions): 2.097152
Lower bound: n_contexts_per_buffer (millions): 0.016384
Total training steps: 11596
Total wandb updates: 386
n_dead_feature_samples: 2
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps
Dataset is not tokenized! Updating config.


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjbloom[0m. Use [1m`wandb login --relogin`[0m to force relogin


773| MSE Loss 17.911 | L1 23.663:   7%|▋         | 3170304/47500000 [02:44<31:59, 23089.88it/s]  

Saved model to checkpoints/80yd4j7z/3170304_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


1546| MSE Loss 16.244 | L1 23.661:  13%|█▎        | 6336512/47500000 [05:58<28:54, 23735.99it/s] 

Saved model to checkpoints/80yd4j7z/6336512_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


2319| MSE Loss 15.766 | L1 22.999:  20%|██        | 9502720/47500000 [08:19<27:39, 22895.40it/s] 

Saved model to checkpoints/80yd4j7z/9502720_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


2998| MSE Loss 15.226 | L1 22.680:  26%|██▌       | 12283904/47500000 [10:21<25:22, 23133.12it/s] 

Resampled 1 neurons


3092| MSE Loss 15.455 | L1 22.170:  27%|██▋       | 12668928/47500000 [10:42<23:29, 24711.16it/s] 

Saved model to checkpoints/80yd4j7z/12668928_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


3865| MSE Loss 14.881 | L1 21.697:  33%|███▎      | 15835136/47500000 [12:59<21:14, 24841.11it/s] 

Saved model to checkpoints/80yd4j7z/15835136_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


4638| MSE Loss 14.449 | L1 21.378:  40%|████      | 19001344/47500000 [15:20<19:12, 24733.38it/s] 

Saved model to checkpoints/80yd4j7z/19001344_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


5411| MSE Loss 14.558 | L1 20.535:  47%|████▋     | 22167552/47500000 [17:40<18:28, 22856.45it/s] 

Saved model to checkpoints/80yd4j7z/22167552_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


6184| MSE Loss 14.328 | L1 19.962:  53%|█████▎    | 25333760/47500000 [19:57<15:01, 24576.24it/s] 

Saved model to checkpoints/80yd4j7z/25333760_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


6958| MSE Loss 14.363 | L1 19.510:  60%|██████    | 28504064/47500000 [22:17<14:03, 22509.32it/s] 

Saved model to checkpoints/80yd4j7z/28504064_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


7731| MSE Loss 14.096 | L1 19.397:  67%|██████▋   | 31670272/47500000 [24:33<10:22, 25431.61it/s] 

Saved model to checkpoints/80yd4j7z/31670272_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


8504| MSE Loss 14.236 | L1 19.207:  73%|███████▎  | 34836480/47500000 [26:53<08:22, 25215.41it/s] 

Saved model to checkpoints/80yd4j7z/34836480_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


9277| MSE Loss 13.791 | L1 19.171:  80%|████████  | 38002688/47500000 [29:09<06:43, 23515.85it/s] 

Saved model to checkpoints/80yd4j7z/38002688_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


10050| MSE Loss 14.000 | L1 18.660:  87%|████████▋ | 41168896/47500000 [31:30<04:15, 24749.72it/s]

Saved model to checkpoints/80yd4j7z/41168896_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


10823| MSE Loss 14.000 | L1 18.748:  93%|█████████▎| 44335104/47500000 [33:50<02:04, 25329.29it/s]

Saved model to checkpoints/80yd4j7z/44335104_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


11596| MSE Loss 14.277 | L1 18.413: : 47501312it [36:05, 24876.35it/s]                            

Saved model to checkpoints/80yd4j7z/47501312_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


11596| MSE Loss 14.277 | L1 18.413: : 47501312it [36:06, 21927.73it/s]


Saved model to checkpoints/80yd4j7z/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_65536.pt


In [None]:
(1_000_000 * 500 - 4096) / 4096

In [None]:
sparse_autoencoder.save_model("./overnight_sae_resid_pre_10_gpt_2_small.pt")

# Pythia 70-M

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "pythia-70m",
    hook_point = "blocks.0.hook_mlp_out",
    hook_point_layer = 0,
    d_in = 512,
    dataset_path = "EleutherAI/the_pile_deduplicated",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 16,
    
    # Training Parameters
    lr = 3e-4,
    l1_coefficient = 1e-3,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 64,
    total_training_tokens = 1_000_000 * 5, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500, # Doesn't currently matter.
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 1e-8,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_benchmark_tests",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

def main():
    sparse_autoencoder = language_model_sae_runner(cfg)

main()
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# import cProfile, pstats, io
# from pstats import SortKey
# pr = cProfile.Profile()
# pr.enable()
# # ... do something ...
# main()
# pr.disable()
# s = io.StringIO()
# sortby = SortKey.CUMULATIVE
# ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
# ps.print_stats()
# print(s.getvalue())


# Tiny Stories

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner



os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "tiny-stories-2L-33M",
    hook_point = "blocks.1.mlp.hook_post",
    hook_point_layer = 1,
    d_in = 4096,
    dataset_path = "roneneldan/TinyStories",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 4,
    
    # Training Parameters
    lr = 1e-4,
    l1_coefficient = 3e-4,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 10, # want 500M eventually.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500, # Doesn't currently matter.
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 0.0005,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_benchmark_tests",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# Toy Model

In [None]:

from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner


cfg = SAEToyModelRunnerConfig(
    
    # Model Details
    n_features=200,
    n_hidden=5,
    n_correlated_pairs=0,
    n_anticorrelated_pairs=0,
    feature_probability=0.025,
    model_training_steps=10_000,
    
    # SAE Parameters
    d_sae=240,
    l1_coefficient=0.001,
    
    # SAE Train Config
    train_batch_size=1028,
    feature_sampling_window=3_000,
    dead_feature_window=1_000,
    feature_reinit_scale=0.5,
    total_training_tokens=4096*300,
    
    # Other parameters
    log_to_wandb=True,
    wandb_project="sae-training-test",
    wandb_log_frequency=5,
    device="mps",
)

trained_sae = toy_model_sae_runner(cfg)

assert trained_sae is not None


# Run caching of activations to disk

In [None]:
import torch
import os 
import sys
sys.path.append("..")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import CacheActivationsRunnerConfig
from sae_training.cache_activations_runner import cache_activations_runner

cfg = CacheActivationsRunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.attn.hook_q",
    hook_point_layer = 11,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    cached_activations_path="../activations/",
    
    # Activation Store Parameters
    n_batches_in_buffer = 16,
    total_training_tokens = 500_000_000, 
    store_batch_size = 32,

    # Activation caching shuffle parameters
    n_shuffles_final = 16,
    
    # Misc
    device = "mps",
    seed = 42,
    dtype = torch.float32,
    )

cache_activations_runner(cfg)


## Train an SAE using the cached activations stored on disk
Pass `use_cached_activations=True` into the config

In [None]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.hook_resid_pre",
    hook_point_layer = 11,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=True,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 1e-5,
    l1_coefficient = 5e-4,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 64,
    total_training_tokens = 200_000,
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-7,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small",
    wandb_entity = None,
    wandb_log_frequency=50,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 5,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# Scratch

In [1]:
import torch
import os 
import sys 
sys.path.append("..")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner



# for l1_coefficient in [9e-4,8e-4,7e-4]:
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.attn.hook_q",
    hook_point_layer = 10,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=True,
    cached_activations_path="../activations/",
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE. (64*64 = 4096, 64*4*64 = 32768)
    
    # Training Parameters
    lr = 1e-3,
    l1_coefficient = 2e-4,
    # lr_scheduler_name="LinearWarmupDecay",
    lr_warm_up_steps=2200,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 512,
    total_training_tokens = 3_000_000,
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=200,
    dead_feature_threshold = 5e-6,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small_hook_q_dev",
    wandb_entity = None,
    wandb_log_frequency=5,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

# cfg.d_sae
sparse_autoencoder = language_model_sae_runner(cfg)
# assert sparse_autoencoder is not None

  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(
  _torch_pytree._register_pytree_node(


n_tokens_per_buffer (millions): 2.097152
Lower bound: n_contexts_per_buffer (millions): 0.016384
Total training steps: 732
Total wandb updates: 146
n_dead_feature_samples: 2
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps
Dataset is not tokenized! Updating config.


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjbloom[0m. Use [1m`wandb login --relogin`[0m to force relogin


732| MSE Loss 7.234 | L1 18.154: : 3002368it [01:45, 28515.37it/s]                           


Saved model to checkpoints/obxe76cw/final_sparse_autoencoder_gpt2-small_blocks.10.attn.hook_q_4096.pt


VBox(children=(Label(value='31.940 MB of 41.399 MB uploaded\r'), FloatProgress(value=0.771506892761734, max=1.…



0,1
details/current_learning_rate,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/l1_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,██▅▃▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▄▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▆▂▂▆▅▃▄▅██▇▆▆
metrics/ce_loss_with_ablation,▇▅▃▅▂▁▇▃▄█▂▆▆█
metrics/ce_loss_with_sae,▇▅▃▅▂▁█▃▄█▂▆▆█
metrics/ce_loss_without_sae,▇▅▃▅▂▁█▃▄█▂▆▆█
metrics/explained_variance,▁▁▄▆▇▇▇▇▇▇██████████████████████████████

0,1
details/current_learning_rate,0.001
details/n_training_tokens,2990080.0
losses/l1_loss,17.97697
losses/mse_loss,7.28971
losses/overall_loss,25.26668
metrics/CE_loss_score,0.74266
metrics/ce_loss_with_ablation,3.8
metrics/ce_loss_with_sae,3.74677
metrics/ce_loss_without_sae,3.72833
metrics/explained_variance,0.91121
