# TemporalSAE Quickstart

Dear Neuronpedia Team, thanks a lot for your support!

We're providing the smallest SAE we have on Gemma-2-2B so far, to get started quickly. There's also LLama SAEs that'll probably perform better in Autointerp.

The Temporal SAE builds on the prior that each LLM Activation x at position t is a sum two parts: some aggragation information at previous positions (x_pred) and novel information related to the input token at position t (x_novel), so x = x_pred + x_novel. We assume that only x_novel is sparse.

The Temporal SAE encodes llm activation x_t at position t in two steps:
1. Identify and remove information from the context: Apply an attention layer to reconstruct x_t only with context vectors {x_0, ..., x_{t-1}}. We'll call the attention layer output x_{t,pred}, and compute the remainder x_{t,novel} = x_t - x_{t,pred}.
2. Feed x_{t,novel} (="novel_codes") through a TopK SAE.

The main focus is generating feature dashboards of the novel codes.

If Neuronpedia has the resources, it would be great to run a cheap version of Autointerp on the novel codes, to verify the autointerp score is the ballpark of other SAE architectures.

Aside for later: For natural lanugage inputs (not dummy_acts) it's also interesting to look at the similarity matrix kernel of the predicted codes, as shown in the paper and presentation. Displaying this on Neuronpedia is interesting, but not a priority. We can talk details later, if you are interested.

In [14]:
from pathlib import Path
from huggingface_hub import snapshot_download

# Define custom directory for temporal SAEs
temporal_sae_dir = Path("./temporal_saes_weights")
temporal_sae_dir.mkdir(exist_ok=True)

# Download the temporal folder from the repository
snapshot_download(
    repo_id="ekdeepslubana/temporalSAEs",
    allow_patterns="temporal/*",
    local_dir=temporal_sae_dir,
    local_dir_use_symlinks=False,
)

print(f"Downloaded temporal SAEs to: {temporal_sae_dir.absolute()}")

For more details, check out https://huggingface.co/docs/huggingface_hub/main/en/guides/download#download-files-to-local-folder.
Fetching 2 files: 100%|██████████| 2/2 [00:00<00:00, 4576.44it/s]

Downloaded temporal SAEs to: /home/can/dynamic_representations/exp/temporal_saes_weights





In [15]:
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import yaml
import math
from collections import defaultdict
from tqdm import trange
import os


# Standalone TemporalSAE class definition
def get_attention(query, key) -> th.Tensor:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1))
    attn_bias = th.zeros(L, S, dtype=query.dtype, device=query.device)
    temp_mask = th.ones(L, S, dtype=th.bool, device=query.device).tril(diagonal=0)
    attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
    attn_bias.to(query.dtype)

    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = th.softmax(attn_weight, dim=-1)
    return attn_weight


### Manual Attention Implementation
class ManualAttention(nn.Module):
    """
    Manual implementation to allow tinkering with the attention mechanism.
    """

    def __init__(
        self,
        dimin,
        n_heads=4,
        bottleneck_factor=64,
        bias_k=True,
        bias_q=True,
        bias_v=True,
        bias_o=True,
    ):
        super().__init__()
        assert dimin % (bottleneck_factor * n_heads) == 0

        # attention heads
        self.n_heads = n_heads
        self.n_embds = dimin // bottleneck_factor  # n_heads
        self.dimin = dimin

        # key, query, value projections for all heads, but in a batch
        self.k_ctx = nn.Linear(dimin, self.n_embds, bias=bias_k)
        self.q_target = nn.Linear(dimin, self.n_embds, bias=bias_q)
        self.v_ctx = nn.Linear(dimin, dimin, bias=bias_v)

        # output projection
        self.c_proj = nn.Linear(dimin, dimin, bias=bias_o)

        # Normalize to match scale with representations
        with th.no_grad():
            scaling = 1 / math.sqrt(self.n_embds // self.n_heads)
            self.k_ctx.weight.copy_(
                scaling
                * self.k_ctx.weight
                / (1e-6 + th.linalg.norm(self.k_ctx.weight, dim=1, keepdim=True))
            )
            self.q_target.weight.copy_(
                scaling
                * self.q_target.weight
                / (1e-6 + th.linalg.norm(self.q_target.weight, dim=1, keepdim=True))
            )

            scaling = 1 / math.sqrt(self.dimin // self.n_heads)
            self.v_ctx.weight.copy_(
                scaling
                * self.v_ctx.weight
                / (1e-6 + th.linalg.norm(self.v_ctx.weight, dim=1, keepdim=True))
            )

            scaling = 1 / math.sqrt(self.dimin)
            self.c_proj.weight.copy_(
                scaling
                * self.c_proj.weight
                / (1e-6 + th.linalg.norm(self.c_proj.weight, dim=1, keepdim=True))
            )

    def forward(self, x_ctx, x_target, get_attn_map=False):
        """
        Compute projective attention output
        """
        # Compute key and value projections from context representations
        k = self.k_ctx(x_ctx)
        v = self.v_ctx(x_ctx)

        # Compute query projection from target representations
        q = self.q_target(x_target)

        # Split into heads
        B, T, _ = x_ctx.size()
        k = k.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
        q = q.view(B, T, self.n_heads, self.n_embds // self.n_heads).transpose(1, 2)
        v = v.view(B, T, self.n_heads, self.dimin // self.n_heads).transpose(1, 2)

        # Attn map
        if get_attn_map:
            attn_map = get_attention(query=q, key=k)
            th.cuda.empty_cache()

        # Scaled dot-product attention
        attn_output = th.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=0, is_causal=True
        )

        # Reshape, project back to original dimension
        d_target = self.c_proj(
            attn_output.transpose(1, 2).contiguous().view(B, T, self.dimin)
        )  # [batch, length, dimin]

        if get_attn_map:
            return d_target, attn_map
        else:
            return d_target, None


class TemporalSAE(th.nn.Module):
    def __init__(
        self,
        dimin=2,
        width=5,
        n_heads=8,
        sae_diff_type="relu",
        kval_topk=None,
        tied_weights=True,
        n_attn_layers=1,
        bottleneck_factor=64,
    ):
        """
        dimin: (int)
            input dimension
        width: (int)
            width of the encoder
        n_heads: (int)
            number of attention heads
        sae_diff_type: (str)
            type of sae to express the per-token difference
        kval_topk: (int)
            k in topk sae_diff_type
        n_attn_layers: (int)
            number of attention layers
        """
        super(TemporalSAE, self).__init__()
        self.sae_type = "temporal"
        self.width = width
        self.dimin = dimin
        self.eps = 1e-6
        self.lam = 1 / (4 * dimin)
        self.tied_weights = tied_weights

        ## Attention parameters
        self.n_attn_layers = n_attn_layers
        self.attn_layers = nn.ModuleList(
            [
                ManualAttention(
                    dimin=width,
                    n_heads=n_heads,
                    bottleneck_factor=bottleneck_factor,
                    bias_k=True,
                    bias_q=True,
                    bias_v=True,
                    bias_o=True,
                )
                for _ in range(n_attn_layers)
            ]
        )

        ## Dictionary parameters
        self.D = nn.Parameter(th.randn((width, dimin)))  # N(0,1) init
        self.b = nn.Parameter(th.zeros((1, dimin)))
        if not tied_weights:
            self.E = nn.Parameter(th.randn((dimin, width)))  # N(0,1) init

        ## SAE-specific parameters
        self.sae_diff_type = sae_diff_type
        self.kval_topk = kval_topk if sae_diff_type == "topk" else None

    def forward(self, x_input, return_graph=False, inf_k=None):
        B, L, _ = x_input.size()
        E = self.D.T if self.tied_weights else self.E

        ### Define context and target ###
        x_input = x_input - self.b

        ### Tracking variables ###
        attn_graphs = []

        ### Predictable part ###
        z_pred = th.zeros(
            (B, L, self.width), device=x_input.device, dtype=x_input.dtype
        )
        for attn_layer in self.attn_layers:
            z_input = F.relu(th.matmul(x_input * self.lam, E))  # [batch, length, width]
            z_ctx = th.cat(
                (th.zeros_like(z_input[:, :1, :]), z_input[:, :-1, :].clone()), dim=1
            )  # [batch, length, width]

            # Compute codes using attention
            z_pred_, attn_graphs_ = attn_layer(
                z_ctx, z_input, get_attn_map=return_graph
            )

            # Take back to input space
            z_pred_ = F.relu(z_pred_)
            Dz_pred_ = th.matmul(z_pred_, self.D)
            Dz_norm_ = Dz_pred_.norm(dim=-1, keepdim=True) + self.eps

            # Compute projection
            proj_scale = (Dz_pred_ * x_input).sum(dim=-1, keepdim=True) / Dz_norm_.pow(
                2
            )

            # Add the projection to the reconstructed
            z_pred = z_pred + (z_pred_ * proj_scale)

            # Remove the projection from the input
            x_input = x_input - proj_scale * Dz_pred_  # [batch, length, width]

            # Add the attention graph if return_graph is True
            if return_graph:
                attn_graphs.append(attn_graphs_)

        ### Novel part (identified using the residual target signal) ###
        if self.sae_diff_type == "relu":
            z_novel = F.relu(th.matmul(x_input * self.lam, E))

        elif self.sae_diff_type == "topk":
            kval = self.kval_topk if inf_k is None else inf_k
            z_novel = F.relu(th.matmul(x_input * self.lam, E))
            _, topk_indices = th.topk(z_novel, kval, dim=-1)
            mask = th.zeros_like(z_novel)
            mask.scatter_(-1, topk_indices, 1)
            z_novel = z_novel * mask

        elif self.sae_diff_type == "nullify":
            z_novel = th.zeros_like(z_pred)

        ### Reconstruction ###
        x_recons = (
            th.matmul(z_novel + z_pred, self.D) + self.b
        )  # [batch, length, dimin]

        ### Compute the predicted vs. novel reconstructions, sans the bias (allows to check context / dictionary's value) ###
        with th.no_grad():
            x_pred_recons = th.matmul(z_pred, self.D)
            x_novel_recons = th.matmul(z_novel, self.D)

        ### Return the dictionary ###
        results_dict = {
            "novel_codes": z_novel,
            "novel_recons": x_novel_recons,
            "pred_codes": z_pred,
            "pred_recons": x_pred_recons,
            "attn_graphs": th.stack(attn_graphs, dim=1) if return_graph else None,
        }

        return x_recons, results_dict

    @classmethod
    def from_pretrained(cls, folder_path, dtype, device, **kwargs):
        """
        Load a pretrained TemporalSAE from a folder containing conf.yaml and latest_ckpt.pt

        Args:
            folder_path: Path to folder containing conf.yaml and latest_ckpt.pt
            dtype: Target dtype for the model
            device: Target device for the model
            **kwargs: Override any config parameters
        """
        # Load config from yaml file
        config_path = os.path.join(folder_path, "conf.yaml")
        with open(config_path, "r") as f:
            config = yaml.safe_load(f)

        # Extract model parameters from config
        model_args = {
            "dimin": config["llm"]["dimin"],
            "width": int(config["llm"]["dimin"] * config["sae"]["exp_factor"]),
            "n_heads": config["sae"]["n_heads"],
            "sae_diff_type": config["sae"]["sae_diff_type"],
            "kval_topk": config["sae"]["kval_topk"],
            "tied_weights": config["sae"]["tied_weights"],
            "n_attn_layers": config["sae"]["n_attn_layers"],
            "bottleneck_factor": config["sae"]["bottleneck_factor"],
        }

        # Override with any provided kwargs
        model_args.update(kwargs)

        # Create the model
        autoencoder = cls(**model_args)

        # Load the checkpoint
        ckpt_path = os.path.join(folder_path, "latest_ckpt.pt")
        checkpoint = th.load(ckpt_path, map_location="cpu", weights_only=False)

        # Load the state dict
        if "sae" in checkpoint:
            autoencoder.load_state_dict(checkpoint["sae"])
        else:
            autoencoder.load_state_dict(checkpoint)

        autoencoder = autoencoder.to(device=device, dtype=dtype)
        return autoencoder

In [16]:
# Load the temporal SAE
sae_dir = temporal_sae_dir / "temporal"
device = "cuda" if th.cuda.is_available() else "cpu"
dtype = th.float32

sae = TemporalSAE.from_pretrained(folder_path=sae_dir, device=device, dtype=dtype)
sae.eval()

print(f"Loaded TemporalSAE from {sae_dir}")
print(f"Input dim: {sae.dimin}, Width: {sae.width}")
print(f"SAE type: {sae.sae_diff_type}, Attention layers: {sae.n_attn_layers}")

Loaded TemporalSAE from temporal_saes_weights/temporal
Input dim: 2304, Width: 9216
SAE type: topk, Attention layers: 1


In [None]:
@th.inference_mode()
def batch_temporal_sae_inference(
    sae, act_BPD, batch_size=32, device="cuda", dtype=th.float32
):
    """
    Batched inference for temporal SAE

    Args:
        sae: TemporalSAE model
        act_BPD: Input activations [B, P, D]
        batch_size: Batch size for processing
        device: Device to use
        dtype: Data type

    Returns:
        dict: Dictionary containing reconstruction outputs
    """
    B, P, D = act_BPD.shape

    results = defaultdict(list)

    for i in trange(0, B, batch_size, desc="SAE forward"):
        batch = act_BPD[i : i + batch_size]
        batch = batch.to(device=device, dtype=dtype)

        # Forward pass through temporal SAE
        recon_BPD, batch_dict = sae.forward(batch, return_graph=True)
        residuals_BPD = batch - recon_BPD

        # Store results
        results["novel_codes"].append(batch_dict["novel_codes"].detach().cpu())
        # NOTE Only novel codes are relevant for AutoInterp
        # results["total_recons"].append(recon_BPD.detach().cpu())
        # results["residuals"].append(residuals_BPD.detach().cpu())
        # results["novel_recons"].append(batch_dict["novel_recons"].detach().cpu())
        # results["pred_codes"].append(batch_dict["pred_codes"].detach().cpu())
        # results["pred_recons"].append(batch_dict["pred_recons"].detach().cpu())
        # results["attn_graphs"].append(batch_dict["attn_graphs"].detach().cpu())

        # # Add reconstruction bias
        # results["novel_recons_plus_b"].append(
        #     (batch_dict["novel_recons"] + sae.b).detach().cpu()
        # )
        # results["pred_recons_plus_b"].append(
        #     (batch_dict["pred_recons"] + sae.b).detach().cpu()
        # )

    # Concatenate all batches
    for key in results:
        results[key] = th.cat(results[key], dim=0)

    return results


# Example usage:
print("\\nExample usage:")
dummy_acts = th.randn(100, 64, sae.dimin)  # [B=100, P=64, D=input_dim]")
results = batch_temporal_sae_inference(sae, dummy_acts, batch_size=16)

\nExample usage:


SAE forward:  29%|██▊       | 2/7 [00:00<00:00, 11.44it/s]

SAE forward: 100%|██████████| 7/7 [00:00<00:00, 13.31it/s]


Dear Neuronpedia Team, thanks a lot for your support!

The main focus is generating feature dashboards of the novel codes.

If Neuronpedia has the resources, it would be great to run a cheap version of Autointerp, to verify the autointerp score is the ballpark of other SAE architectures.

Aside for later: For natural lanugage inputs (not dummy_acts) it's also interesting to look at the similarity matrix kernel of the predicted codes, as shown in the paper and presentation. Displaying this on Neuronpedia is interesting, but not a priority. We can talk details later, if you are interested.