# Train Transcoders on LLMs automatically

The origin training code in Transcoder is written for the specific layer in the model. To conviniently train our model, we rewrite the training code.

## Setup

First, import the necessary libraries.

In [1]:
import torch
import gc
from transformer_lens import HookedTransformer
import transformer_lens.utils as utils

import json
import sys, os
from pathlib import Path
sys.path.append(os.path.abspath("..")) 
os.environ["TOKENIZERS_PARALLELISM"] = "false"

from sae_training.config import LanguageModelSAERunnerConfig
from sae_training.utils import LMSparseAutoencoderSessionloader
from sae_training.train_sae_on_language_model import train_sae_on_language_model

## Train

First, TransformerLens doesn't support all the models in Huggingface. So before training, choose a model you want [here](https://transformerlensorg.github.io/TransformerLens/generated/model_properties_table.html).

In [2]:
model_name = "meta-llama/Llama-3.1-8B-Instruct"     

Now, we download the pretrained model.

In [3]:
device = "cuda:7"
# model = HookedTransformer.from_pretrained(model_name=model_name, device=device, n_devices=8, move_to_device=True)

So far, we already have the model. Let's first test it.

In [4]:
# model.eval()
# model.generate("5*(10+2*3). Let's think step by step.", max_new_tokens=50, temperature=0.1)

In [5]:
def make_layer_cfg(dataset_path: str, model_name: str, layer_idx: int, d_model: int, out_dir: str,
                   expansion_factor: int, lr: float, l1_coeff: float, train_batch_size: int, context_size: int,
                   device: torch.device, dtype: torch.dtype) -> LanguageModelSAERunnerConfig:
    """
    build RunnerConfig for a specific layer
    """
    layer_ckpt_dir = Path(out_dir) / f"layer_{layer_idx:02d}"
    layer_ckpt_dir.mkdir(parents=True, exist_ok=True)

    cfg = LanguageModelSAERunnerConfig(
        hook_point=f"blocks.{layer_idx}.ln2.hook_normalized",  # input of MLP
        hook_point_layer=layer_idx,
        d_in=d_model,
        dataset_path=dataset_path,
        is_dataset_tokenized=False,
        model_name=model_name,

        is_transcoder=True,
        out_hook_point=f"blocks.{layer_idx}.hook_mlp_out",
        out_hook_point_layer=layer_idx,
        d_out=d_model,

        expansion_factor=expansion_factor,
        b_dec_init_method="mean",

        # Training Parameters
        lr=lr,
        l1_coefficient=l1_coeff,
        lr_scheduler_name="constantwithwarmup",
        train_batch_size=train_batch_size,
        context_size=context_size,
        lr_warm_up_steps=5_000,

        # Activation Store Parameters
        n_batches_in_buffer=2,
        total_training_tokens=1_000_000 * 60,
        store_batch_size=4,

        # Dead Neurons and Sparsity
        use_ghost_grads=True,
        feature_sampling_method=None,
        feature_sampling_window=1000,
        resample_batches=512,
        dead_feature_window=5000,
        dead_feature_threshold=1e-8,

        log_to_wandb=False,
        use_tqdm=True,
        device=device,
        seed=42,
        n_checkpoints=3,
        checkpoint_path=str(layer_ckpt_dir),
        dtype=dtype,
    )
    return cfg


In [6]:
def train_one_layer(cfg: LanguageModelSAERunnerConfig):
    print(f"[Layer {cfg.hook_point_layer}] Start training: lr={cfg.lr} l1={cfg.l1_coefficient} Checkpoint dir: {cfg.checkpoint_path}")

    loader = LMSparseAutoencoderSessionloader(cfg)
    model, sparse_autoencoder, activations_loader = loader.load_session()

    sparse_autoencoder = train_sae_on_language_model(
        model, sparse_autoencoder, activations_loader,
        n_checkpoints=cfg.n_checkpoints,
        batch_size=cfg.train_batch_size,
        feature_sampling_method=cfg.feature_sampling_method,
        feature_sampling_window=cfg.feature_sampling_window,
        feature_reinit_scale=cfg.feature_reinit_scale,
        dead_feature_threshold=cfg.dead_feature_threshold,
        dead_feature_window=cfg.dead_feature_window,
        use_wandb=cfg.log_to_wandb,
        wandb_log_frequency=cfg.wandb_log_frequency
    )

    final_path = Path(cfg.checkpoint_path) / f"final_{sparse_autoencoder.get_name()}.pt"
    sparse_autoencoder.save_model(str(final_path))
    print(f"[Layer {cfg.hook_point_layer}] Saved: {final_path}")

    del model, sparse_autoencoder, activations_loader
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

In [7]:
def final_path_exists(ckpt_dir: Path) -> bool:
    """
    check whether final checkpoint exists
    """
    if not ckpt_dir.exists():
        return False
    for p in ckpt_dir.glob("final_*.pt"):
        return True
    return False

