In [32]:
from transformer_lens import HookedTransformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import json
from huggingface_hub import PyTorchModelHubMixin

device = "cuda" if torch.cuda.is_available() else "cpu"

In [15]:
model_name = "tiny-stories-2L-33M"
model = HookedTransformer.from_pretrained(model_name, device=device)

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [59]:
class SparseAutoEncoder(nn.Module, PyTorchModelHubMixin):
    def __init__(
        self, config=None
    ):
        super().__init__()
        if config is None:
            raise ValueError("Config not provided")
        torch.manual_seed(config["seed"])
        d_in = config["d_in"]
        d_hidden = config["d_hidden"]
        dtype = getattr(torch, config["dtype"])
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(d_in, d_hidden, dtype=dtype))
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_in, dtype=dtype))
        )
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        self.d_hidden = d_hidden
        

    def forward(self, x: torch.Tensor):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        return x_reconstruct, acts

In [38]:
reg_names = {
    "l1": "l1",
    "sqrt": "l1_sqrt",
    "pure_sqrt": "sqrt",
    "combined_hoyer_l1": "l1_hoyer",
    "combined_hoyer_sqrt": "l1_sqrt_hoyer",
    "combined_hoyer_pure_sqrt": "sqrt_hoyer",
}

In [42]:
def load_encoder(run_name, model_name):
    save_path = f"/workspace/{model_name}"

    with open(f"{save_path}/{run_name}.json", "r") as f:
        cfg = json.load(f)

    d_in = cfg["d_in"]
    d_hidden = cfg["d_in"] * cfg["expansion_factor"]
    hook_name = f"blocks.{cfg['layer']}.{cfg['act']}"

    sae_config = {"d_in": d_in, "d_hidden": d_hidden, "dtype":"float32", "seed":cfg["seed"], "wandb_name": cfg["wandb_name"], "hook_name": hook_name, "layer": cfg["layer"], "model": cfg["model"], "regularization": cfg["reg"], "wandb_id": cfg["wandb_name"].split("-")[-1]}
    encoder = SparseAutoEncoder(sae_config)
    encoder.load_state_dict(torch.load(f"{save_path}/{run_name}.pt"))
    #encoder.to(device)
    return encoder, sae_config

In [28]:
l1_models = ["228_earnest_voice", "229_comfy_haze", "230_divine_resonance", "236_dainty_bush",
             "262_azure_violet", "263_crimson_sunset", "264_prime_capybara", "265_tough_lake",
             "266_avid_firebrand", "267_flowing_elevator", "268_earnest_breeze", "269_pretty_violet",
            "273_spring_moon", "278_clear_plant", "277_icy_cherry",
            "289_giddy_firebrand", "290_neat_microwave", "291_lucky_voice"]

def upload_to_HF(encoder_name, model_name):
    encoder, cfg = load_encoder(encoder_name, model_name)
    name = f'SAE-{cfg["model"]}-L{cfg["layer"]}-{cfg["wandb_id"]}'
    #encoder.save_pretrained(name, config=sae_config)
    encoder.push_to_hub(name, config=cfg)

In [30]:
# for encoder_name in l1_models:
#     upload_to_HF(encoder_name, "tiny-stories-2L-33M")

In [60]:
encoder = SparseAutoEncoder.from_pretrained("lovish/SAE-tiny-stories-2L-33M-L1-291")

model.safetensors:   0%|          | 0.00/537M [00:00<?, ?B/s]