# This is only used for evaluation. Training is not yet supported!

In [1]:
from sae_lens.training.sparse_autoencoder import SparseAutoencoder
from sae_lens.training.config import LanguageModelSAERunnerConfig
from sae_lens.training.sae_group import SparseAutoencoderDictionary
from huggingface_hub import hf_hub_download

In [2]:
hf_model_id = "HuFY-dev/tanh_sae"
model_path = hf_hub_download(hf_model_id, "model.safetensors")
config_path = hf_hub_download(hf_model_id, "config.json")

model.safetensors:   0%|          | 0.00/907M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/186 [00:00<?, ?B/s]

In [30]:
import json
import torch
from safetensors import safe_open

device = "cuda" if torch.cuda.is_available() else "cpu"

with open(config_path) as f:
    config_dict = json.load(f)

tensors = {}
with safe_open(model_path, framework="pt", device=device) as f:  # type: ignore
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

d_in = config_dict['n_input_features']
d_sae = config_dict['n_learned_features']
cfg = LanguageModelSAERunnerConfig(
        d_in=d_in,
        expansion_factor=d_sae//d_in,
        normalize_sae_decoder=False,
        noise_scale=config_dict['noise_scale'],
        model_name="gpt2",
        activation_fn="tanh-relu",
        hook_point="blocks.{layer}.hook_mlp_out",
        hook_point_layer=list(range(config_dict['n_components'])),  # type: ignore
        dtype="torch.float32",
        device=device,
    )

sae_group = SparseAutoencoderDictionary(cfg)
with torch.no_grad():
    for single_sae in sae_group.autoencoders.values():
        layer = single_sae.hook_point_layer
        single_sae.W_enc.data = tensors['encoder.weight'].data[layer].T.clone()
        single_sae.b_enc.data = tensors['encoder.bias'].data[layer].clone()
        single_sae.W_dec.data = tensors['decoder.weight'].data[layer].T.clone()
        single_sae.b_dec.data = tensors['post_decoder_bias._bias_reference'].data[layer].clone()

Run name: 12288-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.08192
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06
Run name: 12288-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.08192
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 488
Total wandb updates: 48
n_tokens_per_feature_sampling_window (millions): 1048.576
n_tokens_per_dead_feature_window (millions): 524.288
We will reset the sparsity calculation 0 times.
Number tokens in sparsity calculation window: 8.19e+06
Run name: 12288-L1-0.001-LR-0.0003-Tokens-2.000e+06
n_tokens_per_buffer (millions): 0.08192
Lower bound: n_contexts_per_buffer (millions): 0.00064
Total training steps: 488
Total

In [31]:
# Now you should be able to use sae_group
sae_group

<sae_lens.training.sae_group.SparseAutoencoderDictionary at 0x29a2f0e50>

# Small sanity check

In [32]:
from typing import TypedDict
from transformer_lens import HookedTransformer
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

In [33]:
model_name = "gpt2"
model = HookedTransformer.from_pretrained(model_name)

Loaded pretrained model gpt2 into HookedTransformer


In [19]:
dataset_path = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"
torch_dataset = load_dataset(dataset_path, split="train", streaming=True).with_format("torch")

Downloading readme:   0%|          | 0.00/324 [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/64 [00:00<?, ?it/s]

In [34]:
TokenizedPrompt = list[int]
"""A tokenized prompt."""


class TokenizedPrompts(TypedDict):
    """Tokenized prompts."""

    input_ids: list[TokenizedPrompt]
    
class TorchTokenizedPrompts(TypedDict):
    """Tokenized prompts prepared for PyTorch."""

    input_ids: torch.Tensor

dl = DataLoader[TorchTokenizedPrompts](
            torch_dataset,
            batch_size=16,
            # Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not
            # here.
            shuffle=False,
            num_workers=1,
        )

In [35]:
hooked_layers = [sae_group.cfg.hook_point.format(layer=layer) for layer in sae_group.cfg.hook_point_layer]
saes_by_layer = {sae.hook_point_layer: sae for sae in sae_group.autoencoders.values()}
hooked_layers

['blocks.0.hook_mlp_out',
 'blocks.1.hook_mlp_out',
 'blocks.2.hook_mlp_out',
 'blocks.3.hook_mlp_out',
 'blocks.4.hook_mlp_out',
 'blocks.5.hook_mlp_out',
 'blocks.6.hook_mlp_out',
 'blocks.7.hook_mlp_out',
 'blocks.8.hook_mlp_out',
 'blocks.9.hook_mlp_out',
 'blocks.10.hook_mlp_out',
 'blocks.11.hook_mlp_out']

In [36]:
residuals = []
for i, batch in enumerate(dl):
    if i >= 1:
        break
    batch_tokens = batch["input_ids"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=hooked_layers)
    residuals = [cache[layer] for layer in hooked_layers]
    del cache

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


In [37]:
sae_hooks = ["hook_hidden_post", "hook_sae_out"]
for i in range(len(residuals)):
    autoencoder = saes_by_layer[i]
    _, cache = autoencoder.run_with_cache(residuals[i].to(autoencoder.device), names_filter=sae_hooks)
    reconstructed = cache["hook_sae_out"]
    feature_act = cache["hook_hidden_post"]
    l2_loss = torch.nn.functional.mse_loss(residuals[i].to(autoencoder.device), reconstructed)
    l1_loss = torch.nn.functional.l1_loss(feature_act, torch.zeros_like(feature_act)) * autoencoder.d_sae
    print(f"Layer {i}: L2 loss: {l2_loss}, L1 loss: {l1_loss}")
    del cache

Layer 0: L2 loss: 0.04830312356352806, L1 loss: 77.44003295898438
Layer 1: L2 loss: 0.042308878153562546, L1 loss: 172.5427703857422
Layer 2: L2 loss: 0.06545332819223404, L1 loss: 128.71791076660156
Layer 3: L2 loss: 0.04921870306134224, L1 loss: 186.860107421875
Layer 4: L2 loss: 0.05224926397204399, L1 loss: 207.7950439453125
Layer 5: L2 loss: 0.06069820001721382, L1 loss: 221.6451416015625
Layer 6: L2 loss: 0.07593274861574173, L1 loss: 241.45803833007812
Layer 7: L2 loss: 0.10232746601104736, L1 loss: 245.13211059570312
Layer 8: L2 loss: 0.13374188542366028, L1 loss: 249.81625366210938
Layer 9: L2 loss: 0.2078799605369568, L1 loss: 234.19540405273438
Layer 10: L2 loss: 0.508967399597168, L1 loss: 141.22979736328125
Layer 11: L2 loss: 0.8917564749717712, L1 loss: 117.58256530761719


Pretty similar to the results I got.