### Runs Gelu-1L for Gated SAEs with resampling

In [None]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

from sae.train import ModelTrainer
from sae.config import Config 

In [None]:
config_inputs = {
    # Model and Hook Point
    'model_name': 'gelu-1l',
    'hook_point': 'blocks.0.hook_mlp_out',
    'hook_point_layer': 0,
    'hook_point_head_index': None,
    'd_in': 512,

    # Dataset
    'dataset_path': 'NeelNanda/c4-tokenized-2b',
    'is_dataset_tokenized': True,
    
     # Activation Store Parameters
    'n_batches_in_store_buffer': 128,
    'store_batch_size': 4,
    'train_batch_size': 4096,
    'context_size': 1024,

    # Outputs
    'log_to_wandb': True,
    'wandb_project': 'test_gelu_1l',
    'wandb_log_frequency': 10,
    'eval_frequency': 500,
    'sparsity_log_frequency': 5000,
    'n_checkpoints': 5,
    'checkpoint_path': '../outputs/checkpoints',

    # Sparse Autoencoder Parameters
    'expansion_factor': 64,
    'subtract_b_dec_from_inputs': True,
    'use_gated_sparse_autoencoder': True,

    'normalise_w_dec': True,
    'clip_grad_norm': False,

    # Resampling
    'feature_resampling_method': 'anthropic',
    'resample_frequency': 10000,
    'max_resample_step': 40001,
    'resample_batches': 128,
    'feature_reinit_scale': 0.2,
    'min_sparsity_for_resample': 1e-6,

    # General
    'seed': 42,
    'total_training_steps': 200000,

    # Learning rate parameters
    'lr': 3e-4,
    'lr_scheduler_name': 'constant',

    # Loss Function
    'mse_loss_coefficient': 1,
    'l1_coefficient': 0.005,

}

cfg = Config(**config_inputs)

In [None]:
mod = ModelTrainer(cfg)
mod.setup()
mod.train()