In [3]:
from abc import ABC, abstractmethod

import einops
import torch
import torch.nn as nn


class BaseSAE(nn.Module, ABC):
    def __init__(
        self,
        d_in: int,
        d_sae: int,
        model_name: str,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        hook_name: str | None = None,
    ):
        super().__init__()

        # Required parameters
        self.W_enc = nn.Parameter(torch.zeros(d_in, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_in))

        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_in))

        # Required attributes
        self.device: torch.device = device
        self.dtype: torch.dtype = dtype
        self.hook_layer = hook_layer

        hook_name = hook_name or f"blocks.{hook_layer}.hook_resid_post"
        self.to(dtype=self.dtype, device=self.device)

    @abstractmethod
    def encode(self, x: torch.Tensor):
        """Must be implemented by child classes"""
        raise NotImplementedError("Encode method must be implemented by child classes")

    @abstractmethod
    def decode(self, feature_acts: torch.Tensor):
        """Must be implemented by child classes"""
        raise NotImplementedError("Encode method must be implemented by child classes")

    @abstractmethod
    def forward(self, x: torch.Tensor):
        """Must be implemented by child classes"""
        raise NotImplementedError("Encode method must be implemented by child classes")

    def to(self, *args, **kwargs):
        """Handle device and dtype updates"""
        super().to(*args, **kwargs)
        device = kwargs.get("device", None)
        dtype = kwargs.get("dtype", None)

        if device:
            self.device = device
        if dtype:
            self.dtype = dtype
        return self

    @torch.no_grad()
    def check_decoder_norms(self) -> bool:
        """
        It's important to check that the decoder weights are normalized.
        """
        norms = torch.norm(self.W_dec, dim=1).to(dtype=self.dtype, device=self.device)

        # In bfloat16, it's common to see errors of (1/256) in the norms
        tolerance = (
            1e-2 if self.W_dec.dtype in [torch.bfloat16, torch.float16] else 1e-5
        )

        if torch.allclose(norms, torch.ones_like(norms), atol=tolerance):
            return True
        else:
            max_diff = torch.max(torch.abs(norms - torch.ones_like(norms)))
            print(f"Decoder weights are not normalized. Max diff: {max_diff.item()}")
            return False

    # @torch.no_grad()
    # def test_sae(self, model_name: str):
    #     assert self.W_dec.shape == (self.cfg.d_sae, self.cfg.d_in)
    #     assert self.W_enc.shape == (self.cfg.d_in, self.cfg.d_sae)

    #     # TODO: Refactor to use AutoModelForCausalLM
    #     model = HookedTransformer.from_pretrained(model_name, device=self.device)

    #     test_input = "The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science"

    #     _, cache = model.run_with_cache(
    #         test_input,
    #         prepend_bos=True,
    #         names_filter=[self.cfg.hook_name],
    #         stop_at_layer=self.cfg.hook_layer + 1,
    #     )
    #     acts = cache[self.cfg.hook_name]

    #     encoded_acts = self.encode(acts)
    #     decoded_acts = self.decode(encoded_acts)

    #     flattened_acts = einops.rearrange(acts, "b l d -> (b l) d")
    #     reconstructed_acts = self(flattened_acts)
    #     # match flattened_acts with decoded_acts
    #     reconstructed_acts = reconstructed_acts.reshape(acts.shape)

    #     assert torch.allclose(reconstructed_acts, decoded_acts)

    #     l0 = (encoded_acts[:, 1:] > 0).float().sum(-1).detach()
    #     print(f"average l0: {l0.mean().item()}")


In [4]:
import json

import torch
import torch.nn as nn
from huggingface_hub import hf_hub_download


