In [1]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch as t
import numpy as np
import gradio as gr

import torch as t
#from google.colab import drive

# This will prompt for authorization.
#drive.mount('/content/drive')

import einops
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import tqdm
from functools import partial
from datasets import load_dataset
from IPython.display import display

In [2]:
# load gpt2-small
model = HookedTransformer.from_pretrained("gpt2-small").to('mps')

Loaded pretrained model gpt2-small into HookedTransformer
Moving model to device:  mps


In [3]:
data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(22)



In [4]:
import torch 
from pathlib import Path
import torch.nn as nn
import pprint
import json 
import torch.nn.functional as F

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self


In [5]:
point, layer = "resid_pre", 9
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{point}_{layer}.pt", force_is_torch=True)
W_dec , b_dec, W_enc, b_enc = dic["W_dec"], dic["b_dec"], dic["W_enc"], dic["b_enc"]

In [6]:
dic.keys()

odict_keys(['W_enc', 'W_dec', 'b_enc', 'b_dec'])

In [7]:
cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": "mps",
    "model_batch_size": 1028,
}
encoder_resid_pre_9 = AutoEncoder(cfg)
encoder_resid_pre_9.load_state_dict(dic)

<All keys matched successfully>

In [8]:
with torch.no_grad():
    example_tokens = tokenized_data[:200]["tokens"].to("mps")
    logits, cache = model.run_with_cache(example_tokens)
    per_token_loss = model.loss_fn(logits, example_tokens, True)
    # imshow(per_token_loss)

RuntimeError: MPS backend out of memory (MPS allocated: 32.95 GB, other allocations: 2.25 GB, max allowed: 36.27 GB). Tried to allocate 4.76 GB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [None]:
# get the reconstruction loss

def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@t.no_grad()
def get_recons_loss(encoder, all_tokens, num_batches=5, local_encoder=None):
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[t.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("resid_pre", 9), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("resid_pre", 9), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = t.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss

local_encoder = encoder_resid_pre_9
get_recons_loss(encoder_resid_pre_9, example_tokens, num_batches=5, local_encoder=local_encoder)

KeyboardInterrupt: 

In [None]:
all_tokens = example_tokens