# Notebook for experimenting and testing with SAEs

In [1]:
%load_ext autoreload
%autoreload 2

import torch
from sae_lens import LanguageModelSAERunnerConfig, SAETrainingRunner
import wandb
import os, sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))

from utils.testing import access_wandb_runs, download_models, load_our_model, update_run
from utils.sae import create_sae_trainer

if torch.cuda.is_available():
    device = 'cuda'
else:
    device = 'cpu'
    
print(f'Using device: {device}')

from dotenv import load_dotenv
load_dotenv()

Using device: cuda


True

## Testing Tutorial

In [None]:
total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = 0
lr_decay_steps = total_training_steps // 5  # 20% of training
l1_warm_up_steps = total_training_steps // 20  # 5% of training

cfg = LanguageModelSAERunnerConfig(
    # Data Generating Function (Model + Training Distibuion)
    model_name="tiny-stories-1M",  # our model (more options here: https://neelnanda-io.github.io/TransformerLens/generated/model_properties_table.html)
    hook_name="blocks.6.hook_mlp_out",  # A valid hook point (see more details here: https://neelnanda-io.github.io/TransformerLens/generated/demos/Main_Demo.html#Hook-Points)
    hook_layer=6,  # Only one layer in the model.
    d_in=64,  # the width of the mlp output.
    dataset_path="apollo-research/roneneldan-TinyStories-tokenizer-gpt2",  # this is a tokenized language dataset on Huggingface for the Tiny Stories corpus.
    
    # SAE Parameters
    expansion_factor=16,  # the width of the SAE. Larger will result in better stats but slower training.
    b_dec_init_method="zeros",  # The geometric median can be used to initialize the decoder weights.
    apply_b_dec_to_input=False,  # We won't apply the decoder weights to the input.
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    normalize_activations="expected_average_only_in",
    
    # Training Parameters
    lr=3e-4, 
    lr_scheduler_name="constant",  # constant learning rate with warmup. Could be better schedules out there.
    lr_warm_up_steps=lr_warm_up_steps,  # this can help avoid too many dead features initially.
    lr_decay_steps=lr_decay_steps,  # this will help us avoid overfitting.
    l1_coefficient=5,  # will control how sparse the feature activations are
    l1_warm_up_steps=l1_warm_up_steps,  # this can help avoid too many dead features initially.
    lp_norm=1.0,  # the L1 penalty (and not a Lp for p < 1)
    train_batch_size_tokens=batch_size,
    context_size=512,
    
    # Activation Store Parameters
    n_batches_in_buffer=64,  # controls how many activations we store / shuffle.
    training_tokens=total_training_tokens,  # 100 million tokens is quite a few, but we want to see good stats. Get a coffee, come back.
    store_batch_size_prompts=16,
    
    # WANDB
    log_to_wandb=True,  # always use wandb unless you are just testing code.
    wandb_project="ablation-sae",
    run_name='SAE-test',
    
    # Misc
    device=device,
    seed=42,
    checkpoint_path="checkpoints",
    dtype="float32",
)

In [8]:
sparse_autoencoder = SAETrainingRunner(cfg).run()



Loaded pretrained model tiny-stories-1M into HookedTransformer