class BatchTopKSAE(BaseSAE):
    def __init__(
        self,
        d_in: int,
        d_sae: int,
        k: int,
        model_name: str,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        hook_name: str | None = None,
    ):
        hook_name = hook_name or f"blocks.{hook_layer}.hook_resid_post"
        super().__init__(d_in, d_sae, model_name, hook_layer, device, dtype, hook_name)

        assert isinstance(k, int) and k > 0
        self.register_buffer("k", torch.tensor(k, dtype=torch.int, device=device))

        # BatchTopK requires a global threshold to use during inference. Must be positive.
        self.use_threshold = True
        self.register_buffer(
            "threshold", torch.tensor(-1.0, dtype=dtype, device=device)
        )

    def encode(self, x: torch.Tensor):
        """Note: x can be either shape (B, F) or (B, L, F)"""
        post_relu_feat_acts_BF = nn.functional.relu(
            (x - self.b_dec) @ self.W_enc + self.b_enc
        )

        if self.use_threshold:
            if self.threshold < 0:
                raise ValueError(
                    "Threshold is not set. The threshold must be set to use it during inference"
                )
            encoded_acts_BF = post_relu_feat_acts_BF * (
                post_relu_feat_acts_BF > self.threshold
            )
            return encoded_acts_BF

        post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)

        tops_acts_BK = post_topk.values
        top_indices_BK = post_topk.indices

        buffer_BF = torch.zeros_like(post_relu_feat_acts_BF)
        encoded_acts_BF = buffer_BF.scatter_(
            dim=-1, index=top_indices_BK, src=tops_acts_BK
        )
        return encoded_acts_BF

    def decode(self, feature_acts: torch.Tensor):
        return (feature_acts @ self.W_dec) + self.b_dec

    def forward(self, x: torch.Tensor):
        x = self.encode(x)
        recon = self.decode(x)
        return recon


def load_dictionary_learning_batch_topk_sae(
    repo_id: str,
    filename: str,
    model_name: str,
    device: torch.device,
    dtype: torch.dtype,
    layer: int | None = None,
    local_dir: str = "downloaded_saes",
) -> BatchTopKSAE:
    assert "ae.pt" in filename

    path_to_params = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        force_download=False,
        local_dir=local_dir,
    )

    pt_params = torch.load(path_to_params, map_location=torch.device("cpu"))

    config_filename = filename.replace("ae.pt", "config.json")
    path_to_config = hf_hub_download(
        repo_id=repo_id,
        filename=config_filename,
        force_download=False,
        local_dir=local_dir,
    )

    with open(path_to_config) as f:
        config = json.load(f)

    if layer is not None:
        assert layer == config["trainer"]["layer"]
    else:
        layer = config["trainer"]["layer"]

    # Transformer lens often uses a shortened model name
    assert model_name in config["trainer"]["lm_name"]

    k = config["trainer"]["k"]

    # Print original keys for debugging
    print("Original keys in state_dict:", pt_params.keys())

    # Map old keys to new keys
    key_mapping = {
        "encoder.weight": "W_enc",
        "decoder.weight": "W_dec",
        "encoder.bias": "b_enc",
        "bias": "b_dec",
        "k": "k",
        "threshold": "threshold",
    }

    # Create a new dictionary with renamed keys
    renamed_params = {key_mapping.get(k, k): v for k, v in pt_params.items()}

    # due to the way torch uses nn.Linear, we need to transpose the weight matrices
    renamed_params["W_enc"] = renamed_params["W_enc"].T
    renamed_params["W_dec"] = renamed_params["W_dec"].T

    # Print renamed keys for debugging
    print("Renamed keys in state_dict:", renamed_params.keys())

    sae = BatchTopKSAE(
        d_in=renamed_params["b_dec"].shape[0],
        d_sae=renamed_params["b_enc"].shape[0],
        k=k,
        model_name=model_name,
        hook_layer=layer,  # type: ignore
        device=device,
        dtype=dtype,
    )

    sae.load_state_dict(renamed_params)

    sae.to(device=device, dtype=dtype)

    d_sae, d_in = sae.W_dec.data.shape

    assert d_sae >= d_in

    normalized = sae.check_decoder_norms()
    if not normalized:
        raise ValueError("Decoder vectors are not normalized. Please normalize them")

    return sae


