In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gelu-2l",
    hook_point = "blocks.0.hook_mlp_out",
    hook_point_layer = 0,
    d_in = 512,
    dataset_path = "NeelNanda/c4-tokenized-2b",
    is_dataset_tokenized=True,
    
    # SAE Parameters
    expansion_factor = 16,
    
    # Training Parameters
    lr = 1e-4,
    l1_coefficient = 3e-4,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 32,
    total_training_tokens = 1_000_000 * 10, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500,
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 1e-8,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_models",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

def main():
    sparse_autoencoder = language_model_sae_runner(cfg)

# main()

import cProfile, pstats, io
from pstats import SortKey
pr = cProfile.Profile()
pr.enable()
# ... do something ...
main()
pr.disable()
s = io.StringIO()
sortby = SortKey.CUMULATIVE
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue())


# GPT2 - Small

In [None]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.hook_resid_pre",
    hook_point_layer = 11,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 1e-5,
    l1_coefficient = 5e-4,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 200, # 200M tokens seems doable overnight.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-7,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small",
    wandb_entity = None,
    wandb_log_frequency=50,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 5,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# GPT2-Small Hook Q

In [None]:
import torch
import os 
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.10.attn.hook_q",
    hook_point_layer = 11,
    hook_point_head_index=7,
    d_in = 64,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 64, # determines the dimension of the SAE.
    
    # Training Parameters
    lr = 1e-3,
    l1_coefficient = 0.001,
    lr_scheduler_name=None,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 512,
    total_training_tokens = 1_000_000 * 3, # 200M tokens seems doable overnight.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 1000,
    feature_reinit_scale = 0.2,
    dead_feature_window=1000,
    dead_feature_threshold = 1e-7,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2_small_hook_q",
    wandb_entity = None,
    wandb_log_frequency=50,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 5,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


In [None]:
sparse_autoencoder.save_model("./overnight_sae_resid_pre_10_gpt_2_small.pt")

# Pythia 70-M

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

import cProfile


os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "pythia-70m",
    hook_point = "blocks.0.hook_mlp_out",
    hook_point_layer = 0,
    d_in = 512,
    dataset_path = "EleutherAI/the_pile_deduplicated",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 16,
    
    # Training Parameters
    lr = 3e-4,
    l1_coefficient = 1e-3,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 64,
    total_training_tokens = 1_000_000 * 5, 
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500, # Doesn't currently matter.
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 1e-8,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_benchmark_tests",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

def main():
    sparse_autoencoder = language_model_sae_runner(cfg)

main()
# os.environ["TOKENIZERS_PARALLELISM"] = "false"
# import cProfile, pstats, io
# from pstats import SortKey
# pr = cProfile.Profile()
# pr.enable()
# # ... do something ...
# main()
# pr.disable()
# s = io.StringIO()
# sortby = SortKey.CUMULATIVE
# ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
# ps.print_stats()
# print(s.getvalue())


# Tiny Stories

In [None]:
import torch
import os 

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner



os.environ["TOKENIZERS_PARALLELISM"] = "false"
cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "tiny-stories-2L-33M",
    hook_point = "blocks.1.mlp.hook_post",
    hook_point_layer = 1,
    d_in = 4096,
    dataset_path = "roneneldan/TinyStories",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 4,
    
    # Training Parameters
    lr = 1e-4,
    l1_coefficient = 3e-4,
    train_batch_size = 4096,
    context_size = 128,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 10, # want 500M eventually.
    store_batch_size = 32,
    
    # Resampling protocol
    feature_sampling_method = 'l2',
    feature_sampling_window = 2500, # Doesn't currently matter.
    feature_reinit_scale = 0.2,
    dead_feature_window=1250,
    dead_feature_threshold = 0.0005,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_language_benchmark_tests",
    wandb_entity = None,
    wandb_log_frequency=10,
    
    # Misc
    device = "mps",
    seed = 42,
    n_checkpoints = 0,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)


# Toy Model

In [None]:

from sae_training.toy_model_runner import SAEToyModelRunnerConfig, toy_model_sae_runner


cfg = SAEToyModelRunnerConfig(
    
    # Model Details
    n_features=200,
    n_hidden=5,
    n_correlated_pairs=0,
    n_anticorrelated_pairs=0,
    feature_probability=0.025,
    model_training_steps=10_000,
    
    # SAE Parameters
    d_sae=240,
    l1_coefficient=0.001,
    
    # SAE Train Config
    train_batch_size=1028,
    feature_sampling_window=3_000,
    dead_feature_window=1_000,
    feature_reinit_scale=0.5,
    total_training_tokens=4096*300,
    
    # Other parameters
    log_to_wandb=True,
    wandb_project="sae-training-test",
    wandb_log_frequency=5,
    device="mps",
)

trained_sae = toy_model_sae_runner(cfg)

assert trained_sae is not None
