In [88]:
import torch
import einops

from circuit_lens import get_model_encoders
from z_sae import ZSAE
from mlp_transcoder import SparseTranscoder
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int
from torch import Tensor
from typing import List, Dict, TypedDict, Any, Union, Tuple, Optional
from tqdm import trange
from plotly_utils import imshow
from pprint import pprint
from transformer_lens.utils import get_act_name, to_numpy
from enum import Enum
from dataclasses import dataclass
from tqdm import tqdm 

# Import plotly stuff
import plotly.graph_objects as go
import plotly.express as px
import plotly.io as pio

In [90]:
model, z_saes, transcoders = get_model_encoders(device='cpu')

## Load the Pile dataset we'll use for activations

In [91]:
from datasets import load_dataset

dataset = load_dataset("NeelNanda/pile-10k")

Found cached dataset parquet (/Users/charlesoneill/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/2a3b91fbd88a2c90d1dbbb32b460cf621d31bd5b05b934492fdef7d8d6f236ec)


  0%|          | 0/1 [00:00<?, ?it/s]

In [94]:
# Split the huggingface dataset up into seq_len text
seq_len = 128
batch_size = 4096
model_name = 'gpt2-small'

model = HookedTransformer.from_pretrained(model_name, device='cpu')

tokenized_dataset = []
# Concat all the text together
text = " ".join(dataset['train']['text'])

