In [None]:
! pip install -r requirements.txt

In [1]:
import torch
import torch.nn as nn
import numpy as np
from datasets import load_dataset
from transformers import GPT2Tokenizer, GPT2Model
from einops import rearrange
from tqdm import tqdm
import matplotlib.pyplot as plt

## Running on macbook M2
device = "mps" if torch.backends.mps.is_available() else "cpu"
device


  from .autonotebook import tqdm as notebook_tqdm


'mps'

In [2]:
# Load dataset
dataset = load_dataset("roneneldan/TinyStories", split="train[:1%]")  # tiny subset
texts = dataset["text"][:5000]  # limit for speed
len(texts)

5000

In [3]:
# Load  GPT-2 small (124M)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2Model.from_pretrained("gpt2", output_hidden_states=True)
model.to(device)
model.eval()

print("GPT-2 Loaded.")

GPT-2 Loaded.


In [4]:
model

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D(nf=2304, nx=768)
        (c_proj): Conv1D(nf=768, nx=768)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D(nf=3072, nx=768)
        (c_proj): Conv1D(nf=768, nx=3072)
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)

In [15]:
model.h[LAYER]

GPT2Block(
  (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (attn): GPT2Attention(
    (c_attn): Conv1D(nf=2304, nx=768)
    (c_proj): Conv1D(nf=768, nx=768)
    (attn_dropout): Dropout(p=0.1, inplace=False)
    (resid_dropout): Dropout(p=0.1, inplace=False)
  )
  (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  (mlp): GPT2MLP(
    (c_fc): Conv1D(nf=3072, nx=768)
    (c_proj): Conv1D(nf=768, nx=3072)
    (act): NewGELUActivation()
    (dropout): Dropout(p=0.1, inplace=False)
  )
)

In [5]:
# Select a mid-layer where polysemantic neurons are common
LAYER = 6 

# Hook layer to capture activations
mlp_acts = []

def hook_fn(module, inp, out):
    # out is (batch, seq, hidden)
    mlp_acts.append(out.detach().cpu())

# GPT-2 MLP is in model.h[layer].mlp.c_fc
layer_mlp = model.h[LAYER].mlp.c_fc
hook = layer_mlp.register_forward_hook(hook_fn)


In [6]:
# Collect activations over dataset
all_acts = []

for text in tqdm(texts):
    tokens = tokenizer(text, return_tensors="pt", truncation=True, max_length=128).to(device)
    mlp_acts.clear()
    with torch.no_grad():
        _ = model(**tokens)
    if len(mlp_acts) > 0:
        acts = mlp_acts[0].squeeze(0)  # (seq, hidden)
        all_acts.append(acts)

all_acts = torch.cat(all_acts, dim=0)  # shape (N, d_hidden)
all_acts.shape


100%|██████████| 5000/5000 [01:46<00:00, 46.79it/s]


torch.Size([636269, 3072])

In [7]:
# 7. Detect polysemantic neurons
#    Method: cluster activation patterns and see if 
#    neuron responds to multiple distinct contexts.

from sklearn.cluster import KMeans

hidden_size = all_acts.shape[-1]

# compute top-k activating tokens for each neuron
k = 200
top_examples = torch.topk(all_acts, k=k, dim=0).indices  # shape (k, hidden)

# Represent each neuron by the contexts of its top activations
neuron_vectors = []

for neuron in range(hidden_size):
    idxs = top_examples[:, neuron]
    context_mats = all_acts[idxs]  # (k, hidden)
    neuron_vectors.append(context_mats.mean(dim=0).numpy())

neuron_vectors = np.stack(neuron_vectors)
print(neuron_vectors.shape)


(3072, 3072)


In [8]:
# 8. KMeans to find neurons with multi-cluster activation patterns
polysemantic_neurons = []

for neuron in range(hidden_size):
    activations = all_acts[top_examples[:, neuron], neuron].reshape(-1, 1)
    kmeans = KMeans(n_clusters=2, n_init=5).fit(activations)
    if kmeans.cluster_centers_[0][0] * kmeans.cluster_centers_[1][0] < 0.8 * max(kmeans.cluster_centers_):
        polysemantic_neurons.append(neuron)

print("Estimated polysemantic neurons:", len(polysemantic_neurons))
polysemantic_neurons[:20]


Estimated polysemantic neurons: 115


[14,
 31,
 72,
 87,
 96,
 136,
 140,
 162,
 178,
 227,
 252,
 257,
 294,
 310,
 332,
 341,
 368,
 375,
 389,
 442]

In [9]:
# 9. Build an Anthropic-Style Sparse Autoencoder (SAE)
class AnthropicSAE(nn.Module):
    def __init__(self, d_in, d_hidden):
        super().__init__()
        # No biases, small init
        self.W_enc = nn.Parameter(torch.randn(d_hidden, d_in) * 0.02)
        self.W_dec = nn.Parameter(torch.randn(d_in, d_hidden) * 0.02)

    def forward(self, x):
        # Normalize decoder weights column-wise
        W_dec_norm = self.W_dec / (self.W_dec.norm(dim=0, keepdim=True) + 1e-6)

        # Encoder → ReLU sparse code
        z = torch.relu(x @ self.W_enc.t())  # shape: [batch, d_hidden]

        # Decoder reconstruction
        x_hat = z @ W_dec_norm.t()

        return x_hat, z


# Instantiate SAE
sae = AnthropicSAE(hidden_size, d_hidden=all_acts.shape[-1]*4).to(device)


In [10]:

# Stable optimizer + LR
opt = torch.optim.AdamW(sae.parameters(), lr=3e-4, weight_decay=0.0)

# Normalize activations before training (VERY IMPORTANT)
acts_mean = all_acts.mean(dim=0, keepdim=True)
acts_std = all_acts.std(dim=0, keepdim=True) + 1e-6

acts_normed = (all_acts - acts_mean) / acts_std
acts_train = acts_normed.to(device)



In [20]:
# 10. Train SAE
BATCH = 512
EPOCHS = 3
L1 = 1e-5

for epoch in tqdm(range(EPOCHS), desc="Epoch"):
    perm = torch.randperm(acts_train.shape[0])
    acts_train = acts_train[perm]

    running_mse = 0
    running_l1 = 0
    steps = 0

    for i in tqdm(range(0, len(acts_train), BATCH), desc="Batch"):
        batch = acts_train[i:i+BATCH]
        opt.zero_grad()

        recon, z = sae(batch)

        mse = ((recon - batch)**2).mean()
        sparsity = z.abs().mean()

        loss = mse + L1 * sparsity
        loss.backward()
        opt.step()

        running_mse += mse.item()
        running_l1 += sparsity.item()
        steps += 1

    print(f"Epoch {epoch+1}: MSE={running_mse/steps:.4f}, Sparsity={running_l1/steps:.4f}")


Batch: 100%|██████████| 1243/1243 [01:15<00:00, 16.46it/s]
Epoch:  33%|███▎      | 1/3 [01:21<02:42, 81.18s/it]

Epoch 1: MSE=0.0717, Sparsity=0.7247


Batch: 100%|██████████| 1243/1243 [01:04<00:00, 19.16it/s]
Epoch:  67%|██████▋   | 2/3 [02:26<01:11, 71.61s/it]

Epoch 2: MSE=0.0100, Sparsity=0.9740


Batch: 100%|██████████| 1243/1243 [01:05<00:00, 19.04it/s]
Epoch: 100%|██████████| 3/3 [03:31<00:00, 70.46s/it]

Epoch 3: MSE=0.0069, Sparsity=1.0572





In [23]:
# 11. Demonstrate monosemanticity:
# Compare neuron vs SAE-feature before/after.

def show_top_sentences(neuron_idx, n=10):
    # baseline: original GPT-2 neuron
    idxs = torch.topk(all_acts[:, neuron_idx], k=n).indices
    return [texts[i % len(texts)] for i in idxs]


def show_top_feature(feature_idx, n=10):
    with torch.no_grad():
        # Use matrix multiplication, not a callable parameter
        z = torch.relu(all_acts.to(device) @ sae.W_enc.t()).cpu()
    idxs = torch.topk(z[:, feature_idx], k=n).indices
    return [texts[i % len(texts)] for i in idxs]


neuron = polysemantic_neurons[0]

print("=== BEFORE (polysemantic neuron) ===")
for s in show_top_sentences(neuron, 5):
    print("-", s[:100])

print("\n=== AFTER (SAE monosemantic feature) ===")
for s in show_top_feature(0, 5):
    print("-", s[:100])


=== BEFORE (polysemantic neuron) ===
- Once upon a time there was a mommy, a daddy and a baby. Every day, Mommy and Daddy would take Baby t
- Once upon a time, there was a little girl named Lily. She loved to play with her toys and eat snacks
- Tilly had the best day! She woke up in the morning and put on her new pants. They were so big that t
- Once upon a time, there were two friends, Jack and Jane. They were playing together in the park with
- Mum and Dad were packing for a trip.

Mum said to Dad, "Let's get the egg in the suitcase."

Dad agr

=== AFTER (SAE monosemantic feature) ===
- Once upon a time there was a little girl named Jane. She left the house early one morning to go visi
- Once there was a little girl who wanted to watch the television. She approached the large television
- Once upon a time, there was a foolish little mouse. He was so foolish that he thought he could bite 
- One day, there was a little girl called Sam who loved to play games. She particularly liked pla