Skip to content

hijohnnylin/mats_sae_training

 
 

Repository files navigation

Screenshot 2024-03-21 at 3 08 28 pm

MATS SAE Training

License: MIT build Deploy Docs codecov

The MATS SAE training codebase (we'll rename it soon) exists to help researchers:

  • Train sparse autoencoders.
  • Analyse sparse autoencoders and neural network internals.
  • Generate insights which make it easier to create safe and aligned AI systems.

Quick Start

Set Up

This project uses Poetry for dependency management. Ensure Poetry is installed, then to install the dependencies, run:

poetry install

Loading Sparse Autoencoders from Huggingface

Previously trained sparse autoencoders can be loaded from huggingface with close to single line of code. For more details and performance metrics for these sparse autoencoder, read my blog post.

import torch 
from sae_training.utils import LMSparseAutoencoderSessionloader
from huggingface_hub import hf_hub_download

layer = 8 # pick a layer you want.
REPO_ID = "jbloom/GPT2-Small-SAEs"
FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
model, sparse_autoencoder, activation_store = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path = path
)
sparse_autoencoder.eval()

You can also load the feature sparsity from huggingface.

FILENAME = f"final_sparse_autoencoder_gpt2-small_blocks.{layer}.hook_resid_pre_24576_log_feature_sparsity.pt"
path = hf_hub_download(repo_id=REPO_ID, filename=FILENAME)
log_feature_sparsity = torch.load(path, map_location=sparse_autoencoder.cfg.device)

Background

We highly recommend this tutorial.

High Level

Motivation

  • Accelerate SAE Research: Support fast experimentation to understand SAEs and improve SAE training so we can train SAEs on larger and more diverse models.
  • Make Research like Play: Support research into language model internals via SAEs. Good tooling can make research tremendously exciting and enjoyable. Balancing modifiability and reliability with ease of understanding / access is the name of the game here.
  • Build an awesome community: Mechanistic Interpretability already has an awesome community but as that community grows, it makes sense that there will be niches. I'd love to build a great community around Sparse Autoencoders.

Goals

SAE Training: SAE Training features will fit into a number of categories including:

  • Making it easy to train SAEs: Training SAEs is hard for a number of reasons and so making it easy for people to train SAEs with relatively little expertise seems like the main way this codebase will create value.
  • Training SAEs on more models: Supporting training of SAEs on more models, architectures, different activations within those models.
  • Being better at training SAEs: Enabling methodological changes which may improve SAE performance as measured by reconstruction loss, Cross Entropy Loss when using reconstructed activation, L1 loss, L0 and interpretability of features as well as improving speed of training or reducing the compute resources required to train SAEs.
  • Being better at measuring SAE Performance: How do we know when SAEs are doing what we want them to? Improving training metrics should allow better decisions about which methods to use and which hyperparameters choices we make.
  • Training SAE variants: People are already training “Transcoders” which map from one activation to another (such as before / after an MLP layer). These can be easily supported with a few changes. Other variants will come in time and

Analysis with SAEs: Using SAEs to understand neural network internals is an exciting, but complicated task.

  • Feature-wise Interpretability: This looks something like "for each feature, have as much knowledge about it as possible". Part of this will feature dashboard improvements, or supporting better integrations with Neuronpedia.
  • Mechanistic Interpretability: This comprises the more traditional kinds of Mechanistic Interpretability which TransformerLens supports and should be supported by this codebase. Making it easy to patch, ablate or otherwise intervene on features so as to find circuits will likely speed up lots of researchers.

Other Stuff

I think there are lots of other types of analysis that could be done in the future with SAE features. I've already explored many different types of statistical tests which can reveal interesting properties of features. There are also things like saliency mapping and attribution techniques which it would be nice to support.

  • Accessibility and Code Quality: The codebase won’t be used if it doesn’t work and it also won’t get used if it’s too hard to understand, modify or read. Making the code accessible: This involves tasks like turning the code base into a python package.
  • Knowing how the code is supposed to work: Is the code well-documented? This will require docstrings, tutorials and links to related work and publications. Getting aligned on what the code does is critical to sharing a resource like this.
  • Knowing the code works as intended: All code should be tested. Unit tests and acceptance tests are both important.
  • Knowing the code is actually performant: This will ensure code works as intended. However deep learning introduces lots of complexity which makes actually running benchmarks essential to having confidence in the code.

Code Overview

The codebase contains 2 folders worth caring about:

  • sae_training: The main body of the code is here. Everything required for training SAEs.
  • sae_analysis: This code is mainly house the feature visualizer code we use to generate dashboards. It was written by Callum McDougal but I've ported it here with permission and edited it to work with a few different activation types.

Some other folders:

  • tutorials: These aren't well maintained but I'll aim to clean them up soon.
  • tests: When first developing the codebase, I was writing more tests. I have no idea whether they are currently working!

I've been commiting my research code to the Research folder but am not expecting other people use or look at that.

Training your own Sparse Autoencoder

Sparse Autoencoders can be intimidating at first but it's fairly simple to train one once you know what each part of the config does. I've created a config class which you instantiate and pass to the runner which will complete your training run and log it's progress to wandb.