# Tokenize the text
for i in trange(0, len(text)//100, 2500):
    tokens = model.to_tokens(text[i:i+2500]).squeeze()
    # Split into seq_len chunks
    for j in range(0, len(tokens), seq_len):
        tokenized_dataset.append(tokens[j:j+seq_len])

Loaded pretrained model gpt2-small into HookedTransformer


100%|██████████| 243/243 [00:00<00:00, 809.81it/s]


In [116]:
# Keep only examples with seq_len 128
tokenized_dataset = [x for x in tokenized_dataset if len(x) == seq_len]

In [171]:
len(tokenized_dataset)

1124

In [118]:
# Assert all tensors have shape seq_length
for i, tokens in enumerate(tokenized_dataset):
    assert tokens.shape[0] == seq_len, f"Token {i} has shape {tokens.shape}"

In [119]:
# Turn tokenized_dataset (a list of tensors) into a Pytorch Dataset
from torch.utils.data import Dataset

batch_size = 16

class TokenizedDataset(Dataset):
    def __init__(self, tokenized_dataset):
        self.tokenized_dataset = tokenized_dataset

    def __len__(self):
        return len(self.tokenized_dataset)

    def __getitem__(self, idx):
        return self.tokenized_dataset[idx]
    
dataset = TokenizedDataset(tokenized_dataset)

from torch.utils.data import DataLoader

dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
print(next(iter(dataloader)).shape)

torch.Size([16, 128])


In [120]:
# Disable torch grad
torch.set_grad_enabled(False)

<torch.autograd.grad_mode.set_grad_enabled at 0x4942133b0>

In [134]:
layer = 9
sae = z_saes[layer]

# Get all z activations
z_acts = []
for batch in tqdm(dataloader):
    logits, cache = model.run_with_cache(batch)
    z = cache["z", layer] # batch_size x seq_len x n_heads x d_head
    del logits
    del cache
    z = einops.rearrange(
        z, 
        "b s n d -> (b s) (n d)"
    )
    z_acts.append(z)

# Stack all z activations along first dimension
z_acts = torch.cat(z_acts, dim=0)
z_acts.shape

100%|██████████| 71/71 [00:53<00:00,  1.34it/s]


torch.Size([143872, 768])

In [135]:
torch.save(z_acts, 'z_acts.pt')

In [136]:
# Load z_acts
z_acts = torch.load("z_acts.pt")

## Get SAE reconstructions and errors

In [137]:
z_acts.shape

torch.Size([143872, 768])

In [138]:
# Create SAE dataset
class SAEDataset(Dataset):
    def __init__(self, z_acts):
        self.z_acts = z_acts

    def __len__(self):
        return len(self.z_acts)

    def __getitem__(self, idx):
        return self.z_acts[idx]
    
sae_dataset = SAEDataset(z_acts)

# Create SAE dataloader
sae_dataloader = DataLoader(sae_dataset, batch_size=batch_size, shuffle=True)

print(next(iter(sae_dataloader)).shape)

torch.Size([16, 768])


In [146]:
# Get SAE errors on each z_acts - we need to store the errors, and the original z_acts
sae_errors = []
original_z = []
for z_batch in tqdm(sae_dataloader):
    _, z_recon, z_acts, _, _ = sae(z_batch)
    sae_error = z_batch - z_recon
    sae_errors.append(sae_error)
    original_z.append(z_batch)
    
# Stack all sae errors along first dimension
sae_errors = torch.cat(sae_errors, dim=0)
original_z = torch.cat(original_z, dim=0)

  0%|          | 0/8992 [00:00<?, ?it/s]

100%|██████████| 8992/8992 [01:08<00:00, 130.99it/s]


In [147]:
sae_errors.shape, original_z.shape

(torch.Size([143872, 768]), torch.Size([143872, 768]))

In [148]:
# Save both
torch.save(sae_errors, 'sae_errors.pt')
torch.save(original_z, 'original_z.pt')

## Train a gated SAE to predict the errors

In [151]:
import torch
import einops
from torch import Tensor
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F

In [152]:
# Load in the sae errors and original z
sae_errors = torch.load('sae_errors.pt')
original_z = torch.load('original_z.pt')

In [163]:
# Gated SAE
class GatedSAE(nn.Module):

    def __init__(self, n_input_features, n_learned_features, l1_coefficient=0.01):

        super().__init__()

        self.n_input_features = n_input_features
        self.n_learned_features = n_learned_features
        self.l1_coefficient = l1_coefficient

        self.W_enc = nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(self.n_input_features, self.n_learned_features))   
        )
        self.W_dec = nn.Parameter(
            torch.nn.init.kaiming_uniform_(torch.empty(self.n_learned_features, self.n_input_features))   
        )

        self.r_mag = nn.Parameter(
            torch.zeros(self.n_learned_features)
        )
        self.b_mag = nn.Parameter(
            torch.zeros(self.n_learned_features)
        )
        self.b_gate = nn.Parameter(
            torch.zeros(self.n_learned_features)
        )
        self.b_dec = nn.Parameter(
            torch.zeros(self.n_input_features)
        )

        self.activation_fn = nn.ReLU()

    def forward(self, x_act, y_error):
        # Assert x_act (original z activations i.e. the input) and the y_error (SAE error i.e. the target) have the same shape
        assert x_act.shape == y_error.shape, f"x_act shape {x_act.shape} does not match y_error shape {y_error.shape}"

        hidden_pre = einops.einsum(x_act, self.W_enc, "... d_in, d_in d_sae -> ... d_sae")

        # Gated SAE
        hidden_pre_mag = hidden_pre * torch.exp(self.r_mag) + self.b_mag
        hidden_post_mag = self.activation_fn(hidden_pre_mag)  
        hidden_pre_gate = hidden_pre + self.b_gate
        hidden_post_gate = (torch.sign(hidden_pre_gate) + 1) / 2
        hidden_post = hidden_post_mag * hidden_post_gate

        sae_out = einops.einsum(hidden_post, self.W_dec, "... d_sae, d_sae d_in -> ... d_in") + self.b_dec

        # Now we need to handle all the loss stuff
        # Reconstruction loss
        per_item_mse_loss = self.per_item_mse_loss_with_target_norm(sae_out, y_error)
        mse_loss = per_item_mse_loss.mean()
        # L1 loss
        via_gate_feature_magnitudes = F.relu(hidden_pre_gate)
        sparsity = via_gate_feature_magnitudes.norm(p=1, dim=1).mean(dim=(0,))
        l1_loss = self.l1_coefficient * sparsity
        # Auxiliary loss
        via_gate_reconstruction = einops.einsum(via_gate_feature_magnitudes, self.W_dec.detach(), "... d_sae, d_sae d_in -> ... d_in") + self.b_dec.detach()
        aux_loss = F.mse_loss(via_gate_reconstruction, y_error, reduction="mean")
        
        loss = mse_loss + l1_loss + aux_loss

        return sae_out, loss

    def per_item_mse_loss_with_target_norm(self, preds, target):
        return torch.nn.functional.mse_loss(preds, target, reduction='none')

