In [19]:
from abc import ABC, abstractmethod
import einops
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
import json
import safetensors
from huggingface_hub import hf_hub_download


class BaseSAE(nn.Module, ABC):
    def __init__(
        self,
        d_in: int,
        d_sae: int,
        model_name: str,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        hook_name: str | None = None,
    ):
        super().__init__()

        # Required parameters
        self.W_enc = nn.Parameter(torch.zeros(d_in, d_sae))
        self.W_dec = nn.Parameter(torch.zeros(d_sae, d_in))

        self.b_enc = nn.Parameter(torch.zeros(d_sae))
        self.b_dec = nn.Parameter(torch.zeros(d_in))

        # Required attributes
        self.device: torch.device = device
        self.dtype: torch.dtype = dtype
        self.hook_layer = hook_layer

        hook_name = hook_name or f"blocks.{hook_layer}.hook_resid_post"
        self.to(dtype=self.dtype, device=self.device)

    @abstractmethod
    def encode(self, x: torch.Tensor):
        """Must be implemented by child classes"""
        raise NotImplementedError("Encode method must be implemented by child classes")

    @abstractmethod
    def decode(self, feature_acts: torch.Tensor):
        """Must be implemented by child classes"""
        raise NotImplementedError("Decode method must be implemented by child classes")

    @abstractmethod
    def forward(self, x: torch.Tensor):
        """Must be implemented by child classes"""
        raise NotImplementedError("Forward method must be implemented by child classes")

    def to(self, *args, **kwargs):
        """Handle device and dtype updates"""
        super().to(*args, **kwargs)
        device = kwargs.get("device", None)
        dtype = kwargs.get("dtype", None)

        if device:
            self.device = device
        if dtype:
            self.dtype = dtype
        return self


class TopKSAE(BaseSAE):
    def __init__(
        self,
        d_in: int,
        d_sae: int,
        k: int,
        model_name: str,
        hook_layer: int,
        device: torch.device,
        dtype: torch.dtype,
        use_threshold: bool = False,
        hook_name: str | None = None,
    ):
        hook_name = hook_name or f"blocks.{hook_layer}.hook_resid_post"
        super().__init__(d_in, d_sae, model_name, hook_layer, device, dtype, hook_name)

        assert isinstance(k, int) and k > 0
        self.register_buffer("k", torch.tensor(k, dtype=torch.int, device=device))
        self.d_sae = d_sae
        self.d_in = d_in
        self.pre_encoder_bias = False

        self.use_threshold = use_threshold
        if use_threshold:
            # Optional global threshold to use during inference. Must be positive.
            self.register_buffer(
                "threshold", torch.tensor(-1.0, dtype=dtype, device=device)
            )

    def encode(self, x: torch.Tensor):
        """Note: x can be either shape (B, F) or (B, L, F)"""
        if self.pre_encoder_bias:
            x = x - self.b_dec

        post_relu_feat_acts_BF = nn.functional.relu(x @ self.W_enc + self.b_enc)

        if self.use_threshold:
            if self.threshold < 0:
                raise ValueError(
                    "Threshold is not set. The threshold must be set to use it during inference"
                )
            encoded_acts_BF = post_relu_feat_acts_BF * (
                post_relu_feat_acts_BF > self.threshold
            )
            return encoded_acts_BF

        post_topk = post_relu_feat_acts_BF.topk(self.k, sorted=False, dim=-1)

        tops_acts_BK = post_topk.values
        top_indices_BK = post_topk.indices

        buffer_BF = torch.zeros_like(post_relu_feat_acts_BF)
        encoded_acts_BF = buffer_BF.scatter_(
            dim=-1, index=top_indices_BK, src=tops_acts_BK
        )
        return encoded_acts_BF

    def decode(self, feature_acts: torch.Tensor):
        return (feature_acts @ self.W_dec) + self.b_dec

    def forward(self, x: torch.Tensor):
        x = self.encode(x)
        recon = self.decode(x)
        return recon

In [20]:

