# This is only used for evaluation. Training is not yet supported!

In [None]:
from sae_lens.training.sae import SAE
from sae_lens.config import LanguageModelSAERunnerConfig
from huggingface_hub import hf_hub_download

In [None]:
hf_model_id = "HuFY-dev/tanh_sae"
model_path = hf_hub_download(hf_model_id, "model.safetensors")
config_path = hf_hub_download(hf_model_id, "config.json")

In [None]:
import json
import torch
from safetensors import safe_open

device = "cuda" if torch.cuda.is_available() else "cpu"

with open(config_path) as f:
    config_dict = json.load(f)

tensors = {}
with safe_open(model_path, framework="pt", device=device) as f:  # type: ignore
    for k in f.keys():
        tensors[k] = f.get_tensor(k)

d_in = config_dict['n_input_features']
d_sae = config_dict['n_learned_features']
cfg = LanguageModelSAERunnerConfig(
        d_in=d_in,
        expansion_factor=d_sae//d_in,
        normalize_sae_decoder=False,
        noise_scale=config_dict['noise_scale'],
        model_name="gpt2",
        activation_fn="tanh-relu",
        hook_name="blocks.{layer}.hook_mlp_out",
        hook_layer=list(range(config_dict['n_components'])),  # type: ignore
        dtype="torch.float32",
        device=device,
    )

single_sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
with torch.no_grad():
    layer = single_sae.cfg.hook_layer
    single_sae.W_enc.data = tensors['encoder.weight'].data[layer].T.clone()
    single_sae.b_enc.data = tensors['encoder.bias'].data[layer].clone()
    single_sae.W_dec.data = tensors['decoder.weight'].data[layer].T.clone()
    single_sae.b_dec.data = tensors['post_decoder_bias._bias_reference'].data[layer].clone()

In [None]:
# Now you should be able to use sae_group
single_sae

# Small sanity check

In [None]:
from typing import TypedDict
from transformer_lens import HookedTransformer
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader

In [None]:
model_name = "gpt2"
model = HookedTransformer.from_pretrained(model_name)

In [None]:
dataset_path = "alancooney/sae-monology-pile-uncopyrighted-tokenizer-gpt2"
torch_dataset = load_dataset(dataset_path, split="train", streaming=True).with_format("torch")

In [None]:
TokenizedPrompt = list[int]
"""A tokenized prompt."""


class TokenizedPrompts(TypedDict):
    """Tokenized prompts."""

    input_ids: list[TokenizedPrompt]
    
class TorchTokenizedPrompts(TypedDict):
    """Tokenized prompts prepared for PyTorch."""

    input_ids: torch.Tensor

dl = DataLoader[TorchTokenizedPrompts](
            torch_dataset,
            batch_size=16,
            # Shuffle is most efficiently done with the `shuffle` method on the dataset itself, not
            # here.
            shuffle=False,
            num_workers=1,
        )

In [None]:
saes_by_layer = {}
hooked_layers = []
for layer in list(range(config_dict['n_components'])):
    cfg = LanguageModelSAERunnerConfig(
            d_in=d_in,
            expansion_factor=d_sae//d_in,
            normalize_sae_decoder=False,
            noise_scale=config_dict['noise_scale'],
            model_name="gpt2",
            activation_fn="tanh-relu",
            hook_name=f"blocks.{layer}.hook_mlp_out",
            hook_layer=layer,  # type: ignore
            dtype="torch.float32",
            device=device,
            verbose=False,
        )

    single_sae = SAE.from_dict(cfg.get_base_sae_cfg_dict())
    with torch.no_grad():
        layer = single_sae.cfg.hook_layer
        single_sae.W_enc.data = tensors['encoder.weight'].data[layer].T.clone()
        single_sae.b_enc.data = tensors['encoder.bias'].data[layer].clone()
        single_sae.W_dec.data = tensors['decoder.weight'].data[layer].T.clone()
        single_sae.b_dec.data = tensors['post_decoder_bias._bias_reference'].data[layer].clone()

    saes_by_layer[layer] = single_sae
    hooked_layers.append(single_sae.cfg.hook_name)
    
hooked_layers

In [None]:
residuals = []
for i, batch in enumerate(dl):
    if i >= 1:
        break
    batch_tokens = batch["input_ids"]
    _, cache = model.run_with_cache(batch_tokens, prepend_bos=True, names_filter=hooked_layers)
    residuals = [cache[layer] for layer in hooked_layers]
    del cache

In [None]:
sae_hooks = ["hook_sae_acts_post", "hook_sae_output"]
for i in range(len(residuals)):
    autoencoder = saes_by_layer[i]
    _, cache = autoencoder.run_with_cache(residuals[i].to(autoencoder.device), names_filter=sae_hooks)
    reconstructed = cache["hook_sae_output"]
    feature_act = cache["hook_sae_acts_post"]
    l2_loss = torch.nn.functional.mse_loss(residuals[i].to(autoencoder.device), reconstructed)
    l1_loss = torch.nn.functional.l1_loss(feature_act, torch.zeros_like(feature_act)) * autoencoder.cfg.d_sae
    print(f"Layer {i}: L2 loss: {l2_loss}, L1 loss: {l1_loss}")
    del cache

Pretty similar to the results I got.