Now, define the params for training.

In [8]:
lr = 0.0004                                                 # learning rate
l1_coeff = 0.00014                                          # l1 sparsity regularization coefficient
expansion_factor = 16                                       # expansion factor（default 32）
dtype = torch.float32                                       # dtype

train_batch = 32                                            # training batch size
train_epoch = 200                                           # training epoch
context_size = 512                                          # context size

dataset_path = "cerebras/SlimPajama-627B"                   # dataset path
checkpoint_dir = f"./{model_name}_checkpoints"              # checkpoint directory

In [9]:
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)
d_model, n_layers = 4096, 32
print(f"Model: {model_name} | d_model={d_model} | n_layers={n_layers}")
meta_path = Path(checkpoint_dir) / "meta.json"
json.dump({"model_name": model_name, "d_model": d_model, "n_layers": n_layers}, open(meta_path, "w"))

layer_idx = 21
layer_dir = Path(checkpoint_dir)/f"layer_{layer_idx:02d}"
cfg = make_layer_cfg(dataset_path, model_name, layer_idx, d_model, checkpoint_dir, expansion_factor, lr, l1_coeff, train_batch, context_size,
                        device, dtype)

train_one_layer(cfg)

# layer_range = list(range(n_layers))

# for layer_idx in layer_range:
#     layer_dir = Path(checkpoint_dir)/f"layer_{layer_idx:02d}"
#     if final_path_exists(layer_dir):
#         print(f"[Layer {layer_idx}] final checkpoint exists, skip.")
#         continue
#     cfg = make_layer_cfg(dataset_path, model_name, layer_idx, d_model, checkpoint_dir, expansion_factor, lr, l1_coeff, train_batch, context_size,
#                          device, dtype)
    
#     train_one_layer(cfg)


print("All done.")

Model: meta-llama/Llama-3.1-8B-Instruct | d_model=4096 | n_layers=32
Run name: 65536-L1-0.00014-LR-0.0004-Tokens-6.000e+07
n_tokens_per_buffer (millions): 0.004096
Lower bound: n_contexts_per_buffer (millions): 8e-06
Total training steps: 1875000
Total wandb updates: 187500
n_tokens_per_feature_sampling_window (millions): 16.384
n_tokens_per_dead_feature_window (millions): 81.92
Using Ghost Grads.
We will reset the sparsity calculation 1875 times.
Number of tokens when resampling: 2048
Number tokens in sparsity calculation window: 3.20e+04
[Layer 21] Start training: lr=0.0004 l1=0.00014 Checkpoint dir: meta-llama/Llama-3.1-8B-Instruct_checkpoints/layer_21/i2c3nzsc


`torch_dtype` is deprecated! Use `dtype` instead!


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]



Loaded pretrained model meta-llama/Llama-3.1-8B-Instruct into HookedTransformer


Resolving data files:   0%|          | 0/59166 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/31428 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/31411 [00:00<?, ?it/s]

Dataset is not tokenized! Updating config.
Reinitializing b_dec with mean of activations
Previous distances: 63.996585845947266
New distances: 58.10400390625
Reinitializing b_dec with mean of activations
Previous distances: 6.712014198303223
New distances: 6.45009183883667


13516| MSE Loss 0.001544 | L1 0.000000:   1%|          | 432512/60000000 [20:47<60:44:29, 272.41it/s] 

[Early Stop] Stop at step 13517, best loss=0.001301
Saved model to meta-llama/Llama-3.1-8B-Instruct_checkpoints/layer_21/i2c3nzsc/final_sparse_autoencoder_meta-llama/Llama-3.1-8B-Instruct_blocks.21.ln2.hook_normalized_65536.pt


13516| MSE Loss 0.001544 | L1 0.000000:   1%|          | 432544/60000000 [20:59<60:44:29, 272.41it/s]

Saved model to meta-llama/Llama-3.1-8B-Instruct_checkpoints/layer_21/i2c3nzsc/final_sparse_autoencoder_meta-llama/Llama-3.1-8B-Instruct_blocks.21.ln2.hook_normalized_65536.pt
[Layer 21] Saved: meta-llama/Llama-3.1-8B-Instruct_checkpoints/layer_21/i2c3nzsc/final_sparse_autoencoder_meta-llama/Llama-3.1-8B-Instruct_blocks.21.ln2.hook_normalized_65536.pt


13516| MSE Loss 0.001544 | L1 0.000000:   1%|          | 432544/60000000 [20:59<48:11:31, 343.35it/s]

All done.





P.S. OOM is common since the sparsity feature dimension is very large. I handled this by decreasing the sparsity factor, as this is only a demo for validation rather than the formal version. If you prefer not to do so, consider using FSDP or other techniques to address it.