def load_llama_scope_topk_sae(
    model_name: str,
    device: torch.device,
    dtype: torch.dtype,
    layer: int,
    expansion_factor: int,
) -> TopKSAE:
    repo_id = f"fnlp/Llama3_1-8B-Base-LXR-{expansion_factor}x"
    config_filename = f"Llama3_1-8B-Base-L{layer}R-{expansion_factor}x/hyperparams.json"
    filename = (
        f"Llama3_1-8B-Base-L{layer}R-{expansion_factor}x/checkpoints/final.safetensors"
    )

    path_to_params = hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        force_download=False,
        local_dir="downloaded_saes",
    )

    path_to_config = hf_hub_download(
        repo_id=repo_id,
        filename=config_filename,
        force_download=False,
        local_dir="downloaded_saes",
    )

    with open(path_to_config) as f:
        config = json.load(f)

    threshold = config["jump_relu_threshold"]
    k = config["top_k"]

    pt_params = safetensors.torch.load_file(path_to_params)

    key_mapping = {
        "encoder.weight": "W_enc",
        "decoder.weight": "W_dec",
        "encoder.bias": "b_enc",
        "decoder.bias": "b_dec",
    }

    renamed_params = {key_mapping.get(k, k): v for k, v in pt_params.items()}
    renamed_params["W_enc"] = renamed_params["W_enc"].T
    renamed_params["W_dec"] = renamed_params["W_dec"].T
    renamed_params["k"] = torch.tensor(k, dtype=torch.int, device=device)
    renamed_params["threshold"] = torch.tensor(threshold, dtype=dtype, device=device)

    print(renamed_params.keys())

    d_in = renamed_params["b_dec"].shape[0]
    d_sae = renamed_params["b_enc"].shape[0]

    assert d_in <= d_sae, "d_in must be less than or equal to d_sae"

    sae = TopKSAE(
        d_in=d_in,
        d_sae=d_sae,
        k=k,
        model_name=model_name,
        hook_layer=layer,
        device=device,
        dtype=dtype,
        use_threshold=True,
    )

    sae.load_state_dict(renamed_params)
    sae.to(device=device, dtype=dtype)

    assert sae.W_enc.shape == (d_in, d_sae)
    assert sae.W_dec.shape == (d_sae, d_in)

    # https://github.com/OpenMOSS/Language-Model-SAEs/blob/25180e32e82176924b62ab30a75fffd234260a9e/src/lm_saes/sae.py#L172
    # openmoss scaling strategy
    dataset_average_activation_norm = config["dataset_average_activation_norm"]
    input_norm_factor = sae.d_in**0.5 / dataset_average_activation_norm["in"]
    sae.b_enc.data /= input_norm_factor
    sae.b_dec.data /= input_norm_factor

    return sae



In [21]:
device = "cuda"
dtype = torch.bfloat16
model_name = "meta-llama/Llama-3.1-8B"
# note: will also work for Instruct models
layer = 9
expansion_factor = 8

sae = load_llama_scope_topk_sae(
    model_name=model_name,
    device=device,
    dtype=dtype,
    layer=layer,
    expansion_factor=expansion_factor,
)

dict_keys(['b_dec', 'W_dec', 'b_enc', 'W_enc', 'k', 'threshold'])


In [7]:
model = AutoModelForCausalLM.from_pretrained(
    model_name, device_map="auto", torch_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_name)

Loading checkpoint shards: 100%|██████████| 4/4 [00:01<00:00,  2.04it/s]


In [22]:
submodule = model.model.layers[layer]

test_input = "The scientist named the population, after their distinctive horn, Ovid’s Unicorn. These four-horned, silver-white unicorns were previously unknown to science"
test_input = tokenizer(test_input, return_tensors="pt", add_special_tokens=True).to(
    device
)

In [23]:

class EarlyStopException(Exception):
    """Custom exception for stopping model forward pass early."""
    pass

def collect_activations(
    model: AutoModelForCausalLM,
    submodule: torch.nn.Module,
    inputs_BL: dict[str, torch.Tensor],
) -> torch.Tensor:
    """
    Registers a forward hook on the submodule to capture the residual (or hidden)
    activations. We then raise an EarlyStopException to skip unneeded computations.

    Args:
        model: The model to run.
        submodule: The submodule to hook into.
        inputs_BL: The inputs to the model.
    """
    activations_BLD = None

    def gather_target_act_hook(module, inputs, outputs):
        nonlocal activations_BLD
        # For many models, the submodule outputs are a tuple or a single tensor:
        # If "outputs" is a tuple, pick the relevant item:
        #   e.g. if your layer returns (hidden, something_else), you'd do outputs[0]
        # Otherwise just do outputs
        if isinstance(outputs, tuple):
            activations_BLD = outputs[0]
        else:
            activations_BLD = outputs

        raise EarlyStopException("Early stopping after capturing activations")

    handle = submodule.register_forward_hook(gather_target_act_hook)


    try:
        with torch.no_grad():
            _ = model(**inputs_BL)
    except EarlyStopException:
        pass
    except Exception as e:
        print(f"Unexpected error during forward pass: {str(e)}")
        raise
    finally:
        handle.remove()

    return activations_BLD

acts_BLD = collect_activations(model, model.model.layers[layer], test_input)

In [24]:
encoded_acts_BLF = sae.encode(acts_BLD)

decoded_acts_BLD = sae.decode(encoded_acts_BLF)

flattened_acts_BD = einops.rearrange(acts_BLD, "b l d -> (b l) d")
reconstructed_acts_BD = sae(flattened_acts_BD)
# match flattened_acts with decoded_acts
reconstructed_acts_BLD = reconstructed_acts_BD.reshape(acts_BLD.shape)

assert torch.allclose(reconstructed_acts_BLD, decoded_acts_BLD)

l0 = (encoded_acts_BLF[:, 1:] > 0).float().sum(-1).detach()
print(f"average l0: {l0.mean().item()}")

variance_explained = 1 - torch.mean(
    (reconstructed_acts_BLD[:, 1:] - acts_BLD[:, 1:].to(torch.float32)) ** 2
) / (acts_BLD[:, 1:].to(torch.float32).var())

print(f"variance explained: {variance_explained.item()}")

average l0: 32.45161056518555
variance explained: 0.626794695854187
