### Test Transcoder for Gelu-1L with Anthropic's implementation

In [1]:
%load_ext autoreload
%autoreload 2

import sys
sys.path.append("../")

from sae.train import ModelTrainer
from sae.config import Config 
from sae.activation_store import ActivationsStore, DoubleActivationStore


In [2]:
config_inputs = {
    # Model and Hook Point
    "model_name": "gpt2-small",
    "hook_point": "blocks.8.ln2.hook_normalized",
    "hook_point_layer": 8,
    "hook_point_head_index": None,
    "d_in": 768,

    'different_output': True,
    'hook_point_output': 'blocks.8.hook_mlp_out',
    'hook_point_layer_output': 8,
    'hook_point_head_index_output': None,
    'd_out': 768,

    # Dataset
    'dataset_path': 'Skylion007/openwebtext',
    'is_dataset_tokenized': False,
    
     # Activation Store Parameters
    'n_batches_in_store_buffer': 128,
    'store_batch_size': 16,
    'train_batch_size': 4096,
    'context_size': 128,

    # Outputs
    'log_to_wandb': True,
    'wandb_project': 'gpt2-small-transcoders',
    'wandb_log_frequency': 10,
    'eval_frequency': 500,
    'sparsity_log_frequency': 5000,
    'n_checkpoints': 5,
    'checkpoint_path': '../outputs/checkpoints',

    # Sparse Autoencoder Parameters
    'expansion_factor': 32,
    'normalise_w_dec': True,
    'subtract_b_dec_from_inputs': True,
    'b_dec_init_method': 'mean',

    # General
    'seed': 42,
    'total_training_steps': 200000,

    # Learning rate parameters
    'lr': 0.0004,
    'lr_scheduler_name': 'constant_with_warmup',
    'lr_warm_up_steps': 5000,
    
    # Loss Function
    'mse_loss_coefficient': 1,
    'mse_loss_type': 'standard',
    'l1_coefficient': 5,
    # 'l0_coefficient': 0, #7e-5,
    # 'epsilon_l0_approx': 0.2,
    
    # 'sparse_loss_coefficient': 0, #1e-6,
    # 'min_sparsity_target': 1e-5,

}

cfg = Config(**config_inputs)

In [None]:
mod = ModelTrainer(cfg)
mod.setup()
mod.train()

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  cuda
creating activation store
creating data loader
buffer
dataloader
creating sae
creating wanbd


[34m[1mwandb[0m: Currently logged in as: [33meoin[0m. Use [1m`wandb login --relogin`[0m to force relogin


  0%|          | 0/200000 [00:00<?, ?it/s]

main base loss tensor(3.6225, device='cuda:0')
main base loss tensor(3.6225, device='cuda:0')
main base loss tensor(3.5279, device='cuda:0')
main base loss tensor(3.5279, device='cuda:0')
main base loss tensor(3.6720, device='cuda:0')
main base loss tensor(3.6720, device='cuda:0')
main base loss tensor(3.5444, device='cuda:0')
main base loss tensor(3.5444, device='cuda:0')
main base loss tensor(3.5521, device='cuda:0')
main base loss tensor(3.5521, device='cuda:0')
main base loss tensor(3.4225, device='cuda:0')
main base loss tensor(3.4225, device='cuda:0')
main base loss tensor(3.3311, device='cuda:0')
main base loss tensor(3.3311, device='cuda:0')
main base loss tensor(3.4627, device='cuda:0')
main base loss tensor(3.4627, device='cuda:0')
main base loss tensor(3.6772, device='cuda:0')
main base loss tensor(3.6772, device='cuda:0')
main base loss tensor(3.3136, device='cuda:0')
main base loss tensor(3.3136, device='cuda:0')
main base loss tensor(3.7741, device='cuda:0')
main base los

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



main base loss tensor(3.6253, device='cuda:0')
main base loss tensor(3.6253, device='cuda:0')
main base loss tensor(3.4373, device='cuda:0')
main base loss tensor(3.4373, device='cuda:0')
main base loss tensor(3.3973, device='cuda:0')
main base loss tensor(3.3973, device='cuda:0')
main base loss tensor(3.7478, device='cuda:0')
main base loss tensor(3.7478, device='cuda:0')
main base loss tensor(3.4122, device='cuda:0')
main base loss tensor(3.4122, device='cuda:0')
main base loss tensor(3.6136, device='cuda:0')
main base loss tensor(3.6136, device='cuda:0')
main base loss tensor(3.2525, device='cuda:0')
main base loss tensor(3.2525, device='cuda:0')
main base loss tensor(3.3376, device='cuda:0')
main base loss tensor(3.3376, device='cuda:0')
main base loss tensor(3.1911, device='cuda:0')
main base loss tensor(3.1911, device='cuda:0')
main base loss tensor(3.1834, device='cuda:0')
main base loss tensor(3.1834, device='cuda:0')
main base loss tensor(3.4934, device='cuda:0')
main base los

In [None]:
# config_inputs = {
#     # Model and Hook Point
#     "model_name": "gelu-1l",
#     "hook_point": "blocks.0.ln2.hook_normalized",
#     "hook_point_layer": 0,
#     "hook_point_head_index": None,
#     "d_in": 512,

#     'different_output': True,
#     'hook_point_output': 'blocks.0.hook_mlp_out',
#     'hook_point_layer_output': 0,
#     'hook_point_head_index_output': None,
#     'd_out': 512,

#     # Dataset
#     'dataset_path': 'NeelNanda/c4-tokenized-2b',
#     'is_dataset_tokenized': True,
    
#      # Activation Store Parameters
#     'n_batches_in_store_buffer': 128,
#     'store_batch_size': 4,
#     'train_batch_size': 4096,
#     'context_size': 1024,

#     # Outputs
#     'log_to_wandb': True,
#     'wandb_project': 'test_gelu_1l',
#     'wandb_log_frequency': 10,
#     'eval_frequency': 10,
#     'sparsity_log_frequency': 5000,
#     'n_checkpoints': 5,
#     'checkpoint_path': '../outputs/checkpoints',

#     # Sparse Autoencoder Parameters
#     'expansion_factor': 64,
#     'normalise_initial_decoder_weights': True,
#     'initial_decoder_norm': 0.1,
#     'initialise_encoder_to_decoder_transpose': True,

#     'normalise_w_dec': False,
#     'clip_grad_norm': True,
#     'scale_input_norm': False,

#     # General
#     'seed': 42,
#     'total_training_steps': 200000,

#     # Learning rate parameters
#     'lr': 5e-5,
#     'lr_scheduler_name': 'constant',

#     # Loss Function
#     'mse_loss_coefficient': 1,
#     'l1_coefficient': 0.,
#     'weight_l1_by_decoder_norms': True,
    
#     # Warm up loss coefficients
#     'l1_warmup': True,
#     'l1_warmup_steps': 10000,
# }

# cfg = Config(**config_inputs)