def load_dictionary_learning_matryoshka_batch_topk_sae(
    repo_id: str,
    filename: str,
    model_name: str,
    device: torch.device,
    dtype: torch.dtype,
    layer: int | None = None,
    local_dir: str = "downloaded_saes",
) -> BatchTopKSAE:
    assert "ae.pt" in filename

    path_to_params = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        force_download=False,
        local_dir=local_dir,
    )

    pt_params = torch.load(path_to_params, map_location=torch.device("cpu"))

    config_filename = filename.replace("ae.pt", "config.json")
    path_to_config = hf_hub_download(
        repo_id=repo_id,
        filename=config_filename,
        force_download=False,
        local_dir=local_dir,
    )

    with open(path_to_config) as f:
        config = json.load(f)

    if layer is not None:
        assert layer == config["trainer"]["layer"]
    else:
        layer = config["trainer"]["layer"]

    # Transformer lens often uses a shortened model name
    assert model_name in config["trainer"]["lm_name"]

    k = config["trainer"]["k"]

    # We currently don't use group sizes, so we remove them to reuse the BatchTopKSAE class
    del pt_params["group_sizes"]

    # Print original keys for debugging
    print("Original keys in state_dict:", pt_params.keys())

    sae = BatchTopKSAE(
        d_in=pt_params["b_dec"].shape[0],
        d_sae=pt_params["b_enc"].shape[0],
        k=k,
        model_name=model_name,
        hook_layer=layer,  # type: ignore
        device=device,
        dtype=dtype,
    )

    sae.load_state_dict(pt_params)

    sae.to(device=device, dtype=dtype)

    d_sae, d_in = sae.W_dec.data.shape

    assert d_sae >= d_in

    if config["trainer"]["trainer_class"] == "MatryoshkaBatchTopKTrainer":
        sae.cfg.architecture = "matryoshka_batch_topk"
    else:
        raise ValueError(f"Unknown trainer class: {config['trainer']['trainer_class']}")

    normalized = sae.check_decoder_norms()
    if not normalized:
        raise ValueError("Decoder vectors are not normalized. Please normalize them")

    return sae


# if __name__ == "__main__":
#     repo_id = "adamkarvonen/saebench_pythia-160m-deduped_width-2pow12_date-0104"
#     filename = "BatchTopKTrainer_EleutherAI_pythia-160m-deduped_ctx1024_0104/resid_post_layer_8/trainer_26/ae.pt"
#     layer = 8

#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     dtype = torch.float32

#     model_name = "EleutherAI/pythia-160m-deduped"

#     sae = load_dictionary_learning_batch_topk_sae(
#         repo_id,
#         filename,
#         model_name,
#         device,  # type: ignore
#         dtype,
#         layer=layer,
#     )
#     sae.test_sae(model_name)

# Matryoshka BatchTopK SAE

# if __name__ == "__main__":
#     repo_id = "adamkarvonen/matryoshka_pythia_160m_16k"
#     filename = "MatryoshkaBatchTopKTrainer_temp_100_EleutherAI_pythia-160m-deduped_ctx1024_0104/resid_post_layer_8/trainer_2/ae.pt"
#     layer = 8

#     device = "cuda" if torch.cuda.is_available() else "cpu"
#     dtype = torch.float32

#     model_name = "EleutherAI/pythia-160m-deduped"
#     hook_name = f"blocks.{layer}.hook_resid_post"

#     sae = load_dictionary_learning_matryoshka_batch_topk_sae(
#         repo_id, filename, model_name, device, dtype, layer=layer
#     )
#     sae.test_sae(model_name)


In [5]:
import pandas as pd
from datasets import load_dataset
import datasets
from tqdm import tqdm
from typing import Optional
import torch
import einops

from transformers import AutoTokenizer, AutoModelForCausalLM
# from circuitsvis.activations import text_neuron_activations

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# model_name = "mistralai/Ministral-8B-Instruct-2410"
model_name = "mistralai/Mistral-Small-24B-Instruct-2501"
dtype = torch.bfloat16

tokenizer = AutoTokenizer.from_pretrained(model_name)

