In [None]:
import json

config = {
    "trainer": {
        "trainer_class": "TrainerJumpRelu",
        "dict_class": "JumpReluAutoEncoder",
        "lr": -1,
        "l1_penalty": -1,
        "steps": -1,
        "seed": -1,
        "activation_dim": 2304,
        "dict_size": 16384,
        "device": "cuda:0",
        "layer": 11,
        "lm_name": "google/gemma-2-2b",
        "wandb_name": "-1",
        "submodule_name": "resid_post_layer_11"
    },
    "buffer": {
        "d_submodule": 2304,
        "io": "out",
        "n_ctxs": 2000,
        "ctx_len": 128,
        "refresh_batch_size": 32,
        "out_batch_size": 4096,
        "device": "cuda:0"
    }
}

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# Loading the model is only required if you want to make sure that the SAEs have the correct L0

import torch
from nnsight import LanguageModel
import json

from dictionary_learning import AutoEncoder, ActivationBuffer
from dictionary_learning.dictionary import (
    IdentityDict,
    GatedAutoEncoder,
    AutoEncoderNew,
)
from dictionary_learning.trainers.top_k import AutoEncoderTopK



model_name = "google/gemma-2-2b"
device = "cpu"
model_dtype = torch.bfloat16

model = LanguageModel(
    model_name,
    device_map=device,
    dispatch=True,
    attn_implementation="eager",
    torch_dtype=model_dtype,
)

layer = 11
submodule = model.model.layers[layer]
with model.trace("Hello World"):
    activations_BLD = submodule.output

    if type(submodule.output.shape) == tuple:
        activations_BLD = activations_BLD[0]

    orig_activations_BLD = activations_BLD.save()

In [None]:
from huggingface_hub import hf_hub_download
import os
import numpy as np
import torch

from dictionary_learning.dictionary import JumpReluAutoEncoder

def save_folder(layer: int, l0s: list[int]):

    save_dir = "gemma-2-2b"

    local_dir = f"gemma-2-2b_sweep_jumprelu_0902/resid_post_layer_{layer}"

    for idx, l0 in enumerate(l0s):

        sae_dir = os.path.join(local_dir, f"trainer_{idx}")

        path_to_params = hf_hub_download(
            repo_id="google/gemma-scope-2b-pt-res",
            filename=f"layer_{layer}/width_16k/average_l0_{l0}/params.npz",
            force_download=False,
            cache_dir=save_dir,
        )

        os.makedirs(sae_dir, exist_ok=True)

        params = np.load(path_to_params)
        pt_params = {k: torch.from_numpy(v).cuda() for k, v in params.items()}

        embed_dim = params['W_enc'].shape[0]
        latent_dim = params['W_enc'].shape[1]

        sae = JumpReluAutoEncoder(
            activation_dim=embed_dim,
            dict_size=latent_dim,
            device="cpu"
        )

        # For original JumpReluAutoEncoder using nn.Parameter instead of nn.Linear
        sae.load_state_dict(pt_params)

        # If you want to use nn.Linear instead of nn.Parameter
        # sae.encoder.weight.data = pt_params['W_enc'].T
        # sae.decoder.weight.data = pt_params['W_dec'].T
        # sae.b_enc.data = pt_params['b_enc']
        # sae.b_dec.data = pt_params['b_dec']
        # sae.threshold.data = pt_params['threshold']

        sae.to("cpu")

        torch.save(sae.state_dict(), os.path.join(sae_dir, "ae.pt"))

        config["trainer"]["activation_dim"] = embed_dim
        config["trainer"]["dict_size"] = latent_dim

        config["trainer"]["layer"] = layer
        config["trainer"]["submodule_name"] = f"resid_post_layer_{layer}"

        with open(os.path.join(sae_dir, "config.json"), "w") as f:
            json.dump(config, f, indent=2)

l0_dict = {
    3: [14, 28, 59, 142, 315],
    7: [20, 36, 69, 137, 285],
    11: [22, 41, 80, 168, 393],
    15: [23, 41, 78, 150, 308],
    19: [23, 40, 73, 137, 279]
}

for layer in l0_dict:
    save_folder(layer, l0_dict[layer])


In [None]:
ae_path = "/workspace/sae_eval/dictionary_learning/dictionaries/gemma-2-2b_sweep_jumprelu_0902/resid_post_layer_11/trainer_0/ae.pt"
ae_path = "/workspace/sae_eval/experiments/gemma-2-2b_sweep_jumprelu_0902/resid_post_layer_11/trainer_0/ae.pt"

sae = JumpReluAutoEncoder(
    activation_dim=2304,
    dict_size=16384,
    device="cpu"
)

sae.load_state_dict(torch.load(ae_path))

In [None]:

print(orig_activations_BLD.shape)

sae = sae.to(dtype=model_dtype)
activations_BLD = orig_activations_BLD[:, 1:, :] # Skip the BOS token
ae_activations_BLF = sae.encode(activations_BLD)
print(ae_activations_BLF.shape)
reconstructed_activations_BLD = sae.decode(ae_activations_BLF)
print(reconstructed_activations_BLD.shape)

l0 = (ae_activations_BLF != 0).float().sum(dim=-1).mean()
l2_loss = torch.linalg.norm(activations_BLD - reconstructed_activations_BLD, dim=-1).mean()

print(l0, l2_loss)

In [None]:
print(sae.W_dec.shape)