# Sparse Autoencoders & Superposition

## Monday 4/29/24

- Paired with Peter Kang.
- Read [Towards Monosemanticity](https://transformer-circuits.pub/2023/monosemantic-features).
- We started out (too ambitiously) trying to implement and train a one-layer transformer from scratch simply for the purposes of this study.
- We realized that would be worthy of its own separate study and notebook, and that the main purpose of this study is to study superposition and to implement SAEs, so we chose to use an off-the shelf model from TransformerLens.
- At first we wanted to train it ourselves. We ran into some annoying bugs with HF datasets that wouldn't let us iterate over the dataset for training.

In [302]:
import torch as t
import numpy as np
from datasets import load_dataset
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint
import wandb

dataset = load_dataset('JeanKaddour/minipile')
# dataset.save_to_disk('minipile')

cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=128,
    n_heads=8,
    d_head=64,
    n_ctx=1024,
    tokenizer_name='EleutherAI/gpt-neox-20b',
    act_fn='relu'
)
model = HookedTransformer(cfg=cfg)
pretrained_gelu_model = HookedTransformer.from_pretrained('gelu-1l')


training_args = TrainingArguments(output_dir='train', 
                                  per_device_train_batch_size=10000,
                                  per_device_eval_batch_size=10000,
                                  report_to=None)

trainer = Trainer(
    model,
    training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    tokenizer=model.tokenizer
)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  mps



Passing the following arguments to `Accelerator` is deprecated and will be removed in version 1.0 of Accelerate: dict_keys(['dispatch_batches', 'split_batches', 'even_batches', 'use_seedable_sampler']). Please pass an `accelerate.DataLoaderConfiguration` instead: 
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)



This call breaks when enumerating the dataset.

In [303]:
try:
    trainer.train()
except IndexError as e:
    print(e)
    wandb.finish()



Moving model to device:  mps


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


  0%|          | 0/300 [00:00<?, ?it/s]

Invalid key: 999976 is out of bounds for size 0


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

This reproduces the bug:

In [301]:
train_dataloader = trainer.get_train_dataloader()
try:
    for i, batch in enumerate(train_dataloader):
        print(f'{i}, {batch}')
except IndexError as e:
    print(e)

Invalid key: 999976 is out of bounds for size 0


Might want to keep loking into it, or might be worth just working with a pretrained one-layer transformer, since, again, the main goal for this study is to train an SAE and use it to interpret activations.

## Tuesday 4/30/24

- Decided to use pretrained 1L transformer from TransformerLens.
- Created SAE class.
- Created dataset to train SAE by sampling MLP activations using minipile dataset.
- Implemented basic SAE training loop.
- Trained SAE.

Loading pre-trained GELU model and creating SAE class.

In [112]:
import torch as t
from torch import nn
import datasets
import numpy as np
import einops

from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

gelu_model = HookedTransformer.from_pretrained('gelu-1l')

sae_input_dim = gelu_model.cfg.d_mlp
sae_hidden_dim = 8 * sae_input_dim
sae_output_dim = sae_input_dim

class SAE(nn.Module):
    
    def __init__(self, 
                 input_dim=sae_input_dim,
                 hidden_dim=sae_hidden_dim,
                 init_range=0.04,
                 ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_range = init_range

        self.bias_d = nn.Parameter(t.empty(input_dim))
        self.bias_e = nn.Parameter(t.empty(hidden_dim))
        self.W_e = nn.Parameter(t.empty(hidden_dim, input_dim))
        self.W_d = nn.Parameter(t.empty(input_dim, hidden_dim))
        nn.init.normal_(self.bias_d, std=init_range)
        nn.init.normal_(self.bias_e, std=init_range)
        nn.init.normal_(self.W_e, std=init_range)
        self.activation = nn.functional.relu

    def encode(self, input):
        inputs_centered = input - self.bias_d
        act_input = einops.einsum(inputs_centered, self.W_e,
                                  'batch input_dim, hidden_dim input_dim -> batch hidden_dim')
        act_input = act_input + self.bias_e
        return self.activation(act_input)

    def decode(self, features):
        output = einops.einsum(features, self.W_d,
                               'batch hidden_dim, input_dim hidden_dim -> batch input_dim')
        return output + self.bias_d
        
    def forward(self, input):
        return self.decode(self.encode(input))

Loaded pretrained model gelu-1l into HookedTransformer


Helper tokenization functions, running sample text through model and caching activations.

In [264]:
def strings_to_tokens(strings, model):
    tokens = [t.Tensor(model.tokenizer(string)['input_ids']).to(int) for string in strings]
    return t.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)

def decode(logits, model, separate=True):
    if separate:
        return [model.tokenizer.decode(token) for token in logits.argmax(-1).squeeze()]
    else:
        return model.tokenizer.decode(logits.argmax(-1))

strings = ['we are so back lmao', 'omg', 'pad see ew']
input_tokens = strings_to_tokens(strings, model)
print(input_tokens)

gelu_model.eval()
logits, cache = gelu_model.run_with_cache(input_tokens)
print(decode(logits, gelu_model))

tensor([[  664,   403,   594,   896,   298,   785,    80],
        [  297,    72,     0,     0,     0,     0,     0],
        [11022,   923,   299,    88,     0,     0,     0]])
['ausausausausausausaus', '0stThe####', '\r\ns\nendfrom##']


Running cached activations through SAE and verifying shapes match.

In [267]:
sae = SAE().to('mps')
sae.eval()

acts = cache['blocks.0.mlp.hook_post']
batch_size = acts.shape[0]
acts_flat = einops.rearrange(acts, 'b s d_mlp -> (b s) d_mlp')

encodings = sae.encode(acts_flat)
sae_output = sae(acts_flat)
sae_output = einops.rearrange(sae_output, '(b s) d_mlp -> b s d_mlp', b=batch_size)

print(f'Original activations shape: {acts.shape}')
print(f'SAE output shape: {sae_output.shape}')

Original activations shape: torch.Size([3, 7, 2048])
SAE output shape: torch.Size([3, 7, 2048])


Generating SAE training dataset by caching MLP activations.

In [139]:
from torch.utils.data import DataLoader
from datasets import load_dataset

batch_size = 64
dataset_name = 'JeanKaddour/minipile'
train_text = load_dataset(dataset_name, split='train[:1024]')
max_length = 512

def process_row(row):
    tokenized_row = gelu_model.tokenizer(row['text'], max_length=max_length)
    pad_length = max_length - len(tokenized_row['input_ids'])
    tokenized_row['input_ids'] += [gelu_model.tokenizer.pad_token_id] * pad_length
    tokenized_row['attention_mask'] += [1] * pad_length
    tokenized_row['token_type_ids'] += [0] * pad_length
    return tokenized_row

In [None]:

train_dataset = train_text.map(process_row)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

device = t.device('mps')
all_acts = t.tensor([]).to('cpu')
for i, row in enumerate(train_dataloader):
    batch = t.stack(row['input_ids']).T.to(device)
    _ , cache = gelu_model.run_with_cache(batch)
    acts = cache['blocks.0.mlp.hook_post']
    acts = einops.rearrange(acts, 'batch sequence d_mlp -> (batch sequence) d_mlp').to('cpu')
    all_acts = t.concat((all_acts, acts), dim=0)
    