Let's go through the major components of the config:

  • Data: SAE's autoencode model activations. We need to specify the model, the part of the models activations we want to autoencode and the dataset the model is operating on when generating those activations. We now automatically detect if that dataset is tokenized and most huggingface datasets should be fine. One slightly annoying detail is that you need to know the dimensionality of those activations when contructing your SAE but you can get that in the transformerlens docs. Any language model in the table from those docs should work.
  • SAE Parameters: Your expansion factor will determine the size of your SAE and the decoder bias initialization method should always be geometric_median or mean. Mean is faster but theoretically sub-optimal. I use another package to get the geometric median and it can be quite slow.
  • Training Parameters: These are most critical. The right L1 coefficient (coefficient in the activation sparsity inducing term in the loss) changes with your learning rate but a good bet would be to use LR 4e-4 and L1 8e-5 for GPT2 small. These will vary for other models and playing around with them / short runs can be helpful. Training batch size of 4096 is standard and I'm not really sure whether there's benefit to playing with it. In theory a larger context size (one accurate to whatever the model was trained with) seems good but it's computationally cheaper to use 128. Learning rate warm up is important to avoid dead neurons.
  • Activation Store Parameters: The activation store shuffles activations from forward passes over samples from your data. The larger it is, the better shuffling you'll get. In theory more shuffling is good. The total training tokens is a very important parameter. The more the better, but you'll often see good results having trained on a few hundred million tokens. Store batch batch size is a function of your gpu and how many forward passes of your model you want to do simultaneously when collecting activations.
  • Dead Neurons / Sparsity Metrics: The config around resampling was more important when we were using resampling to avoid dead neurons (see Anthropic's post on this), but using ghost gradients, the resampling protcol is much simpler. I'd always set ghost grad to True and feature sampling method to None. The feature sampling window effects the dashboard statistics tracking feature occurence and the dead feature window tracks how many forward passes a neuron must not activate before we apply ghost grads to it.
  • WANDB: Fairly straightfoward. Don't set log frequency too high or your dashboard will be slow!
  • Device: I can run this code on my macbook with "mps" but mostly do runs with cuda.
  • Dtype: Float16 maybe could work but I had some funky results and have left it at float32 for the time being.
  • Checkpoints: I'd collected checkpoints on runs you care about but turn them off when tuning since it can be slow.
import torch
import os 
import sys 

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB__SERVICE_WAIT"] = "300"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.lm_runner import language_model_sae_runner

cfg = LanguageModelSAERunnerConfig(

    # Data Generating Function (Model + Training Distibuion)
    model_name = "gpt2-small",
    hook_point = "blocks.2.hook_resid_pre",
    hook_point_layer = 2,
    d_in = 768,
    dataset_path = "Skylion007/openwebtext",
    is_dataset_tokenized=False,
    
    # SAE Parameters
    expansion_factor = 64,
    b_dec_init_method = "geometric_median",
    
    # Training Parameters
    lr = 0.0004,
    l1_coefficient = 0.00008,
    lr_scheduler_name="constantwithwarmup",
    train_batch_size = 4096,
    context_size = 128,
    lr_warm_up_steps=5000,
    
    # Activation Store Parameters
    n_batches_in_buffer = 128,
    total_training_tokens = 1_000_000 * 300,
    store_batch_size = 32,
    
    # Dead Neurons and Sparsity
    use_ghost_grads=True,
    feature_sampling_window = 1000,
    dead_feature_window=5000,
    dead_feature_threshold = 1e-6,
    
    # WANDB
    log_to_wandb = True,
    wandb_project= "mats_sae_training_gpt2",
    wandb_entity = None,
    wandb_log_frequency=100,
    
    # Misc
    device = "cuda",
    seed = 42,
    n_checkpoints = 10,
    checkpoint_path = "checkpoints",
    dtype = torch.float32,
    )

sparse_autoencoder = language_model_sae_runner(cfg)

Loading a Pretrained Language Model

Once your SAE is trained, the final SAE weights will be saved to wandb and are loadable via the session loader. The session loader will return:

  • The model your SAE was trained on (presumably you're interested in studying this. It's always a HookedTransformer)
  • Your SAE.
  • An activations loader: from which you can get randomly sampled activations or batches of tokens from the dataset you used to train the SAE. (more on this in the tutorial)
from sae_training.utils import LMSparseAutoencoderSessionloader

path ="path/to/sparse_autoencoder.pt"
model, sparse_autoencoder, activations_loader = LMSparseAutoencoderSessionloader.load_session_from_pretrained(
    path
)

Tutorials

I wrote a tutorial to show users how to do some basic exploration of their SAE.

  • evaluating_your_sae.ipynb: A quick/dirty notebook showing how to check L0 and Prediction loss with your SAE, as well as showing how to generate interactive dashboards using Callum's reporduction of Anthropics interface.
  • logits_lens_with_features.ipynb: A notebook showing how to reproduce the analysis from this LessWrong post.

Example Dashboard

WandB Dashboards provide lots of useful insights while training SAE's. Here's a screenshot from one training run.

screenshot

Example Output

Here's one feature we found in the residual stream of Layer 10 of GPT-2 Small:

alt text. Open gpt2_resid_pre10_predict_pronoun_feature.html in your browser to interact with the dashboard (WIP).

Note, probably this feature could split into more mono-semantic features in a larger SAE that had been trained for longer. (this was was only about 49152 features trained on 10M tokens from OpenWebText).

Citations and References:

Research:

Reference Implementations:

About

Training Sparse Autoencoders on Language Models

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • HTML 98.9%
  • Other 1.1%