In [17]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import argparse
import json
import pprint
import numpy as np
import random
import torch.nn as nn

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")
english_data = haystack_utils.load_json_data("data/english_europarl.json")

Downloading model.safetensors:   0%|          | 0.00/166M [00:00<?, ?B/s]


To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [13]:

cfg = {
    "seed": 47,
    "batch_size": 32,
    "model_batch_size": 128,
    "lr": 1e-4,
    "num_tokens": int(1e7),
    "l1_coeff": 3e-3,
    "wd": 1e-2,
    "beta1": 0.9,
    "beta2": 0.99,
    "dict_mult": 8,
    "seq_len": 128,
}
cfg["model_batch_size"] = cfg["batch_size"] // cfg["seq_len"] * 16


In [14]:
SEED = cfg["seed"]
GENERATOR = torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(True)

<torch.autograd.grad_mode.set_grad_enabled at 0x257a9d1f5d0>

In [15]:
n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

In [16]:
@torch.no_grad()
def get_mlp_acts(tokens, batch_size=1024, layer=5):
    _, cache = model.run_with_cache(tokens, names_filter=f"blocks.{layer}.mlp.hook_post")
    mlp_acts = cache[f"blocks.{layer}.mlp.hook_post"]
    mlp_acts = mlp_acts.reshape(-1, d_mlp)
    subsample = torch.randperm(mlp_acts.shape[0], generator=GENERATOR)[:batch_size]
    subsampled_mlp_acts = mlp_acts[subsample, :]
    return subsampled_mlp_acts, mlp_acts

sub, acts = get_mlp_acts(german_data[0], batch_size=1)
sub.shape, acts.shape

(torch.Size([1, 2048]), torch.Size([60, 2048]))

In [None]:
class AutoEncoder(nn.Module):
    def __init__(self, d_hidden, l1_coeff, dtype=torch.float32, seed=47):
        super().__init__()
        torch.manual_seed(seed)
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_mlp, d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, d_mlp, dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(d_mlp, dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct - x).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def remove_parallel_component_of_grads(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj