In [1]:
import pandas as pd
import torch, goodfire
import torch.nn.functional as F
from huggingface_hub import hf_hub_download
from liars.constants import ACTIVATION_CACHE, DATA_PATH
from liars.utils import prefixes
client = goodfire.Client()

In [2]:
class SparseAutoEncoder(torch.nn.Module):
    def __init__(
        self,
        d_in: int,
        d_hidden: int,
        device: torch.device,
        dtype: torch.dtype = torch.bfloat16,
    ):
        super().__init__()
        self.d_in = d_in
        self.d_hidden = d_hidden
        self.device = device
        self.encoder_linear = torch.nn.Linear(d_in, d_hidden)
        self.decoder_linear = torch.nn.Linear(d_hidden, d_in)
        self.dtype = dtype
        self.to(self.device, self.dtype)

    def encode(self, x: torch.Tensor) -> torch.Tensor:
        """Encode a batch of data using a linear, followed by a ReLU."""
        return torch.nn.functional.relu(self.encoder_linear(x))

    def decode(self, x: torch.Tensor) -> torch.Tensor:
        """Decode a batch of data using a linear."""
        return self.decoder_linear(x)

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        """SAE forward pass. Returns the reconstruction and the encoded features."""
        f = self.encode(x)
        return self.decode(f), f


def load_sae(
    path: str,
    d_model: int,
    expansion_factor: int,
    device: torch.device = torch.device("cpu"),
):
    sae = SparseAutoEncoder(
        d_model,
        d_model * expansion_factor,
        device,
    )
    sae_dict = torch.load(
        path, weights_only=True, map_location=device
    )
    sae.load_state_dict(sae_dict)

    return sae

SAE_NAME = "Llama-3.1-8B-Instruct-SAE-l19"
file_path = hf_hub_download(
    repo_id=f"Goodfire/{SAE_NAME}",
    filename=f"{SAE_NAME}.pth",
    repo_type="model"
)
sae = load_sae(
    file_path,
    d_model=4096,
    expansion_factor=16,
    device="cuda",
)

In [None]:
classes = {}
# === LOAD PROBES ===
for p in prefixes.keys():
    steering = torch.load(f"{ACTIVATION_CACHE}/llama-3.1-8b-it-lora-{p}/steering.pt", weights_only=True)
    classes[p] = steering

prefix = "gender"
sae_latents = sae.encoder_linear.weight.data
for i, p in enumerate(classes[prefix]):
    probe = torch.tensor(p, dtype=torch.bfloat16, device=sae.device)
    sim = F.cosine_similarity(probe.unsqueeze(0), sae_latents)
    top_values, top_indices = torch.topk(sim, 5)
    descs = client.features.lookup(
        top_indices.tolist(),
        "meta-llama/Meta-Llama-3.1-8B-Instruct"
    )
    print(f"Layer {i}")
    display(descs)
    print("="*100)