In [None]:
# First run these commands

# huggingface-cli login # (Required for gemma-2-2b download)
# git clone https://github.com/saprmarks/dictionary_learning.git
# cd dictionary_learning
# pip install -r requirements.txt
# cd ..

In [None]:


from huggingface_hub import snapshot_download
import os

repo_id = "canrager/lm_sae"

# To download all trainers in resid_post_layer_11
# folder_path = "gemma-2-2b_sweep_topk_ctx128_ef2_0824/resid_post_layer_11"

# To download a checkpoint
folder_path = "gemma-2-2b_sweep_topk_ctx128_ef2_0824/resid_post_layer_11_checkpoints/trainer_1_step_4882"

# To download a single trainer
folder_path = "gemma-2-2b_sweep_topk_ctx128_ef2_0824/resid_post_layer_11/trainer_2"

local_dir = os.path.join(os.getcwd(), folder_path)

downloaded_dir = snapshot_download(repo_id, allow_patterns=[f"{folder_path}/*"], local_dir="", force_download=True)

print(f"Folder downloaded to {downloaded_dir}")

In [None]:
import torch
from nnsight import LanguageModel
import json

from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import (
    IdentityDict,
    GatedAutoEncoder,
    AutoEncoderNew,
)
from dictionary_learning.trainers.top_k import AutoEncoderTopK

In [None]:


model_name = "google/gemma-2-2b"
device = "cpu"
model_dtype = torch.bfloat16

model = LanguageModel(
    model_name,
    device_map=device,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)

In [None]:
ae_path = "gemma-2-2b_sweep_topk_ctx128_ef2_0824/resid_post_layer_11/trainer_2/ae.pt"
config_path = "gemma-2-2b_sweep_topk_ctx128_ef2_0824/resid_post_layer_11/trainer_2/config.json"

with open(config_path, "r") as f:
    config = json.load(f)

submodule_str = config["trainer"]["submodule_name"]
layer = config["trainer"]["layer"]
ae_model_name = config["trainer"]["lm_name"]
dict_class = config["trainer"]["dict_class"]

assert model_name == ae_model_name

submodule = model.model.layers[layer]

if dict_class == "AutoEncoderTopK":
    k = config["trainer"]["k"]
    dictionary = AutoEncoderTopK.from_pretrained(ae_path, k=k, device=device)

In [None]:
with model.trace("Hello World"):
    activations_BLD = submodule.output

    if type(submodule.output.shape) == tuple:
        activations_BLD = activations_BLD[0]

    activations_BLD = activations_BLD.save()

print(activations_BLD.shape)


In [None]:
ae_activations_BLF = dictionary.encode(activations_BLD)
print(ae_activations_BLF.shape)
reconstructed_activations_BLD = dictionary.decode(ae_activations_BLF)
print(reconstructed_activations_BLD.shape)

In [None]:
l0 = (ae_activations_BLF != 0).float().sum(dim=-1).mean()
l2_loss = torch.linalg.norm(activations_BLD - reconstructed_activations_BLD, dim=-1).mean()

print(l0, l2_loss)