# Attn-2L - Hook Q

In [None]:
import torch
import os 
import sys
sys.path.append("..")

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "attn-only-2l",
    hook_point = "blocks.0.attn.hook_q",
    hook_point_layer = 0,
    hook_point_head_index=3,
    d_in = 64,
    dataset_path = "NeelNanda/c4-tokenized-2b",
    is_dataset_tokenized=True,
    
    # SAE Parameters
    expansion_factor = 4,
    
    # Training Parameters
    lr = 0.0012,
    lr_scheduler_name="constantwithwarmup",
    l1_coefficient = 0.0016,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 50, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500,
    feature_reinit_scale = 0.2,
    dead_feature_window=15000,
    dead_feature_threshold = 1e-8,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_attn_only_2l",
    wandb_entity = None,
    wandb_log_frequency=20,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )


# sparse_autoencoder = language_model_sae_runner(cfg)


# Gelu-2L

In [9]:
import torch
import os 
import sys
sys.path.append("..")

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gelu-2l",
    hook_point = "blocks.0.hook_mlp_out",
    hook_point_layer = 0,
    d_in = 512,
    dataset_path = "NeelNanda/c4-tokenized-2b",
    is_dataset_tokenized=True,
    
    # SAE Parameters
    expansion_factor = 128,
    
    # Training Parameters
    lr = 0.0012,
    lr_scheduler_name="constantwithwarmup",
    l1_coefficient = 0.00016,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 100, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'anthropic',
    feature_sampling_window = 5000,
    feature_reinit_scale = 0.2,
    resample_batches=128,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-4,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_models_gelu_2l_test",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )


# sparse_autoencoder = language_model_sae_runner(cfg)


Run name: 65536-L1-0.00016-LR-0.0012-Tokens-1.000e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 24414
Total wandb updates: 2441
We will reset neurons 3 times.
We will reset the sparsity calculation 3 times.
Number tokens in sparsity calculation window: 2.05e+07


# GPT2 - Small

In [12]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.hook_resid_pre",
    hook_point_layer = 10,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 0.0012,
    l1_coefficient = 0.00016,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 500, # 200M tokens seems doable overnight.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'anthropic',
    feature_sampling_window = 5000,
    feature_reinit_scale = 0.2,
    resample_batches=128,
    dead_feature_window=20000,
    dead_feature_threshold = 1e-6,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small_resid_pre",
    wandb_entity = None,
    wandb_log_frequency=100,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 10,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


Run name: 49152-L1-0.00016-LR-0.0012-Tokens-5.000e+08
n_tokens_per_buffer (millions): 0.524288
Lower bound: n_contexts_per_buffer (millions): 0.004096
Total training steps: 122070
Total wandb updates: 1220
We will reset neurons 5 times.
We will reset the sparsity calculation 23 times.
Number tokens in sparsity calculation window: 2.05e+07
Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps
Dataset is not tokenized! Updating config.