VBox(children=(Label(value='0.010 MB of 0.010 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))


[A

[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A


[A[A[A

[A[A

  lambda data: self._console_raw_callback("stderr", data),
Training SAE:   0%|          | 0/122880000 [02:28<?, ?it/s]


[A[A

[A[A

[A[A

                                                                 

[A[A

[A[A

[A[A

[A[A

                                                                 

[A[A

[A[A

[A[A

[A[A

                                                                 

[A[A

[A[A

[A[A

                                                                 

[A[A

[A[A

[A[A

[A[A

 

VBox(children=(Label(value='0.533 MB of 0.533 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
details/current_l1_coefficient,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
details/current_learning_rate,██████████████████████████████▆▆▆▅▅▅▄▄▃▁
details/n_training_tokens,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇█
losses/l1_loss,█▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
losses/mse_loss,▄▅███▇▆▅▅▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁
losses/overall_loss,▃▇███▇▆▆▆▅▅▅▄▃▃▃▃▃▃▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁
losses/raw_l1_loss,▁▇██▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▃▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂
metrics/explained_variance,▁▂▃▃▄▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇▇▇▇▇▇██▇█▇▇█████████
metrics/explained_variance_std,█▃▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
metrics/l0,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
details/current_l1_coefficient,5.0
details/current_learning_rate,0.0
details/n_training_tokens,122880000.0
losses/l1_loss,3.02543
losses/mse_loss,15.08185
losses/overall_loss,30.209
losses/raw_l1_loss,15.12715
metrics/explained_variance,0.59043
metrics/explained_variance_std,0.19541
metrics/l0,5.85522


In [None]:
# Save the model to wandb
api = wandb.Api()

entity = os.getenv('WANDB_ENTITY')
runs = api.runs(path=f'{entity}/{cfg.wandb_project}')

print(len(runs))
run = runs[0]


torch.save(sparse_autoencoder.state_dict(), 'sparse_autoencoder.pt')
run.upload_file('sparse_autoencoder.pt')


1


<File sparse_autoencoder.pt (application/vnd.snesdev-page-table) 518.5KiB>

## Getting it to work with our models

In [2]:
wandb_ablated_runs = access_wandb_runs()
download_models(wandb_ablated_runs, "../model_weights")

Downloaded best_model_20241105.pt to ../model_weights/fanciful-fog-78
Downloaded config.yaml to ../model_weights/fanciful-fog-78
Downloaded sae.pt to ../model_weights/fanciful-fog-78
Downloaded best_model_20241106.pt to ../model_weights/earnest-moon-79
Downloaded config.yaml to ../model_weights/earnest-moon-79
Downloaded best_model_20241105.pt to ../model_weights/super-violet-80
Downloaded config.yaml to ../model_weights/super-violet-80
Downloaded best_model_20241109.pt to ../model_weights/upbeat-glitter-83
Downloaded config.yaml to ../model_weights/upbeat-glitter-83
Downloaded best_model_20241109.pt to ../model_weights/comfy-cherry-84
Downloaded config.yaml to ../model_weights/comfy-cherry-84
Downloaded best_model_20241111.pt to ../model_weights/cosmic-leaf-81-part2
Downloaded config.yaml to ../model_weights/cosmic-leaf-81-part2
Downloaded best_model_20241114.pt to ../model_weights/major-planet-86-part3
Downloaded config.yaml to ../model_weights/major-planet-86-part3
Downloaded best_m

In [None]:
# Whether we are in a slurm environment
IS_HPC = False
NUM_TRAINING_STEPS = 100000

for run in wandb_ablated_runs:
    
    # if "sae_trained" in run.summaryMetrics.keys():
    #     print(f"Skipping run {run.name} as an SAE has already been trained")
    #     continue
    
    print(f"Training SAE for run {run.name}")
    
    model_dir = f"../model_weights/{run.name}"
    
    try:
        ablated_trained_model = load_our_model(model_dir, device=device)
    except Exception as e:
        print(f"Failed to load model for run {run.name}")
        continue
    
    
    if not IS_HPC:
        sae_trainer = create_sae_trainer(ablated_trained_model, device=device, run_name=run.name, total_training_steps=NUM_TRAINING_STEPS)
        sae = sae_trainer.run()
        torch.save(sae.state_dict(), f'{model_dir}/sae.pt', root=f'{model_dir}')
        run.upload_file(f'{model_dir}/sae.pt')
        update_run(run, {"sae_trained": True})
    else:
        # Load hpc_scripts/sae_template.sh as a string
        with open('hpc_scripts/sae_template.sh', 'r') as f:
            template = f.read()
            
        # Replace the placeholders in the template with the correct values
        template = template.replace("{RUN_NAME}", run.name)
        template = template.replace("{NUM_TRAINING_STEPS}", f"{NUM_TRAINING_STEPS}")
        
        # Write the new script to a file called run_name.sh
        with open(f'hpc_scripts/{run.name}.sh', 'w') as f:
            f.write(template)
            
        os.system(f"sbatch hpc_scripts/{run.name}.sh")

Training SAE for run fanciful-fog-78


  state_dict = torch.load(model_path, map_location=device)
You just passed in a model which will override the one specified in your configuration: None. As a consequence this run will not be reproducible via configuration alone.