# t.save(all_acts.detach(), f='mlp_activations_64bs_512l_minipile1024_gelu1l.pkl')

Training SAE.

In [None]:
from tqdm import tqdm

def get_loss(acts, sae, l1_coeff=0.006):
    features = sae.encode(acts)
    sae_output = sae.decode(features)
    rec_loss = (acts - sae_output).pow(2).mean()
    l1_loss = l1_coeff * features.abs().mean()
    return (rec_loss + l1_loss)

rand_indices = t.randperm(all_acts.shape[0])
all_acts = all_acts[rand_indices]
all_acts.shape

n_epochs = 3
batch_size = 1024
l1_coeff = 0.006

device = t.device('mps')
sae = SAE().to(device)
optimizer = t.optim.Adam(sae.parameters())

losses = []

sae.train()
for i in tqdm(range(n_epochs)):
    n_batches = all_acts.shape[0] // batch_size
    epoch_losses = []
    for i_batch in tqdm(range(n_batches)):
        batch = all_acts[i_batch * batch_size: (i_batch + 1) * batch_size].to(device)
        loss = get_loss(acts=batch, sae=sae, l1_coeff=l1_coeff)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_losses.append(loss.item())
        if (i_batch%100 == 0 or (i_batch < 100 and i_batch % 10 == 0 and i==0)):
            print(f'{i}/{i_batch}: {loss.item()}')
    losses.append(epoch_losses)


Plotting training loss.

In [115]:
import plotly.express as px
all_losses = losses[0] + losses[1] + losses[2]
px.line(all_losses, log_x=True)

NameError: name 'losses' is not defined

One-hot SAE feature vectors don't always get well reconstructed by decoding and encoding. We loop through features and measure a few ratios:
- Ratio between reconstructed value at one-hot encoded feature and the max value of that reconstruction.
- Ratio between reconstructed value at one-hot encoded feature and mean value of all SAE reconstructions.
- Same but with squares.

In [114]:
def test_features_roundtrip(sae):
    sae.eval()
    max_ratios = []
    mean_ratios = []
    square_ratios = []
    for test_feature_index in tqdm(range(sae_hidden_dim)):
        test_feature = t.zeros(sae_hidden_dim).to(device)
        test_feature[test_feature_index] = 1
        test_feature.unsqueeze_(0)

        test_mlp_activation = sae.decode(test_feature)
        encoded_features = sae.encode(test_mlp_activation)
        max_ratio = encoded_features[0,test_feature_index]/encoded_features.max()
        mean_ratio = encoded_features[0,test_feature_index]/encoded_features.mean()
        square_ratio = encoded_features[0,test_feature_index].pow(2)/encoded_features.pow(2).mean()
        max_ratios.append(max_ratio.item())
        mean_ratios.append(mean_ratio.item())
        square_ratios.append(square_ratio.item())
    # px.line(encoded_features[0].cpu().detach().numpy())
    return max_ratios, mean_ratios, square_ratios

max_ratios, mean_ratios, square_ratios = test_features_roundtrip(sae)
px.line(sorted(max_ratios, reverse=True))

100%|██████████| 16384/16384 [01:30<00:00, 180.98it/s]


NameError: name 'px' is not defined

## Thursday 5/2/24

- Used TransformerLens hooks to plug SAE after MLP in transformer.
- Wrote some helper functions to compare performance with/without SAE.

Plugging SAE into transformer using TransformerLens hooks.

In [113]:
device = t.device('mps')
sae = t.load('sae-0430.pt').to(device)
dataset = t.load('mlp_activations_64bs_512l_minipile1024_gelu1l.pkl').to(device)
act = dataset[0].unsqueeze(0)
feats = sae.encode(act)

In [63]:
input_str = 'Hello from the other side this is incredible'
input_tokens = t.Tensor(gelu_model.tokenizer.encode(input_str)).unsqueeze(0).to(int)
input_str_tokenized = gelu_model.tokenizer.tokenize(input_str)

mlp_acts = []
sae_feats = []
mlp_acts_reconstructed = []

def save_mlp_acts_hook(value, hook):
    mlp_acts.append(value)

def sae_reconstruction_hook(value, hook):
    mlp_acts.append(value)
    batch_size = value.shape[0]
    flat_value = einops.rearrange(value, 'b s d_mlp -> (b s) d_mlp')
    # Encode MLP acts into SAE features
    flat_sae_feats = sae.encode(flat_value)
    reshaped_sae_feats = einops.rearrange(flat_sae_feats, '(b s) hidden_dim -> b s hidden_dim', b=batch_size)
    sae_feats.append(reshaped_sae_feats)
    # Decode SAE features to reconstruct MLP activations
    flat_mlp_acts_reconstructed = sae.decode(flat_sae_feats)
    reshaped_mlp_acts_reconstructed = einops.rearrange(flat_mlp_acts_reconstructed,
                                                       '(b s) d_mlp -> b s d_mlp', b=batch_size)
    mlp_acts_reconstructed.append(reshaped_mlp_acts_reconstructed)
    return reshaped_mlp_acts_reconstructed


gelu_model.eval()
sae.eval()
logits = gelu_model(input_tokens)
logits_sae = gelu_model.run_with_hooks(
    input_tokens,
    fwd_hooks=[
    (
        'blocks.0.mlp.hook_post',
        sae_reconstruction_hook
    )
    ]
)
logits = logits[0]
logits_sae = logits_sae[0]

mlp_acts = mlp_acts[0]
sae_feats = sae_feats[0]
mlp_acts_reconstructed = mlp_acts_reconstructed[0]

A few helper functions to start comparing performance with/without SAE.

In [60]:
def get_top_tokens(logits, model, top_k=5):
    sorted_logits, sorted_inds = t.sort(logits, descending=True, dim=-1)
    top_k = 5
    top_tokens = []
    for seq_pos in range(logits.shape[0]):
        pos_top_tokens = sorted_inds[seq_pos]
        pos_top_tokens_decoded = []
        for top_k_ind in range(top_k):
            pos_top_tokens_decoded.append(model.tokenizer.decode(pos_top_tokens[top_k_ind]))
        top_tokens.append(pos_top_tokens_decoded)
    return top_tokens

In [292]:
top_tokens_orig = get_top_tokens(logits, gelu_model)
top_tokens_recons = get_top_tokens(logits_sae, gelu_model)

In [64]:
def compare_top_tokens(input_str, logits_sae, logits, model):
    top_tokens_orig = get_top_tokens(logits, model)
    top_tokens_recons = get_top_tokens(logits_sae, model)
    for input_str_token, top_k_orig, top_k_recons in zip(input_str, 
                                                        top_tokens_orig, 
                                                        top_tokens_recons):
        print(f'Input: {input_str_token}\tOriginal top 5:{top_k_orig}\tReconstructed top 5:{top_k_recons}')
compare_top_tokens(input_str_tokenized, logits_sae, logits, gelu_model)