VBox(children=(Label(value='11.447 MB of 11.447 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_learning_rate,▁▄▆█████████████████████████████████████
details/n_training_tokens,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
losses/l1_loss,█▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,█▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,█▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/CE_loss_score,▁▆▆▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇█▇▇▇████▇████▇████████
metrics/ce_loss_with_ablation,▁▅█▅▅█▆▂▁█▃▅▅▅▄▃▇▂▄██▇▅▁▂▃▄▂▂▇▆▄▄▄▂▃█▆▃▇
metrics/ce_loss_with_sae,█▄▂▃▃▂▄▃▂▃▂▂▂▃▂▂▃▂▂▂▃▂▁▂▂▁▁▁▁▂▂▂▁▁▁▁▂▂▂▂
metrics/ce_loss_without_sae,▅▆▁▅▅▃█▆▂█▄▆▅▇▅▆█▅▆▅▇▆▄▅▅▄▃▄▃▄▆▂▄▅▃▄▆▅▆▆
metrics/explained_variance,▁▆▇▇▇▇▇▇▇███████████████████████████████

0,1
details/current_learning_rate,0.0012
details/n_training_tokens,28631040.0
losses/l1_loss,0.00283
losses/mse_loss,0.00274
losses/overall_loss,0.00557
metrics/CE_loss_score,0.92079
metrics/ce_loss_with_ablation,10.48372
metrics/ce_loss_with_sae,4.57367
metrics/ce_loss_without_sae,4.06526
metrics/explained_variance,0.90138


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011125159722157857, max=1.0…


[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Saved model to checkpoints/kdwz3z61/50003968_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Saved model to checkpoints/kdwz3z61/100003840_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

Saved model to checkpoints/kdwz3z61/150003712_sparse_autoencoder_gpt2-small_blocks.10.hook_resid_pre_49152.pt



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

# GPT2-Small Hook Q

## L10H7

In [None]:
import torch
import os 
import sys
sys.path.append("../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner


for l1_coefficient in [0.005, 0.003, 0.001, 0.008]:
    cfg = LanguageModelSAERunnerConfig(

        # Data Generating Function (Model + Training Distibuion)
        model_name = "gpt2-small",
        hook_point = "blocks.10.attn.hook_q",
        hook_point_layer = 10,
        hook_point_head_index=7,
        d_in = 64,
        dataset_path = "Skylion007/openwebtext",
        is_dataset_tokenized=False,
        use_cached_activations=True,
        cached_activations_path="../activations/",
        
        # SAE Parameters
        expansion_factor = 128,
        
        # Training Parameters
        lr = 0.003,
        l1_coefficient = l1_coefficient,
        lr_scheduler_name="constantwithwarmup",
        lr_warm_up_steps=1000, # about 4 million tokens.
        train_batch_size = 4096,
        context_size = 128,
        
        # Activation Store Parameters
        n_batches_in_buffer = 128,
        total_training_tokens = 1_000_000 * 500,#0 - 2_500_000, # avoid having to muse a buffer we don't have.
        store_batch_size = 32,
        
        # Resampling protocol
        feature_sampling_method = 'anthropic',
        feature_sampling_window = 5000,# doesn't do anything currently.
        feature_reinit_scale = 0.2,
        resample_batches=32*8,
        dead_feature_window=20000,
        dead_feature_threshold = 1e-4,
        
        # WANDB
        log_to_wandb = True,
        wandb_project= "mats_sae_training_gpt2_small_hook_q_dev_3",
        wandb_entity = None,
        wandb_log_frequency=100,
        
        # Misc
        device = "mps",
        seed = 42,
        n_checkpoints = 15,
        checkpoint_path = "checkpoints",
        dtype = torch.float32,
        )

    sparse_autoencoder = language_model_sae_runner(cfg)

## L4H11

In [None]:
import torch
import os 
import sys
sys.path.append("../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner



cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.4.attn.hook_q",
    hook_point_layer = 4,
    hook_point_head_index=11,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=False,
    # cached_activations_path="../activations/",
    
    # SAE Parameters
    expansion_factor = 64,
    
    # Training Parameters
    lr = 0.0012,
    l1_coefficient = 0.0016,
    lr_scheduler_name="constantwithwarmup",
    lr_warm_up_steps=1000, # about 4 million tokens.
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 1500,
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'anthropic',
    feature_sampling_window = 1000,# doesn't do anything currently.
    feature_reinit_scale = 0.2,
    dead_feature_window=50000,
    dead_feature_threshold = 1e-4,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small_hook_q_L4H11",
    wandb_entity = None,
    wandb_log_frequency=100,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 10,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)

# Pythia 70-M

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "pythia-70m",
    hook_point = "blocks.0.hook_mlp_out",
    hook_point_layer = 0,
    d_in = 512,
    dataset_path = "EleutherAI/the_pile_deduplicated",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 16,
    
    # Training Parameters
    lr = 3e-4,
    l1_coefficient = 1e-3,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 64,
    total_training_tokens = 1_000_000 * 5, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500, # Doesn't currently matter.
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 1e-8,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_benchmark_tests",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )


sparse_autoencoder = language_model_sae_runner(cfg)


# Pythia 70M Hook Q

In [None]:
import torch
import os 
import sys
sys.path.append("../")

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "pythia-70m-deduped",
    hook_point = "blocks.2.attn.hook_q",
    hook_point_layer = 2,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "EleutherAI/the_pile_deduplicated",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 16,
    
    # Training Parameters
    lr = 0.0012,
    l1_coefficient = 0.003,
    lr_scheduler_name="constantwithwarmup",
    lr_warm_up_steps=1000, # about 4 million tokens.
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 1500,
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'anthropic',
    feature_sampling_window = 1000,# doesn't do anything currently.
    feature_reinit_scale = 0.2,
    resample_batches=8,
    dead_feature_window=60000,
    dead_feature_threshold = 1e-5,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_pythia_70M_hook_q_L2H7",
    wandb_entity = None,
    wandb_log_frequency=100,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 15,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# Tiny Stories

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner



os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "tiny-stories-2L-33M",
    hook_point = "blocks.1.mlp.hook_post",
    hook_point_layer = 1,
    d_in = 4096,
    dataset_path = "roneneldan/TinyStories",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 4,
    
    # Training Parameters
    lr = 1e-4,
    l1_coefficient = 3e-4,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 10, # want 500M eventually.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500, # Doesn't currently matter.
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 0.0005,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_benchmark_tests",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# Toy Model

In [None]:

from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner


cfg = SAEToyModelRunnerConfig(
    
    # Model Details
    n_features=200,
    n_hidden=5,
    n_correlated_pairs=0,
    n_anticorrelated_pairs=0,
    feature_probability=0.025,
    model_training_steps=10_000,
    
    # SAE Parameters
    d_sae=240,
    l1_coefficient=0.001,
    
    # SAE Train Config
    train_batch_size=1028,
    feature_sampling_window=3_000,
    dead_feature_window=1_000,
    feature_reinit_scale=0.5,
    total_training_tokens=4096*300,
    
    # Other parameters
    log_to_wandb=True,
    wandb_project="sae-training-test",
    wandb_log_frequency=5,
    device="mps",
)

trained_sae = toy_model_sae_runner(cfg)

assert trained_sae is not None


# Run caching of activations to disk

In [None]:
import torch
import os 
import sys
sys.path.append("..")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import CacheActivationsRunnerConfig
from sae_training.cache_activations_runner import cache_activations_runner

cfg = CacheActivationsRunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.attn.hook_q",
    hook_point_layer = 10,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    cached_activations_path="../activations/",
    
    # Activation Store Parameters
    n_batches_in_buffer = 16,
    total_training_tokens = 500_000_000, 
    store_batch_size = 32,

    # Activation caching shuffle parameters
    n_shuffles_final = 16,
    
    # Misc
    device = "mps",
    seed = 42,
    dtype = torch.float32,
    )

cache_activations_runner(cfg)


## Train an SAE using the cached activations stored on disk
Pass `use_cached_activations=True` into the config

In [None]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.hook_resid_pre",
    hook_point_layer = 11,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=True,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 1e-5,
    l1_coefficient = 5e-4,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 64,
    total_training_tokens = 200_000,
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-7,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small",
    wandb_entity = None,
    wandb_log_frequency=50,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 5,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# Scratch

In [None]:
import torch
import os 
import sys 
sys.path.append("..")
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"
from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner



# for l1_coefficient in [9e-4,8e-4,7e-4]:
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.attn.hook_q",
    hook_point_layer = 10,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    use_cached_activations=True,
    cached_activations_path="../activations/",
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE. (64*64 = 4096, 64*4*64 = 32768)
    
    # Training Parameters
    lr = 1e-3,
    l1_coefficient = 2e-4,
    # lr_scheduler_name="LinearWarmupDecay",
    lr_warm_up_steps=2200,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 512,
    total_training_tokens = 3_000_000,
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=200,
    dead_feature_threshold = 5e-6,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small_hook_q_dev",
    wandb_entity = None,
    wandb_log_frequency=5,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

# cfg.d_sae
sparse_autoencoder = language_model_sae_runner(cfg)
# assert sparse_autoencoder is not None