In [164]:
n_input_features = 768
projection_up = 4
gated_sae = GatedSAE(n_input_features=768, n_learned_features=n_input_features*projection_up)

In [165]:
# Test the forward pass
x = original_z[:16, :]
y = sae_errors[:16, :]
sae_out, loss = gated_sae(x, y)
sae_out.shape, loss

(torch.Size([16, 768]), tensor(2.5217))

In [166]:
# Create GatedSAE dataset
class GatedSAEDataset(Dataset):
    def __init__(self, original_z, sae_errors):
        self.original_z = original_z
        self.sae_errors = sae_errors

    def __len__(self):
        return len(self.original_z)

    def __getitem__(self, idx):
        return self.original_z[idx], self.sae_errors[idx]
    
gated_sae_dataset = GatedSAEDataset(original_z, sae_errors)

# Create GatedSAE train dataloader and test dataloader
train_size = int(0.8 * len(gated_sae_dataset))
test_size = len(gated_sae_dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(gated_sae_dataset, [train_size, test_size])

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [168]:
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x4b78acbc0>

In [170]:
# Training loop
import torch.optim as optim

n_epochs = 10
gated_sae = GatedSAE(n_input_features=768, n_learned_features=768*4)
optimizer = optim.Adam(gated_sae.parameters(), lr=0.001)

for epoch in range(n_epochs):
    print(f"Epoch {epoch}")
    for i, (x, y) in enumerate(train_dataloader):
        optimizer.zero_grad()
        sae_out, loss = gated_sae(x, y)
        loss.backward()
        print(f"Batch {i} Loss {loss.item()}")
        optimizer.step()
        if i % (n_epochs // 10) == 0:
            # Evaluate on test set
            test_loss = 0
            for x, y in test_dataloader:
                sae_out, loss = gated_sae(x, y)
                test_loss += loss.item()

Epoch 0
Batch 0 Loss 2.2689952850341797
Batch 1 Loss 1.7172391414642334
Batch 2 Loss 0.928523600101471
Batch 3 Loss 0.4895963668823242
Batch 4 Loss 0.722083568572998
Batch 5 Loss 0.24162815511226654
Batch 6 Loss 0.2363552749156952
Batch 7 Loss 0.16351088881492615
Batch 8 Loss 0.10887334495782852
Batch 9 Loss 0.05956375598907471
Batch 10 Loss 0.089531809091568
Batch 11 Loss 0.06512608379125595
Batch 12 Loss 0.08334121108055115
Batch 13 Loss 0.059068404138088226
Batch 14 Loss 0.04388595372438431
Batch 15 Loss 0.039961010217666626
Batch 16 Loss 0.04903621971607208
Batch 17 Loss 0.05079043656587601
Batch 18 Loss 0.11217576265335083
Batch 19 Loss 0.05086073279380798
Batch 20 Loss 0.059194840490818024
Batch 21 Loss 0.06037326902151108
Batch 22 Loss 0.04465043544769287
Batch 23 Loss 0.05342236906290054
Batch 24 Loss 0.05961407721042633
Batch 25 Loss 0.05566681548953056
Batch 26 Loss 0.057308271527290344
Batch 27 Loss 0.05143804848194122
Batch 28 Loss 0.04597126692533493
Batch 29 Loss 0.053627

KeyboardInterrupt: 