In [1]:
from transformer_lens import HookedTransformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import json

device = "cuda" if torch.cuda.is_available() else "cpu"

In [2]:
model = HookedTransformer.from_pretrained("tiny-stories-2L-33M", device=device)

Loaded pretrained model tiny-stories-2L-33M into HookedTransformer


In [3]:
class AutoEncoder(nn.Module):
    def __init__(
        self, d_hidden: int, d_in: int, dtype=torch.float32, seed=47
    ):
        super().__init__()
        torch.manual_seed(seed)
        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(d_in, d_hidden, dtype=dtype))
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_in, dtype=dtype))
        )
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_in, dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        self.d_hidden = d_hidden

    def forward(self, x: torch.Tensor):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        return x_reconstruct, acts

In [4]:
save_path = "/workspace/tiny-stories-2L-33M"
run_name = "189_giddy_water"

with open(f"{save_path}/{run_name}.json", "r") as f:
    cfg = json.load(f)

d_in = cfg["d_in"]
d_hidden = cfg["d_in"] * cfg["expansion_factor"]
hook_name = f"blocks.{cfg['layer']}.{cfg['act']}"

encoder = AutoEncoder(d_hidden, d_in)
encoder.load_state_dict(torch.load(f"{save_path}/{run_name}.pt"))
encoder.to(device)

AutoEncoder()

In [5]:
# Example prompt, more here: https://huggingface.co/datasets/roneneldan/TinyStories
prompt = 'Once upon a time, in a big forest, there lived a rhinoceros named Roxy. Roxy loved to climb. She climbed trees, rocks, and hills. One day, Roxy found an icy hill. She had never seen anything like it before. It was shiny and cold, and she wanted to climb it.\n\nRoxy tried to climb the icy hill, but it was very slippery. She tried again and again, but she kept falling down. Roxy was sad. She wanted to climb the icy hill so much. Then, she saw a little bird named Billy. Billy saw that Roxy was sad and asked, "Why are you sad, Roxy?"\n\nRoxy told Billy about the icy hill and how she couldn\'t climb it. Billy said, "I have an idea! Let\'s find some big leaves to put under your feet. They will help you climb the icy hill." Roxy and Billy looked for big leaves and found some. Roxy put the leaves under her feet and tried to climb the icy hill again.\n\nThis time, Roxy didn\'t slip. She climbed and climbed until she reached the top of the icy hill. Roxy was so happy! She and Billy played on the icy hill all day. From that day on, Roxy and Billy were the best of friends, and they climbed and played together all the time. And Roxy learned that with a little help from a friend, she could climb anything.'

In [6]:
loss, cache = model.run_with_cache(prompt, return_type="loss")
print(f"Original model loss: {loss:.2f}")

# Use SAE to reconstruct the MLP activations
x_reconstruct, acts = encoder(cache[hook_name])
print(f"Active SAE directions (L0): {(acts>0).sum(-1).float().mean():.2f}")

# Run model through the SAE
def sae_hook(value, hook):
    value = x_reconstruct
    return value

with model.hooks(fwd_hooks=[(hook_name, sae_hook)]):
    reconstruct_loss = model(prompt, return_type="loss")
print(f"Model reconstruction loss: {reconstruct_loss:.2f}")

Original model loss: 0.87
Active SAE directions (L0): 33.10
Model reconstruction loss: 1.17
