### Setup

In [1]:
import json
import os

with open("/Users/ghidav/Desktop/keys.json") as f:
    keys = json.load(f)

os.environ["HF_TOKEN"] = keys["huggingface"]

### Download SAEs

In [2]:
from huggingface_hub import snapshot_download
import os
import shutil

# Define the repository ID and target folder
repo_id = "mech-interp/baselines-jr-target-l0-pythia-160m-deduped"
subfolder = "layers.9"
local_save_path = "saes/pythia-160pm-deduped/jr/baseline/9"

# Download the entire repo snapshot temporarily
repo_path = snapshot_download(repo_id=repo_id, allow_patterns=f"{subfolder}/*")

# Copy only the desired subfolder to the local directory
subfolder_path = os.path.join(repo_path, subfolder)
if os.path.exists(subfolder_path):
    shutil.copytree(subfolder_path, local_save_path, dirs_exist_ok=True)
    print(f"Folder '{subfolder}' downloaded to '{local_save_path}'.")
else:
    print(f"Subfolder '{subfolder}' not found in the repository.")

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

Folder 'layers.9' downloaded to 'saes/pythia-160pm-deduped/jr/baseline/9'.


### Load SAEs

In [3]:
import torch
from safetensors import safe_open

state_dict = safe_open("saes/pythia-160pm-deduped/jr/baseline/8/sae.safetensors", framework="torch")
with open("saes/pythia-160pm-deduped/jr/baseline/8/cfg.json", "r") as f:
    cfg = json.load(f)

In [4]:
state_dict.keys(), cfg.keys()

(['W_dec', 'b_dec', 'encoder.bias', 'encoder.weight', 'log_threshold'],
 dict_keys(['expansion_factor', 'normalize_decoder', 'num_latents', 'k', 'multi_topk', 'jumprelu', 'jumprelu_init_threshold', 'jumprelu_bandwidth', 'jumprelu_target_l0', 'init_enc_as_dec_transpose', 'init_b_dec_as_zeros', 'd_in']))

In [5]:
from aether.encoder import FFEncoder
from aether.decoder import FFDecoder
from aether.core import AE
from aether.functions import JumpReLU

class JumpReLSAE(AE):
    def __init__(self, input_dim, latent_dim, output_dim):
        encoder = FFEncoder(input_dim, latent_dim)
        decoder = FFDecoder(latent_dim, output_dim)
        super().__init__(encoder, decoder, JumpReLU(latent_dim))

In [6]:
num_latents = cfg["num_latents"] if cfg["num_latents"] > 0 else cfg["d_in"] * cfg["expansion_factor"]
sae = JumpReLSAE(cfg["d_in"], num_latents, cfg["d_in"])
sae.to("mps")

In [8]:
new_state_dict = {
    "encoder.fc.weight": state_dict.get_tensor("encoder.weight"),
    "encoder.fc.bias": state_dict.get_tensor("encoder.bias"),
    "decoder.fc.weight": state_dict.get_tensor("W_dec").T,
    "decoder.fc.bias": state_dict.get_tensor("b_dec"),
    "F.log_threshold": state_dict.get_tensor("log_threshold"),
}

sae.load_state_dict(new_state_dict)

<All keys matched successfully>

In [153]:
W_dec = sae.decoder.fc.weight.cpu().detach().clone().T
W_dec.shape

torch.Size([12288, 768])

### Load Data

In [156]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
from aether.data import LMGenerator

In [157]:
lm = AutoModelForCausalLM.from_pretrained('EleutherAI/pythia-160m-deduped')
tokenizer = AutoTokenizer.from_pretrained('EleutherAI/pythia-160m-deduped')

dataset = load_dataset('EleutherAI/the_pile_deduplicated', split='train', streaming=True)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Resolving data files:   0%|          | 0/1650 [00:00<?, ?it/s]

In [162]:
W_U = lm.embed_out.weight.cpu().detach().clone()
W_U.shape

torch.Size([50304, 768])

In [29]:
seq_len = 32
batch_size = 16

gen = LMGenerator(
    lm=lm,
    tokenizer=tokenizer,
    dataset=dataset,
    column="text",
    seq_len=seq_len,
    hookpoints=["layers.3"],
    batch_size=batch_size,
    device="mps",
)

In [30]:
import torch

act_dict = {"layers.3": []}
tokens = []

n_batches = 2

for act_dict_, tokens_ in gen:
    for k, v in act_dict_.items():
        act_dict[k].append(v)
    tokens.append(tokens_)
    if len(tokens) >= n_batches:
        break

act_dict["layers.3"] = torch.cat(act_dict["layers.3"], dim=0)
tokens = torch.cat(tokens, dim=0)

### Running the SAE

In [45]:
with torch.no_grad():
    output, latents = sae(act_dict["layers.3"])

### Building the feature dataset

In [170]:
f_ids = torch.randint(0, latents.size(1), (128,))
selected_latents = latents[:, f_ids]

W_dec_selected = W_dec[f_ids]

In [171]:
N = batch_size * n_batches
selected_latents = selected_latents.reshape(N, seq_len, -1)

max_acts, _ = selected_latents.max(dim=1)

In [172]:
selected_latents.shape, tokens.shape

(torch.Size([32, 32, 128]), torch.Size([32, 32]))

In [179]:
import numpy as np
import random
from tqdm import tqdm

feature_dataset = {}
logit_scores = W_dec_selected @ W_U.T # [F, V]

for i, f_id in tqdm(enumerate(f_ids), total=len(f_ids)):
    feature_dataset[f_id.item()] = {}

    # Cut activations in quantiles: (0.8, 0.9, 0.95, 0.99)
    quantiles = max_acts[:, i].quantile(
        torch.tensor([0.8, 0.9, 0.95, 0.99], device=max_acts.device)
    )

    lower_bound = quantiles[0]
    for j, (quantile, label) in enumerate(
        zip(quantiles[1:], ["80-90", "90-95", "95-99"])
    ):
        mask = max_acts[:, i] > lower_bound
        mask &= max_acts[:, i] <= quantile

        sent_tokens = tokens[mask.cpu()].tolist()
        latent_acts = selected_latents[mask, :, i].cpu().tolist()

        # Zip tokens and activations, then sample at most 5
        zipped_data = list(zip(sent_tokens, latent_acts))
        sampled_data = random.sample(zipped_data, min(5, len(zipped_data)))

        feature_dataset[f_id.item()][label] = sampled_data
        lower_bound = quantile

    mask = max_acts[:, i] > lower_bound
    sent_tokens = tokens[mask.cpu()].tolist()
    latent_acts = selected_latents[mask, :, i].cpu().tolist()

    # Zip tokens and activations, then sample at most 5 for "99-100"
    zipped_data = list(zip(sent_tokens, latent_acts))
    sampled_data = random.sample(zipped_data, min(5, len(zipped_data)))

    # Find top and bottom token by attribution
    top_val, top_ids = torch.topk(logit_scores[i], 10)
    bottom_val, bottom_ids = torch.topk(-logit_scores[i], 10)

    feature_dataset[f_id.item()]["top_tokens"] = list(zip(top_ids.tolist(), top_val.tolist()))
    feature_dataset[f_id.item()]["bottom_tokens"] = list(zip(bottom_ids.tolist(), bottom_val.tolist()))

    feature_dataset[f_id.item()]["99-100"] = sampled_data

100%|██████████| 128/128 [00:00<00:00, 137.92it/s]


In [180]:
with open("feature_dataset.json", "w") as f:
    json.dump(feature_dataset, f)