def get_submodule(model: AutoModelForCausalLM, layer: int):
    """Gets the residual stream submodule"""
    model_name = model.config._name_or_path

    if "pythia" in model_name:
        return model.gpt_neox.layers[layer]
    elif "gemma" in model_name or "mistral" in model_name:
        return model.model.layers[layer]
    else:
        raise ValueError(f"Please add submodule for model {model_name}")


chosen_layers = [20]
sae_repo = "adamkarvonen/mistral_24b_saes"
sae_path = "mistral_24b_mistralai_Mistral-Small-24B-Instruct-2501_batch_top_k/resid_post_layer_20/trainer_1/ae.pt"

sae = load_dictionary_learning_batch_topk_sae(
    repo_id=sae_repo,
    filename=sae_path,
    model_name=model_name,
    device=device,
    dtype=dtype,
    layer=chosen_layers[0],
    local_dir="downloaded_saes",
)

Original keys in state_dict: dict_keys(['b_dec', 'k', 'threshold', 'decoder.weight', 'encoder.weight', 'encoder.bias'])
Renamed keys in state_dict: dict_keys(['b_dec', 'k', 'threshold', 'W_dec', 'W_enc', 'b_enc'])


In [8]:
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype=dtype, device_map="auto"
)

submodules = [get_submodule(model, chosen_layers[0])]

Fetching 10 files: 100%|██████████| 10/10 [04:42<00:00, 28.29s/it]
Loading checkpoint shards: 100%|██████████| 10/10 [00:08<00:00,  1.16it/s]


In [9]:
class EarlyStopException(Exception):
    """Custom exception for stopping model forward pass early."""
    pass


@torch.no_grad()
def collect_activations(model, submodule, inputs_BL):
    """
    Registers a forward hook on the submodule to capture the residual (or hidden)
    activations. We then raise an EarlyStopException to skip unneeded computations.
    """
    activations_BLD = None

    def gather_target_act_hook(module, inputs, outputs):
        nonlocal activations_BLD
        # For many models, the submodule outputs are a tuple or a single tensor:
        # If "outputs" is a tuple, pick the relevant item:
        #   e.g. if your layer returns (hidden, something_else), you'd do outputs[0]
        # Otherwise just do outputs
        if isinstance(outputs, tuple):
            activations_BLD = outputs[0]
        else:
            activations_BLD = outputs

        raise EarlyStopException("Early stopping after capturing activations")

    handle = submodule.register_forward_hook(gather_target_act_hook)

    try:
        _ = model(input_ids=inputs_BL.to(model.device))
    except EarlyStopException:
        pass
    except Exception as e:
        print(f"Unexpected error during forward pass: {str(e)}")
        raise
    finally:
        handle.remove()

    return activations_BLD

In [10]:
test_input = "Can you continue this sentence? The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science"

test_input = "[INST]Can you continue this story? The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science[/INST]assistant"


# test_input = "[INST]How can I center a div?[/INST]assistant"

tokens = tokenizer(test_input, return_tensors="pt", add_special_tokens=True).to(device)["input_ids"]
# tokens = tokenizer(test_input, return_tensors="pt", add_special_tokens=False).to(device)

print(tokens.shape)

activations_BLD = collect_activations(model, submodules[0], tokens)

print(tokens)
    

torch.Size([1, 42])
tensor([[     1,      3,  12483,   1636,   8214,   1593,   8303,   1063,   1531,
          51475,   8876,   1278,   4669,   1044,   2453,   2034,  41958,  31907,
           1044, 120370,   2190,   2159, 118586,   1046,   4634,   3946,   5352,
           3257,   1286,   1044,  18990,  24462,  82155, 104555,   1722,   8265,
          14183,   1317,  12470,      4,   1503,  19464]], device='cuda:0')


In [11]:
norms_BL = activations_BLD.norm(dim=-1)
print(norms_BL)
print(norms_BL.mean())


