In [1]:
import os
import sys
from pathlib import Path

import gc
import itertools
import math
import os
import random
import sys
from collections import Counter, defaultdict
from copy import deepcopy
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Any, Callable, Literal, TypeAlias

import circuitsvis as cv
import einops
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
import requests
import torch as t
from datasets import load_dataset
from huggingface_hub import hf_hub_download
from IPython.display import HTML, IFrame, clear_output, display
from jaxtyping import Float, Int
from openai import OpenAI
from rich import print as rprint
from rich.table import Table
from sae_lens import (
    SAE,
    ActivationsStore,
    HookedSAETransformer,
    LanguageModelSAERunnerConfig,
    SAEConfig,
    SAETrainingRunner,
    upload_saes_to_huggingface,
)
from sae_lens.toolkit.pretrained_saes_directory import get_pretrained_saes_directory
from sae_vis import SaeVisConfig, SaeVisData, SaeVisLayoutConfig
from tabulate import tabulate
from torch import Tensor, nn
from torch.distributions.categorical import Categorical
from torch.nn import functional as F
from tqdm.auto import tqdm
from transformer_lens import ActivationCache, HookedTransformer
from transformer_lens.hook_points import HookPoint
from transformer_lens.utils import get_act_name, test_prompt, to_numpy

device = t.device("mps" if t.backends.mps.is_available() else "cuda" if t.cuda.is_available() else "cpu")

MAIN = __name__ == "__main__"

In [3]:
attn_model = HookedSAETransformer.from_pretrained("attn-only-2l-demo")

total_training_steps = 30_000  # probably we should do more
batch_size = 4096
total_training_tokens = total_training_steps * batch_size

lr_warm_up_steps = l1_warm_up_steps = total_training_steps // 10  # 10% of training
lr_decay_steps = total_training_steps // 5  # 20% of training

layer = 0

cfg = LanguageModelSAERunnerConfig(
    #
    # Data generation
    model_name="attn-only-2l-demo",
    hook_name=f"blocks.{layer}.attn.hook_z",
    hook_layer=layer,
    d_in=attn_model.cfg.d_head * attn_model.cfg.n_heads,
    dataset_path="apollo-research/Skylion007-openwebtext-tokenizer-EleutherAI-gpt-neox-20b",
    is_dataset_tokenized=True,
    prepend_bos=True,  # you should use whatever the base model was trained with
    streaming=True,  # we could pre-download the token dataset if it was small.
    train_batch_size_tokens=batch_size,
    context_size=attn_model.cfg.n_ctx,
    #
    # SAE architecture
    architecture="gated",
    expansion_factor=16,
    b_dec_init_method="zeros",
    apply_b_dec_to_input=True,
    normalize_sae_decoder=False,
    scale_sparsity_penalty_by_decoder_norm=True,
    decoder_heuristic_init=True,
    init_encoder_as_decoder_transpose=True,
    #
    # Activations store
    n_batches_in_buffer=64,
    training_tokens=total_training_tokens,
    store_batch_size_prompts=16,
    #
    # Training hyperparameters (standard)
    lr=1e-4,
    adam_beta1=0.9,
    adam_beta2=0.999,
    lr_scheduler_name="constant",
    lr_warm_up_steps=lr_warm_up_steps,  # avoids large number of initial dead features
    lr_decay_steps=lr_decay_steps,
    #
    # Training hyperparameters (SAE-specific)
    l1_coefficient=2,
    l1_warm_up_steps=l1_warm_up_steps,
    use_ghost_grads=False,  # we don't use ghost grads anymore
    feature_sampling_window=1000,  # how often we resample dead features
    dead_feature_window=500,  # size of window to assess whether a feature is dead
    dead_feature_threshold=1e-4,  # threshold for classifying feature as dead, over window
    #
    # Logging / evals
    log_to_wandb=False,  # always use wandb unless you are just testing code.
    #
    # Misc.
    device=str(device),
    seed=42,
    n_checkpoints=5,
    checkpoint_path="checkpoints",
    dtype="float32",
)

print("Comment this code out to train! Otherwise, it will load in the already trained model.")
t.set_grad_enabled(True)
runner = SAETrainingRunner(cfg)
sae = runner.run()

# hf_repo_id = "callummcdougall/arena-demos-attn2l"
# sae_id = f"{cfg.hook_name}-v2"

# # upload_saes_to_huggingface({sae_id: sae}, hf_repo_id=hf_repo_id)

# attn_sae = SAE.from_pretrained(release=hf_repo_id, sae_id=sae_id, device=str(device))[0]

config.json:   0%|          | 0.00/1.25k [00:00<?, ?B/s]



model_final.pth:   0%|          | 0.00/219M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/156 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.08M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/457k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/2.11M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/90.0 [00:00<?, ?B/s]

Loaded pretrained model attn-only-2l-demo into HookedTransformer
Comment this code out to train! Otherwise, it will load in the already trained model.
Loaded pretrained model attn-only-2l-demo into HookedTransformer


Downloading readme:   0%|          | 0.00/300 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/73 [00:00<?, ?it/s]

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Training SAE:   0%|          | 0/122880000 [00:00<?, ?it/s]ERROR:sae_lens:Error during view operation: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
ERROR:sae_lens:Error during view operation: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
ERROR:sae_lens:Error during view operation: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.
ERROR:sae_lens:Error during view operation: view size is not 

InterruptedException: 

700| auxiliary_reconstruction_loss: 7.65157 | l1_loss: 6.42946 | mse_loss: 5.87090:   2%|▏         | 2867200/122880000 [01:22<45:50, 43638.93it/s]