In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch as t
import torch.nn as nn
import torch.nn.functional as F
from beartype import beartype as typed
from beartype.door import die_if_unbearable as assert_type
from datasets import load_dataset
from jaxtyping import Float, Int, Bool
from typing import Callable
from torch import Tensor as TT
from transformers import AutoModelForCausalLM, AutoTokenizer
from einops import einops as ein

%load_ext autoreload
%autoreload 2

In [146]:
from tqdm import tqdm


class SparseAutoEncoder(nn.Module):
    @typed
    def __init__(self, in_features: int, h_features: int):
        super().__init__()
        self.in_features = in_features
        self.h_features = h_features
        self.weight = nn.Parameter(t.empty((in_features, h_features)))
        self.bias = nn.Parameter(t.empty((h_features,)))
        bound = (in_features * h_features) ** -0.25
        t.nn.init.normal_(self.weight, -bound, bound)
        t.nn.init.normal_(self.bias, -bound, bound)

    @typed
    def encode(self, x: Float[TT, "... in_features"]) -> Float[TT, "... h_features"]:
        return F.relu(x @ self.weight + self.bias)

    @typed
    def decode(self, x: Float[TT, "... h_features"]) -> Float[TT, "... in_features"]:
        return x @ self.weight.T

    @typed
    def forward(
        self, x: Float[TT, "... in_features"]
    ) -> tuple[Float[TT, "... in_features"], Float[TT, "... h_features"]]:
        code = self.encode(x)
        decoded = self.decode(code)
        return decoded, code


@typed
def fit_sae(
    input: Float[TT, "total_tokens d"],
    hidden_dim: int,
    lr: float,
    l1: float = 0.0,
    batch_size: int = 512,
    epochs: int = 10,
) -> SparseAutoEncoder:
    model = SparseAutoEncoder(input.size(-1), hidden_dim)
    optim = t.optim.Adam(model.parameters(), lr=lr)
    dataloader = t.utils.data.DataLoader(
        t.utils.data.TensorDataset(input),
        batch_size=batch_size,
        shuffle=True,
    )
    pbar = tqdm(range(epochs))
    losses = []
    for _ in pbar:
        for x in dataloader:
            optim.zero_grad()
            p, z = model(x)
            loss = F.mse_loss(p, x) + l1 * z.abs().mean()
            loss.backward()
            optim.step()
            losses.append(loss.item())
            half_size = (len(losses) + 1) // 2
            second_half_mean = sum(losses[-half_size:]) / half_size
            pbar.set_postfix_str(f": {second_half_mean:.3f}")
    return model