tensor([[336.0000,  15.9375,  27.0000,  15.0625,  15.3750,  16.2500,  16.5000,
          13.0000,  14.3750,  16.0000,  14.5625,  14.6875,  14.3125,  13.3750,
          14.6250,  14.8125,  16.6250,  15.3750,  14.9375,  13.8750,  13.5625,
          12.0625,  15.1875,  15.0000,  13.4375,  14.1875,  12.3125,  15.2500,
          15.3750,  15.2500,  15.8750,  15.9375,  15.7500,  16.0000,  15.2500,
          16.1250,  17.1250,  15.5625,  15.8750,  14.7500,  12.3125,  11.6250]],
       device='cuda:0', dtype=torch.bfloat16)
tensor(22.7500, device='cuda:0', dtype=torch.bfloat16)


In [12]:
sae.use_threshold = True
encoded_BLF = sae.encode(activations_BLD)
decoded_BLD = sae.decode(encoded_BLF)

torch.set_printoptions(precision=8, sci_mode=False)

nonzero_BL = einops.reduce((encoded_BLF > 0).float(), "b l f -> b l", "sum")
print(nonzero_BL)
mean_nonzero = nonzero_BL.mean()
print(mean_nonzero, "\n\n")

MSE_BL = (activations_BLD - decoded_BLD).pow(2).mean(dim=-1)
print(MSE_BL)
mean_MSE = MSE_BL.mean()
print(mean_MSE)

tensor([[314.,  31.,  18.,  58., 118., 116., 125., 121., 101., 137., 125., 148.,
         134., 132., 154., 168., 168., 177., 174., 144., 140., 103., 211., 142.,
         118., 134., 130., 172., 179., 148., 148., 142., 216., 171., 152., 155.,
         181., 161., 168., 115., 144., 118.]], device='cuda:0')
tensor(143.11904907, device='cuda:0') 


tensor([[    0.00076675,     0.00005007,     0.00032806,     0.00257874,
             0.00729370,     0.00946045,     0.00759888,     0.00872803,
             0.00781250,     0.00585938,     0.00747681,     0.00994873,
             0.00750732,     0.00897217,     0.00933838,     0.01007080,
             0.00854492,     0.00842285,     0.01153564,     0.00823975,
             0.00769043,     0.00564575,     0.01171875,     0.00927734,
             0.00692749,     0.00744629,     0.00729370,     0.01055908,
             0.00939941,     0.00823975,     0.00750732,     0.00793457,
             0.01190186,     0.01019287,     0.00701904,     0.00772

In [13]:

@torch.no_grad()
def reconstruct_activations(model, submodule, sae, inputs_BL):

    def gather_target_act_hook(module, inputs, outputs):
        # For many models, the submodule outputs are a tuple or a single tensor:
        # If "outputs" is a tuple, pick the relevant item:
        #   e.g. if your layer returns (hidden, something_else), you'd do outputs[0]
        # Otherwise just do outputs
        if isinstance(outputs, tuple):
            activations_BLD = outputs[0]
        else:
            activations_BLD = outputs

        encoded_BLF = sae.encode(activations_BLD)
        decoded_BLD = sae.decode(encoded_BLF)

        outputs = (decoded_BLD,) + outputs[1:]

        return outputs

    handle = submodule.register_forward_hook(gather_target_act_hook)

    try:
        outputs = model(input_ids=inputs_BL.to(model.device), labels=inputs_BL.to(model.device))
    except Exception as e:
        print(f"Unexpected error during forward pass: {str(e)}")
        raise
    finally:
        handle.remove()

    return outputs

original_loss = model(input_ids=tokens.to(model.device), labels=tokens.to(model.device)).loss

outputs = reconstruct_activations(model, submodules[0], sae, tokens)

print(outputs.loss)
print(original_loss)
ratio = outputs.loss / original_loss
print(ratio)

tensor(5.04034185, device='cuda:0')
tensor(4.86054754, device='cuda:0', grad_fn=<NllLossBackward0>)
tensor(1.03699052, device='cuda:0', grad_fn=<DivBackward0>)