Input: Hello	Original top 5:['0', ' I', ' Thanks', '1', ' Thank']	Reconstructed top 5:[' I', ' my', ' Thank', 'Hello', ',']
Input: Ġfrom	Original top 5:[' my', ' I', '0', ' Welcome', ' 2']	Reconstructed top 5:[' my', ' I', 'Hello', '0', 'D']
Input: Ġthe	Original top 5:[' same', ' I', ' website', ' Welcome', ' name']	Reconstructed top 5:[' I', 'Hello', ' my', ' course', 'Hi']
Input: Ġother	Original top 5:[' website', 'Hello', ' course', ' part', ' blog']	Reconstructed top 5:[' I', 'Hello', 'I', 'Hi', ' my']
Input: Ġside	Original top 5:[' of', ' I', 'Hello', ' and', '0']	Reconstructed top 5:[' I', ' of', 'I', 'Hello', 'Hi']
Input: Ġthis	Original top 5:['Hello', ' blog', ' is', 'Hi', ' I']	Reconstructed top 5:[' I', ' my', 'Hello', ' blog', ' course']
Input: Ġis	Original top 5:[' my', 'Hello', ' work', ' the', ' a']	Reconstructed top 5:[' my', ' I', ' the', ',', ' blog']
Input: Ġincredible	Original top 5:[' and', ' work', 'Hello', ' I', ',']	Reconstructed top 5:[' I', ',', 'Hello', ' blog

In [294]:
def get_crossentropy(logits_sae, logits):
    crossentropy = t.nn.functional.cross_entropy(input=logits_sae, 
                                                target=logits)
    return crossentropy

tensor(-453661.6250, device='mps:0', grad_fn=<DivBackward1>)

## Monday 5/6/24

SAE with batch dimension.

In [1]:
import torch as t
from torch import nn
import datasets
import numpy as np
import einops

from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

device = t.device('mps')
gelu_model = HookedTransformer.from_pretrained('gelu-1l').to(device)

sae_input_dim = gelu_model.cfg.d_mlp
sae_hidden_dim = 8 * sae_input_dim
sae_output_dim = sae_input_dim

class SAE(nn.Module):
    
    def __init__(self, 
                 input_dim=sae_input_dim,
                 hidden_dim=sae_hidden_dim,
                 init_range=0.04,
                 ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_range = init_range

        self.bias_d = nn.Parameter(t.empty(input_dim))
        self.bias_e = nn.Parameter(t.empty(hidden_dim))
        self.W_e = nn.Parameter(t.empty(hidden_dim, input_dim))
        self.W_d = nn.Parameter(t.empty(input_dim, hidden_dim))
        nn.init.normal_(self.bias_d, std=init_range)
        nn.init.normal_(self.bias_e, std=init_range)
        nn.init.normal_(self.W_e, std=init_range)
        self.activation = nn.functional.relu

    def encode(self, input):
        inputs_centered = input - self.bias_d
        act_input = einops.einsum(inputs_centered, self.W_e,
                                  'batch seq input_dim, hidden_dim input_dim -> batch seq hidden_dim')
        act_input = act_input + self.bias_e
        return self.activation(act_input)

    def decode(self, features):
        output = einops.einsum(features, self.W_d,
                               'batch seq hidden_dim, input_dim hidden_dim -> batch seq input_dim')
        return output + self.bias_d
        
    def forward(self, input):
        return self.decode(self.encode(input))

Loaded pretrained model gelu-1l into HookedTransformer
Moving model to device:  mps


In [5]:
sae_flat = t.load('sae-0430.pt')
sae = SAE().to(device)
sae.load_state_dict(sae_flat.state_dict())
all_acts = t.load('mlp_activations_64bs_512l_minipile1024_gelu1l.pkl')
rand_indices = t.randperm(all_acts.shape[0])
all_acts = all_acts[rand_indices]
all_acts.shape


torch.Size([524288, 2048])

Eval functions including updated ones

In [2]:
from functools import partial
from dataclasses import dataclass

device = t.device('mps')

@dataclass
class SAEActivationTracker:
    mlp_acts = t.empty(0)
    feats = t.empty(0)
    mlp_acts_reconstructed = t.empty(0)

    def reconstruction_hook(self, value, hook, sae):
        self.mlp_acts = value.detach()
        features = sae.encode(value)
        self.features = features.detach()
        acts_reconstructed = sae.decode(features)
        self.mlp_acts_reconstructed = acts_reconstructed.detach()
        return acts_reconstructed
    
    def get_act_mse(self):
        return t.nn.functional.mse_loss(self.mlp_acts_reconstructed, self.mlp_acts).detach()

def run_with_sae(model, sae, input_tokens, sae_act_tracker):
    hook_fn = partial(sae_act_tracker.reconstruction_hook, sae=sae)
    logits = model.run_with_hooks(
        input_tokens,
        fwd_hooks=[(
           'blocks.0.mlp.hook_post',
           hook_fn
        )] 
    )
    return logits
    
def run_test_input(sae, model, input_text):
    input_tokens = t.Tensor(model.tokenizer.encode(input_text)).unsqueeze(0).to(int).to(device)
    input_str_tokenized = model.tokenizer.tokenize(input_text)
    logits = model(input_tokens)
    sae_act_tracker = SAEActivationTracker()
    logits_sae = run_with_sae(model, sae, input_tokens, sae_act_tracker)
    return logits, logits_sae, sae_act_tracker

def get_top_tokens(logits, model, top_k=5):
    logits = logits[0]
    sorted_logits, sorted_inds = t.sort(logits, descending=True, dim=-1)
    top_k = 5
    top_tokens = []
    for seq_pos in range(logits.shape[0]):
        pos_top_tokens = sorted_inds[seq_pos]
        pos_top_tokens_decoded = []
        for top_k_ind in range(top_k):
            pos_top_tokens_decoded.append(model.tokenizer.decode(pos_top_tokens[top_k_ind]))
        top_tokens.append(pos_top_tokens_decoded)
    return top_tokens

def compare_top_tokens(input_text, logits_sae, logits, model):
    input_tokens = model.tokenizer.tokenize(input_text)
    top_tokens_orig = get_top_tokens(logits, model)
    top_tokens_recons = get_top_tokens(logits_sae, model)
    for input_str_token, top_k_orig, top_k_recons in zip(input_tokens, 
                                                        top_tokens_orig, 
                                                        top_tokens_recons):
        print(f'Input: {input_str_token}\tOriginal top 5:{top_k_orig}\tReconstructed top 5:{top_k_recons}')

def get_crossentropy(logits_sae, logits):
    crossentropy = t.nn.functional.cross_entropy(input=logits_sae, 
                                                target=logits)
    return crossentropy

def test_features_roundtrip(sae):
    max_ratios = []
    mean_ratios = []
    square_ratios = []
    for test_feature_index in tqdm(range(sae_hidden_dim)):
        test_feature = t.zeros(sae_hidden_dim).to(device)
        test_feature[test_feature_index] = 1
        test_feature.unsqueeze_(0)

        test_mlp_activation = sae.decode(test_feature)
        encoded_features = sae.encode(test_mlp_activation)
        max_ratio = encoded_features[0,test_feature_index]/encoded_features.max()
        mean_ratio = encoded_features[0,test_feature_index]/encoded_features.mean()
        square_ratio = encoded_features[0,test_feature_index].pow(2)/encoded_features.pow(2).mean()
        max_ratios.append(max_ratio.item())
        mean_ratios.append(mean_ratio.item())
        square_ratios.append(square_ratio.item())
    # px.line(encoded_features[0].cpu().detach().numpy())
    return max_ratios, mean_ratios, square_ratios

In [126]:
DEF_INPUT_TEXT = 'Hello from the other side this is incredible'
def eval_sae_single_input(sae, model, input_text=DEF_INPUT_TEXT):
    sae.eval()
    model.eval()
    input_text = DEF_INPUT_TEXT
    print(f'Input text: {input_text}')
    logits, logits_sae, sae_act_tracker = run_test_input(sae, model, input_text)
    activation_mse = sae_act_tracker.get_act_mse()
    print(f'MSE of MLP activation reconstruction: {activation_mse}')
    crossentropy = get_crossentropy(logits_sae, logits)
    print(f'Logit cross-entropy: {crossentropy}')
    compare_top_tokens(input_text, logits_sae, logits, model)
    # metrics = test_features_roundtrip(sae)
eval_sae_single_input(sae, gelu_model)

Input text: Hello from the other side this is incredible
MSE of MLP activation reconstruction: 0.06152134761214256
Logit cross-entropy: -5.282353401184082
Input: Hello	Original top 5:['0', ' I', ' Thanks', '1', ' Thank']	Reconstructed top 5:[' I', ' my', ' Thank', 'Hello', ',']
Input: Ġfrom	Original top 5:[' my', ' I', '0', ' Welcome', ' 2']	Reconstructed top 5:[' my', ' I', 'Hello', '0', 'D']
Input: Ġthe	Original top 5:[' same', ' I', ' website', ' Welcome', ' name']	Reconstructed top 5:[' I', 'Hello', ' my', ' course', 'Hi']
Input: Ġother	Original top 5:[' website', 'Hello', ' course', ' part', ' blog']	Reconstructed top 5:[' I', 'Hello', 'I', 'Hi', ' my']
Input: Ġside	Original top 5:[' of', ' I', 'Hello', ' and', '0']	Reconstructed top 5:[' I', ' of', 'I', 'Hello', 'Hi']
Input: Ġthis	Original top 5:['Hello', ' blog', ' is', 'Hi', ' I']	Reconstructed top 5:[' I', ' my', 'Hello', ' blog', ' course']
Input: Ġis	Original top 5:[' my', 'Hello', ' work', ' the', ' a']	Reconstructed top 5:

## Tuesday 5/7/24

Simple reconstruction / cross-entropy eval with batches.

In [15]:
from torch.utils.data import DataLoader
from datasets import load_dataset
max_length = 64

def process_row(row):
    tokenized_row = gelu_model.tokenizer(row['text'], max_length=max_length)
    pad_length = max_length - len(tokenized_row['input_ids'])
    tokenized_row['input_ids'] += [gelu_model.tokenizer.pad_token_id] * pad_length
    tokenized_row['attention_mask'] += [1] * pad_length
    tokenized_row['token_type_ids'] += [0] * pad_length
    return tokenized_row

def eval_sae_batch(sae, model, batch_size=256):
    model.eval()
    sae.eval()
    dataset = gelu_model.load_sample_training_dataset()
    eval_dataset = dataset.map(process_row)
    eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
    sae_act_tracker = SAEActivationTracker()
    batch = t.stack(next(iter(eval_dataloader))['input_ids']).T.to(device)
    logits = gelu_model(batch)
    logits_sae = run_with_sae(gelu_model, sae, batch, sae_act_tracker)
    crossentropy = get_crossentropy(logits_sae, logits)
    act_mse = sae_act_tracker.get_act_mse()
    print(f'MSE of MLP activation reconstruction: {act_mse}')
    print(f'Logit cross-entropy: {crossentropy}')

In [16]:
eval_sae_batch(sae, gelu_model)

Map:   0%|          | 0/20000 [00:00<?, ? examples/s]

MSE of MLP activation reconstruction: 0.05424727499485016
Logit cross-entropy: 3.7227911949157715


--- 
## Messy from here on

### HF dataset

In [1]:
import torch as t
import datasets
import numpy as np


In [2]:
from datasets import load_dataset
dataset = load_dataset('JeanKaddour/minipile')

In [3]:
dataset.save_to_disk('minipile')

Saving the dataset (0/12 shards):   0%|          | 0/1000000 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/500 [00:00<?, ? examples/s]

Saving the dataset (0/1 shards):   0%|          | 0/10000 [00:00<?, ? examples/s]

In [42]:
small_dataset_train  = dataset['train'].select(range(10000))
small_dataset_test  = dataset['test'].select(range(10000))

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x49a8ea6d0>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 2b74d7cd0, raw_cell="small_dataset_train  = dataset['train'].select(ran.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/dani/code/mech-interp-notebooks/03-sae/sae.ipynb#W3sZmlsZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x49a8ea6d0>> (for post_run_cell), with arguments args (<ExecutionResult object at 497bd6f10, execution_count=42 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 2b74d7cd0, raw_cell="small_dataset_train  = dataset['train'].select(ran.." store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/dani/code/mech-interp-notebooks/03-sae/sae.ipynb#W3sZmlsZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [30]:
raw_datasets = load_dataset("glue", "mrpc")

Error in callback <bound method _WandbInit._resume_backend of <wandb.sdk.wandb_init._WandbInit object at 0x49a8ea6d0>> (for pre_run_cell), with arguments args (<ExecutionInfo object at 49158b1d0, raw_cell="raw_datasets = load_dataset("glue", "mrpc")" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/dani/code/mech-interp-notebooks/03-sae/sae.ipynb#X23sZmlsZQ%3D%3D>,),kwargs {}:


TypeError: _WandbInit._resume_backend() takes 1 positional argument but 2 were given

Downloading readme:   0%|          | 0.00/35.3k [00:00<?, ?B/s]

Downloading data: 100%|██████████| 649k/649k [00:00<00:00, 1.29MB/s]
Downloading data: 100%|██████████| 75.7k/75.7k [00:00<00:00, 549kB/s]
Downloading data: 100%|██████████| 308k/308k [00:00<00:00, 3.64MB/s]


Generating train split:   0%|          | 0/3668 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/408 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1725 [00:00<?, ? examples/s]

Error in callback <bound method _WandbInit._pause_backend of <wandb.sdk.wandb_init._WandbInit object at 0x49a8ea6d0>> (for post_run_cell), with arguments args (<ExecutionResult object at 49ac82e10, execution_count=30 error_before_exec=None error_in_exec=None info=<ExecutionInfo object at 49158b1d0, raw_cell="raw_datasets = load_dataset("glue", "mrpc")" store_history=True silent=False shell_futures=True cell_id=vscode-notebook-cell:/Users/dani/code/mech-interp-notebooks/03-sae/sae.ipynb#X23sZmlsZQ%3D%3D> result=None>,),kwargs {}:


TypeError: _WandbInit._pause_backend() takes 1 positional argument but 2 were given

In [14]:
np.random.choice(dataset['train'], 30)

array([{'text': "Springfield seeks new fire levy\n\nWednesday\n\nOct 20, 2010 at 12:01 AM\n\nSPRINGFIELD — On Nov. 2, residents will choose whether or not to continue paying a little extra in property taxes to cover the cost of a 12-person fire engine crew.\n\nThe special levy renewal proposes reducing the amount property owners are currently paying by 4 cents per $1,000 of assessed value, thanks to an increase in property values and new construction in the city during the past four years.\n\nVoters first approved the fire levy in 2002. The levy has paid for adding a fifth fire engine crew of 12 firefighter-paramedics to the city’s ranks to fill a gap created in 1997 when the city added a fifth fire station in the Gateway area.\n\nThe new fire station was completed at about the time that statewide property tax-limiting measures cut the city’s general fund income by $1.6 million.\n\nAt first, firefighters tried to make do, moving an engine crew from the fire station at 28th Street and C

In [12]:
from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

# cfg = HookedTransformerConfig(
#     d_model=128,
#     act_fn='relu',
#     n_layers=1,
#     d_mlp=512,
# )
gelu_model = HookedTransformer.from_pretrained('gelu-1l')

Loaded pretrained model gelu-1l into HookedTransformer


In [4]:
gelu_model.cfg
cfg = HookedTransformerConfig(
    n_layers=1,
    d_model=128,
    n_heads=8,
    d_head=64,
    n_ctx=1024,
    tokenizer_name='EleutherAI/gpt-neox-20b',
    act_fn='relu'
)
model = HookedTransformer(cfg=cfg)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [5]:
model.tokenizer.encode(dataset['train'][0]['text'])

[5324,
 36,
 434,
 657,
 422,
 1294,
 48691,
 310,
 2130,
 281,
 638,
 14,
 2621,
 323,
 370,
 24,
 1525,
 187,
 187,
 1231,
 1849,
 2326,
 9828,
 273,
 2325,
 1832,
 14,
 34821,
 611,
 3027,
 13139,
 275,
 776,
 673,
 13,
 690,
 1805,
 685,
 2571,
 15,
 29578,
 13,
 2299,
 13,
 1132,
 3240,
 594,
 3587,
 327,
 253,
 1416,
 347,
 42321,
 23751,
 434,
 2325,
 1507,
 15,
 1198,
 370,
 1099,
 13,
 18943,
 1336,
 755,
 247,
 873,
 273,
 40613,
 326,
 8800,
 1652,
 1480,
 36199,
 281,
 3196,
 15,
 49529,
 434,
 9797,
 12317,
 273,
 4327,
 13,
 533,
 403,
 642,
 5545,
 3033,
 281,
 21097,
 3858,
 1969,
 387,
 1878,
 13,
 598,
 1919,
 597,
 923,
 247,
 5230,
 21436,
 18479,
 3185,
 273,
 247,
 2406,
 14,
 5045,
 378,
 15,
 11239,
 2920,
 13,
 627,
 434,
 625,
 281,
 352,
 685,
 816,
 38420,
 285,
 21643,
 21736,
 15,
 6000,
 7471,
 588,
 1421,
 281,
 247,
 23672,
 273,
 45364,
 320,
 1507,
 313,
 5371,
 2010,
 14677,
 281,
 253,
 6347,
 3972,
 5410,
 11431,
 6022,
 273,
 21491,
 3928,
 15,
 1

In [6]:
from transformers import Trainer, TrainingArguments, DataCollatorWithPadding

tokenizer = model.tokenizer
def tokenize_function(example):
    res = tokenizer(example['text'])
    return {'input_ids': res['input_ids']}


tokenized_dataset_train = small_dataset_train.map(tokenize_function, batched=True)
tokenized_dataset_test = small_dataset_test.map(tokenize_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

NameError: name 'small_dataset_train' is not defined

In [7]:
tokenizer('helloooooo')

{'input_ids': [14440, 33363, 3288], 'attention_mask': [1, 1, 1]}

In [88]:
print(tokenized_dataset_train)

Dataset({
    features: ['text', 'input_ids'],
    num_rows: 10000
})


In [8]:
# collator = DataCollatorWithPadding(tokenizer=model.tokenizer)
training_args = TrainingArguments(output_dir='train', 
                                  per_device_train_batch_size=10000,
                                  per_device_eval_batch_size=10000,
                                  report_to=None)

trainer = Trainer(
    model,
    training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['test'],
    tokenizer=model.tokenizer
    # data_collator=collator,
)


Moving model to device:  mps


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [10]:
trainer.train()


Moving model to device:  mps


  0%|          | 0/300 [00:00<?, ?it/s]

IndexError: Invalid key: 999976 is out of bounds for size 0

In [99]:
tokenized_dataset_train

Dataset({
    features: ['text', 'input_ids'],
    num_rows: 10000
})

In [93]:
train_dataloader = trainer.get_train_dataloader()

In [94]:
for i, batch in enumerate(train_dataloader):
    print(f'{i}, {batch}')

ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided []

In [60]:
import wandb
wandb.finish()

In [61]:
import os
os.environ['WANDB_MODE'] = 'disabled'
import wandb as wb
wb = None

In [58]:
print('Hello world')

Hello world


### GELU Model

In [1]:
import torch as t
import datasets
import numpy as np

from transformer_lens import utils, ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

# cfg = HookedTransformerConfig(
#     d_model=128,
#     act_fn='relu',
#     n_layers=1,
#     d_mlp=512,
# )
gelu_model = HookedTransformer.from_pretrained('gelu-1l')

Loaded pretrained model gelu-1l into HookedTransformer


In [18]:
logits

tensor([[[11.2926, -5.7966, -5.7007,  ..., -5.8662, -5.7967, -5.7530],
         [ 5.7718, -5.3477, -5.2548,  ..., -5.4195, -5.3171, -5.4112],
         [ 5.0246, -5.4987, -5.4025,  ..., -5.5617, -5.4628, -5.5536],
         [ 7.8723, -5.9839, -5.8257,  ..., -6.0378, -5.8949, -5.9977]]],
       device='mps:0', grad_fn=<AddBackward0>)

In [36]:
[gelu_model.tokenizer.decode(token) for token in tokens.squeeze()]

['am',
 'Z',
 'v',
 'Y',
 'X',
 'dq',
 'Z',
 'W',
 '9',
 'p',
 'Z',
 'mp',
 'z',
 'Z',
 'G',
 '9',
 'p',
 'Z',
 '2',
 'h',
 'ha',
 'W',
 '9',
 '3',
 'Z',
 'W',
 'Z',
 'o',
 'Y',
 'X',
 'dv',
 'Z',
 'W',
 'lm',
 'aml',
 'hd',
 '2',
 '9',
 'kc',
 '2',
 'Z',
 'z',
 'ZA',
 'o',
 '=']

In [37]:
predicted_tokens = logits.argmax(-1)
[gelu_model.tokenizer.decode(token) for token in predicted_tokens.squeeze()]

['-',
 'i',
 'am',
 'am',
 'am',
 'am',
 'am',
 'Y',
 '9',
 'Z',
 'Y',
 'y',
 'Y',
 'Y',
 'Z',
 '9',
 'Z',
 'Y',
 '9',
 'Z',
 'Z',
 'Z',
 'p',
 '9',
 'p',
 'Z',
 'Y',
 'Z',
 'Z',
 'Z',
 'p',
 'Y',
 'Z',
 'Z',
 'Z',
 'Y',
 'd',
 'Z',
 'Z',
 '9',
 'p',
 'Y',
 'Z',
 'Z',
 'z']

In [41]:
predicted_tokens = logits.argmax(-1)
[(gelu_model.tokenizer.decode(token), gelu_model.tokenizer.decode(predicted_token)) for (token, predicted_token) in zip(tokens.squeeze(), predicted_tokens.squeeze())]


[('A', 'A'),
 ('AT', 'A'),
 ('TT', 'A'),
 ('GAT', 'A'),
 ('CC', 'A'),
 ('CT', 'A'),
 ('CT', 'A'),
 ('AA', 'AC'),
 ('GG', 'A'),
 ('ACT', 'A'),
 ('AC', 'AA'),
 ('CT', 'A'),
 ('AG', 'A'),
 ('GT', 'A'),
 ('AT', 'A'),
 ('AC', 'CT'),
 ('GG', 'A'),
 ('TAG', 'A'),
 ('AC', 'AT'),
 ('GAC', 'AC'),
 ('CT', 'A'),
 ('AG', 'A'),
 ('C', 'AA'),
 ('GC', 'AG'),
 ('GC', 'AG'),
 ('GT', 'AG'),
 ('AT', 'AG'),
 ('C', 'AA'),
 ('ATT', 'AG'),
 ('CT', 'A'),
 ('AT', 'AC'),
 ('AC', 'CT'),
 ('ACT', 'A'),
 ('AA', 'AC'),
 ('GCC', 'A'),
 ('GC', 'AT'),
 ('GC', 'AG'),
 ('GGT', 'AC'),
 ('AT', 'AC'),
 ('C', 'AA'),
 ('GC', 'AG'),
 ('AT', 'AG'),
 ('AA', 'AC'),
 ('GG', 'A'),
 ('CG', 'AC'),
 ('AA', 'AC'),
 ('GCC', 'AC'),
 ('C', 'GC'),
 ('GT', 'AG'),
 ('AT', 'AG'),
 ('AC', 'AT'),
 ('CC', 'AT'),
 ('GTC', 'AC'),
 ('GT', 'AC'),
 ('GCC', 'AC'),
 ('CA', 'AT'),
 ('CTT', 'AC'),
 ('GA', 'AC'),
 ('CT', 'AC'),
 ('GAG', 'AC'),
 ('TC', 'AC'),
 ('CTT', 'AC'),
 ('GG', 'AC'),
 ('ACT', 'AC'),
 ('GAT', 'AC'),
 ('GT', 'AC'),
 ('AG', 'AC'),
 ('GT

In [27]:
logits = gelu_model(t.Tensor([[59]]).to(int))

In [31]:
gelu_model.tokenizer.decode(188)

'\n'

In [30]:
logits.argmax(-1)

tensor([[188]], device='mps:0')

In [3]:
model = gelu_model
def strings_to_tokens(strings, model):
    tokens = [t.Tensor(model.tokenizer(string)['input_ids']).to(int) for string in strings]
    return t.nn.utils.rnn.pad_sequence(tokens, batch_first=True, padding_value=0)

def decode(logits, model, separate=True):
    if separate:
        return [model.tokenizer.decode(token) for token in logits.argmax(-1).squeeze()]
    else:
        return model.tokenizer.decode(logits.argmax(-1))

strings = ['we are so back lmao', 'omg', 'pad see ew']
input_tokens = strings_to_tokens(strings, model)
print(input_tokens)

tensor([[  662,   403,   593,   888,   299,   781,    81],
        [  298,    73,     0,     0,     0,     0,     0],
        [10737,   915,   300,    89,     0,     0,     0]])


In [76]:
model = gelu_model
logits, cache = model.run_with_cache(input_tokens)
decode(logits, model)

['wewewewewewewe', 'ishingomTheThe###', 'selfpadpadpadThe##']

In [54]:
model.parameters

<bound method Module.parameters of HookedTransformer(
  (embed): Embed()
  (hook_embed): HookPoint()
  (pos_embed): PosEmbed()
  (hook_pos_embed): HookPoint()
  (blocks): ModuleList(
    (0): TransformerBlock(
      (ln1): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (ln2): LayerNormPre(
        (hook_scale): HookPoint()
        (hook_normalized): HookPoint()
      )
      (attn): Attention(
        (hook_k): HookPoint()
        (hook_q): HookPoint()
        (hook_v): HookPoint()
        (hook_z): HookPoint()
        (hook_attn_scores): HookPoint()
        (hook_pattern): HookPoint()
        (hook_result): HookPoint()
      )
      (mlp): MLP(
        (hook_pre): HookPoint()
        (hook_post): HookPoint()
      )
      (hook_attn_in): HookPoint()
      (hook_q_input): HookPoint()
      (hook_k_input): HookPoint()
      (hook_v_input): HookPoint()
      (hook_mlp_in): HookPoint()
      (hook_attn_out): HookPoint()
      (hook_mlp

In [77]:
acts = cache['blocks.0.mlp.hook_post']

In [79]:
decode(logits, model)

['wewewewewewewe', 'ishingomTheThe###', 'selfpadpadpadThe##']

In [81]:
gelu_model.cfg

HookedTransformerConfig:
{'act_fn': 'gelu',
 'attention_dir': 'causal',
 'attn_only': False,
 'attn_types': None,
 'checkpoint_index': None,
 'checkpoint_label_type': None,
 'checkpoint_value': None,
 'd_head': 64,
 'd_mlp': 2048,
 'd_model': 512,
 'd_vocab': 48262,
 'd_vocab_out': 48262,
 'default_prepend_bos': True,
 'device': device(type='mps'),
 'dtype': torch.float32,
 'eps': 1e-05,
 'final_rms': False,
 'from_checkpoint': False,
 'gated_mlp': False,
 'init_mode': 'gpt2',
 'init_weights': False,
 'initializer_range': 0.035355339059327376,
 'model_name': 'GELU_1L512W_C4_Code',
 'n_ctx': 1024,
 'n_devices': 1,
 'n_heads': 8,
 'n_key_value_heads': None,
 'n_layers': 1,
 'n_params': 3145728,
 'normalization_type': 'LNPre',
 'original_architecture': 'neel',
 'parallel_attn_mlp': False,
 'positional_embedding_type': 'standard',
 'post_embedding_ln': False,
 'rotary_adjacent_pairs': False,
 'rotary_base': 10000,
 'rotary_dim': None,
 'scale_attn_by_inverse_layer_idx': False,
 'seed': Non

### SAE

In [2]:
from torch import nn
import einops

sae_input_dim = gelu_model.cfg.d_mlp
sae_hidden_dim = 8 * sae_input_dim
sae_output_dim = sae_input_dim

class SAE(nn.Module):
    
    def __init__(self, 
                 input_dim=sae_input_dim,
                 hidden_dim=sae_hidden_dim,
                 init_range=0.04,
                 ):
        super().__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.init_range = init_range

        self.bias_d = nn.Parameter(t.empty(input_dim))
        self.bias_e = nn.Parameter(t.empty(hidden_dim))
        self.W_e = nn.Parameter(t.empty(hidden_dim, input_dim))
        self.W_d = nn.Parameter(t.empty(input_dim, hidden_dim))
        nn.init.normal_(self.bias_d, std=init_range)
        nn.init.normal_(self.bias_e, std=init_range)
        nn.init.normal_(self.W_e, std=init_range)
        self.activation = nn.functional.relu

    def encode(self, input):
        inputs_centered = input - self.bias_d
        act_input = einops.einsum(inputs_centered, self.W_e,
                                  'batch input_dim, hidden_dim input_dim -> batch hidden_dim')
        act_input = act_input + self.bias_e
        return self.activation(act_input)

    def decode(self, features):
        output = einops.einsum(features, self.W_d,
                               'batch hidden_dim, input_dim hidden_dim -> batch input_dim')
        return output + self.bias_d
        
    def forward(self, input):
        return self.decode(self.encode(input))

In [6]:
sae = SAE().to('mps')

In [94]:
acts_flat = einops.rearrange(acts, 'b s act -> (b s) act')
encodings = sae.encode(acts_flat)
encodings

tensor([[0.0000, 0.1662, 0.0000,  ..., 0.0000, 0.7530, 0.0000],
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 1.2454, 0.0000],
        [0.0000, 0.2556, 0.0000,  ..., 0.0000, 1.8661, 0.0000],
        ...,
        [0.1642, 0.2159, 0.1225,  ..., 0.0000, 1.0308, 0.0000],
        [0.3733, 0.2535, 0.0264,  ..., 0.5381, 0.9750, 0.0000],
        [0.3934, 0.2057, 0.0000,  ..., 0.7450, 0.9278, 0.0000]],
       device='mps:0', grad_fn=<ReluBackward0>)

In [97]:
sae_output = sae(acts_flat)
print(sae_output.shape)

torch.Size([21, 2048])


### Generating activation dataset

In [51]:
def get_loss(acts, sae, l1_coeff=0.006):
    features = sae.encode(acts)
    sae_output = sae.decode(features)
    rec_loss = (acts - sae_output).pow(2).mean()
    l1_loss = l1_coeff * features.abs().mean()
    return (rec_loss + l1_loss)

In [6]:
from torch.utils.data import DataLoader
from datasets import load_dataset

batch_size = 64
dataset_name = 'JeanKaddour/minipile'
train_text = load_dataset(dataset_name, split='train[:1024]')
max_length = 512

def process_row(row):
    tokenized_row = gelu_model.tokenizer(row['text'], max_length=max_length)
    pad_length = max_length - len(tokenized_row['input_ids'])
    tokenized_row['input_ids'] += [gelu_model.tokenizer.pad_token_id] * pad_length
    tokenized_row['attention_mask'] += [1] * pad_length
    tokenized_row['token_type_ids'] += [0] * pad_length
    return tokenized_row


train_dataset = train_text.map(process_row)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

Map:   0%|          | 0/1024 [00:00<?, ? examples/s]

Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. Defaulting to 'longest_first' truncation strategy. If you encode pairs of sequences (GLUE-style) with the tokenizer you can select this strategy more precisely by providing a specific strategy to `truncation`.


In [159]:
print(len(train_dataset[1]['input_ids']))

512


In [7]:
device = t.device('mps')
all_acts = t.tensor([]).to('cpu')
for i, row in enumerate(train_dataloader):
    batch = t.stack(row['input_ids']).T.to(device)
    _ , cache = gelu_model.run_with_cache(batch)
    acts = cache['blocks.0.mlp.hook_post']
    acts = einops.rearrange(acts, 'batch sequence d_mlp -> (batch sequence) d_mlp').to('cpu')
    all_acts = t.concat((all_acts, acts), dim=0)
    

In [10]:
t.save(all_acts.detach(), f='mlp_activations_64bs_512l_minipile1024_gelu1l.pkl')

### Training

In [3]:
all_acts = t.load(f='mlp_activations_64bs_512l_minipile1024_gelu1l.pkl')

rand_indices = t.randperm(all_acts.shape[0])
all_acts = all_acts[rand_indices]
all_acts.shape

torch.Size([524288, 2048])

In [52]:
from tqdm import tqdm

n_epochs = 3
batch_size = 1024
l1_coeff = 0.006

device = t.device('mps')
sae = SAE().to(device)
optimizer = t.optim.Adam(sae.parameters())

losses = []

sae.train()
for i in tqdm(range(n_epochs)):
    n_batches = all_acts.shape[0] // batch_size
    epoch_losses = []
    for i_batch in tqdm(range(n_batches)):
        batch = all_acts[i_batch * batch_size: (i_batch + 1) * batch_size].to(device)
        loss = get_loss(acts=batch, sae=sae, l1_coeff=l1_coeff)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        epoch_losses.append(loss.item())
        if (i_batch%100 == 0 or (i_batch < 100 and i_batch % 10 == 0 and i==0)):
            print(f'{i}/{i_batch}: {loss.item()}')
    losses.append(epoch_losses)


  0%|          | 0/3 [00:00<?, ?it/s]

0/0: 8.264573097229004




0/10: 0.7321810722351074




0/20: 0.2547191381454468




0/30: 0.17331887781620026




0/40: 0.13515867292881012




0/50: 0.11307991296052933




0/60: 0.1077914908528328




0/70: 0.10253268480300903




0/80: 0.09983061999082565




0/90: 0.09604736417531967




0/100: 0.08958395570516586




0/200: 0.07334491610527039




0/300: 0.062314774841070175




0/400: 0.05602551996707916




0/500: 0.05334043875336647


100%|██████████| 512/512 [00:36<00:00, 14.14it/s]
 33%|███▎      | 1/3 [00:36<01:12, 36.22s/it]


1/0: 0.05227159336209297


  1%|          | 4/512 [00:00<00:33, 15.02it/s][A

1/100: 0.04738108813762665




1/200: 0.045205436646938324




1/300: 0.04433976486325264



 79%|███████▉  | 404/512 [00:27<00:07, 14.69it/s]

1/400: 0.042124148458242416


[A

1/500: 0.04332846775650978


100%|██████████| 512/512 [00:34<00:00, 14.82it/s]
 67%|██████▋   | 2/3 [01:10<00:35, 35.24s/it]

2/0: 0.042585015296936035




2/100: 0.040313106030225754




2/200: 0.0392337292432785



 59%|█████▉    | 304/512 [00:20<00:13, 14.94it/s]

2/300: 0.039021700620651245


[A

2/400: 0.037467144429683685




2/500: 0.038338176906108856


100%|██████████| 512/512 [00:34<00:00, 14.91it/s]
100%|██████████| 3/3 [01:45<00:00, 35.04s/it]


In [268]:
import plotly.express as px
all_losses = losses[0] + losses[1] + losses[2]
px.line(all_losses, log_x=True)

In [91]:
sae.eval()
max_ratios = []
mean_ratios = []
square_ratios = []
for test_feature_index in tqdm(range(sae_hidden_dim)):
    test_feature = t.zeros(sae_hidden_dim).to(device)
    test_feature[test_feature_index] = 1
    test_feature.unsqueeze_(0)

    test_mlp_activation = sae.decode(test_feature)
    encoded_features = sae.encode(test_mlp_activation)
    max_ratio = encoded_features[0,test_feature_index]/encoded_features.max()
    mean_ratio = encoded_features[0,test_feature_index]/encoded_features.mean()
    square_ratio = encoded_features[0,test_feature_index].pow(2)/encoded_features.pow(2).mean()
    max_ratios.append(max_ratio.item())
    mean_ratios.append(mean_ratio.item())
    square_ratios.append(square_ratio.item())
# px.line(encoded_features[0].cpu().detach().numpy())


100%|██████████| 16384/16384 [01:50<00:00, 148.36it/s]


In [243]:
px.line(sorted(max_ratios, reverse=True))

In [93]:
px.line(mean_ratios)

In [88]:
ratios[19]

1.0

In [241]:
test_feature = t.zeros(sae_hidden_dim).to(device)
test_feature_index = 39
test_feature[test_feature_index] = 1
test_feature.unsqueeze_(0)

test_mlp_activation = sae.decode(test_feature)
encoded_features = sae.encode(test_mlp_activation)
test_feature_to_max = encoded_features[0,test_feature_index]/encoded_features.max()
px.line(encoded_features[0].cpu().detach().numpy())

In [100]:
print(gelu_model.W_out[0].shape)
mlp_out = einops.einsum(test_mlp_activation, gelu_model.W_out[0],
                        'batch d_mlp, d_mlp d_model -> batch d_model')
logits = einops.einsum(mlp_out, gelu_model.W_U,
                       'batch d_model, d_model d_vocab -> batch d_vocab')
gelu_model.tokenizer.decode(logits.argmax(-1).squeeze())

torch.Size([2048, 512])


'itage'

In [108]:
# px.line(sorted(logits[0].detach().cpu().numpy(), reverse=True))
sorted_logits, sorted_inds = t.sort(logits[0], descending=True)
sorted_inds
for token_ind in sorted_inds[:10]:
    print(gelu_model.tokenizer.decode(token_ind.to(int)))

itage
zel
itt
akers
bra
anc
orthy
ater
orph
eler


In [109]:
t.save(sae, 'sae-0430.pt')

In [110]:
from transformer_lens.utils import get_dataset

dataset = gelu_model.load_sample_training_dataset()

Downloading readme:   0%|          | 0.00/754 [00:00<?, ?B/s]

Downloading data: 100%|██████████| 42.8M/42.8M [00:02<00:00, 19.9MB/s]


Generating train split:   0%|          | 0/20000 [00:00<?, ? examples/s]

In [112]:
gelu_model.cfg.original_architecture

'neel'

In [58]:
px.line((act-reconstruction)[0].detach().cpu().numpy())

### Plugging SAE into transformer

In [142]:
device = t.device('mps')
sae = t.load('sae-0430.pt').to(device)
dataset = t.load('mlp_activations_64bs_512l_minipile1024_gelu1l.pkl').to(device)
act = dataset[0].unsqueeze(0)
feats = sae.encode(act)

In [275]:
input_str = 'Hello from the other side this is incredible'
input_tokens = t.Tensor(gelu_model.tokenizer.encode(input_str)).unsqueeze(0).to(int)
input_str_tokenized = gelu_model.tokenizer.tokenize(input_str)

mlp_acts = []
sae_feats = []
mlp_acts_reconstructed = []

def save_mlp_acts_hook(value, hook):
    mlp_acts.append(value)

def sae_reconstruction_hook(value, hook):
    mlp_acts.append(value)
    batch_size = value.shape[0]
    flat_value = einops.rearrange(value, 'b s d_mlp -> (b s) d_mlp')
    # Encode MLP acts into SAE features
    flat_sae_feats = sae.encode(flat_value)
    reshaped_sae_feats = einops.rearrange(flat_sae_feats, '(b s) hidden_dim -> b s hidden_dim', b=batch_size)
    sae_feats.append(reshaped_sae_feats)
    # Decode SAE features to reconstruct MLP activations
    flat_mlp_acts_reconstructed = sae.decode(flat_sae_feats)
    reshaped_mlp_acts_reconstructed = einops.rearrange(flat_mlp_acts_reconstructed,
                                                       '(b s) d_mlp -> b s d_mlp', b=batch_size)
    mlp_acts_reconstructed.append(reshaped_mlp_acts_reconstructed)
    return reshaped_mlp_acts_reconstructed


gelu_model.eval()
sae.eval()
logits = gelu_model(input_tokens)
logits_sae = gelu_model.run_with_hooks(
    input_tokens,
    fwd_hooks=[
    (
        'blocks.0.mlp.hook_post',
        sae_reconstruction_hook
    )
    ]
)
logits = logits[0]
logits_sae = logits_sae[0]

mlp_acts = mlp_acts[0]
sae_feats = sae_feats[0]
mlp_acts_reconstructed = mlp_acts_reconstructed[0]

In [271]:
def get_top_tokens(logits, model, top_k=5):
    logits = logits[0]
    sorted_logits, sorted_inds = t.sort(logits, descending=True, dim=-1)
    return model.tokenizer.decode(sorted_inds[:, :top_k])
    
get_top_tokens(logits, gelu_model)

TypeError: argument 'ids': 'list' object cannot be interpreted as an integer

In [272]:
def get_top_tokens(logits, model, top_k=5):
    sorted_logits, sorted_inds = t.sort(logits, descending=True, dim=-1)
    top_k = 5
    top_tokens = []
    for seq_pos in range(logits.shape[0]):
        pos_top_tokens = sorted_inds[seq_pos]
        pos_top_tokens_decoded = []
        for top_k_ind in range(top_k):
            pos_top_tokens_decoded.append(model.tokenizer.decode(pos_top_tokens[top_k_ind]))
        top_tokens.append(pos_top_tokens_decoded)
    return top_tokens

In [276]:
top_tokens_orig = get_top_tokens(logits, gelu_model)
top_tokens_recons = get_top_tokens(logits_sae, gelu_model)

In [277]:
for input_str_token, top_k_orig, top_k_recons in zip(input_str_tokenized, 
                                                     top_tokens_orig, 
                                                     top_tokens_recons):
    print(f'Input: {input_str_token}\tOriginal top 5:{top_k_orig}\tReconstructed top 5:{top_k_recons}')

Input: Hello	Original top 5:['0', ' I', ' Thanks', '1', ' Thank']	Reconstructed top 5:['1', 'oth', 'k', 'EE', '2']
Input: Ġfrom	Original top 5:[' my', ' I', '0', ' Welcome', ' 2']	Reconstructed top 5:['k', 't', 'est', 'h', 'self']
Input: Ġthe	Original top 5:[' same', ' I', ' website', ' Welcome', ' name']	Reconstructed top 5:['t', 'k', 'eter', 'self', 'body']
Input: Ġother	Original top 5:[' website', 'Hello', ' course', ' part', ' blog']	Reconstructed top 5:[' is', ' are', '.', 'h', 'le']
Input: Ġside	Original top 5:[' of', ' I', 'Hello', ' and', '0']	Reconstructed top 5:[' are', ' is', ' have', ' has', ' may']
Input: Ġthis	Original top 5:['Hello', ' blog', ' is', 'Hi', ' I']	Reconstructed top 5:['text', ' is', ' are', 'atable', 'le']
Input: Ġis	Original top 5:[' my', 'Hello', ' work', ' the', ' a']	Reconstructed top 5:[' is', ' are', 'i', ' has', 'yl']
Input: Ġincredible	Original top 5:[' and', ' work', 'Hello', ' I', ',']	Reconstructed top 5:['h', 'le', 'body', 'f', 'ff']


In [285]:
crossentropy = t.nn.functional.cross_entropy(input=logits_sae, 
                                             target=logits)
crossentropy

tensor(-215541.6250, device='mps:0', grad_fn=<DivBackward1>)

In [278]:
t.nn.functional.cross_entropy(logits.softmax(-1), logits.softmax(-1))

tensor(10.7419, device='mps:0', grad_fn=<DivBackward1>)

In [234]:
px.line(logits.softmax(-1).detach().cpu().numpy()[0])

In [236]:
px.line(logits_sae.softmax(-1).detach().cpu().numpy()[0])