In [None]:
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch as t
import numpy as np
import gradio as gr

import torch as t
#from google.colab import drive

# This will prompt for authorization.
#drive.mount('/content/drive')

import einops
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import tqdm
from functools import partial
from datasets import load_dataset
from IPython.display import display

In [None]:
# load gpt2-small
model = HookedTransformer.from_pretrained("gpt2-small").to('cuda')

data = load_dataset("stas/openwebtext-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data = tokenized_data.shuffle(22)

In [None]:
import torch 
from pathlib import Path
import torch.nn as nn
import pprint
import json 
import torch.nn.functional as F

DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self


In [None]:
hook_point, layer = "resid_pre", 9
dic = utils.download_file_from_hf("jacobcd52/gpt2-small-sparse-autoencoders", f"gpt2-small_6144_{hook_point}_{layer}.pt", force_is_torch=True)
W_dec , b_dec, W_enc, b_enc = dic["W_dec"], dic["b_dec"], dic["W_enc"], dic["b_enc"]

cfg = {
    "dict_size": 6144,
    "act_size": 768,
    "l1_coeff": 0.001,
    "enc_dtype": "fp32",
    "seed": 0,
    "device": "cuda",
    "model_batch_size": 1028,
}
encoder_resid_pre_9 = AutoEncoder(cfg)
encoder_resid_pre_9.load_state_dict(dic)

In [None]:
# get the reconstruction loss

def replacement_hook(mlp_post, hook, encoder):
    mlp_post_reconstr = encoder(mlp_post)[1]
    return mlp_post_reconstr

def mean_ablate_hook(mlp_post, hook):
    mlp_post[:] = mlp_post.mean([0, 1])
    return mlp_post

def zero_ablate_hook(mlp_post, hook):
    mlp_post[:] = 0.
    return mlp_post

@t.no_grad()
def get_recons_loss(encoder, all_tokens, num_batches=5, local_encoder=None):
    loss_list = []
    for i in range(num_batches):
        tokens = all_tokens[t.randperm(len(all_tokens))[:cfg["model_batch_size"]]]
        loss = model(tokens, return_type="loss")
        recons_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("resid_pre", 9), partial(replacement_hook, encoder=local_encoder))])
        # mean_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("post", 0), mean_ablate_hook)])
        zero_abl_loss = model.run_with_hooks(tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("resid_pre", 9), zero_ablate_hook)])
        loss_list.append((loss, recons_loss, zero_abl_loss))
    losses = t.tensor(loss_list)
    loss, recons_loss, zero_abl_loss = losses.mean(0).tolist()

    print(f"loss: {loss:.4f}, recons_loss: {recons_loss:.4f}, zero_abl_loss: {zero_abl_loss:.4f}")
    score = ((zero_abl_loss - recons_loss)/(zero_abl_loss - loss))
    print(f"Reconstruction Score: {score:.2%}")
    # print(f"{((zero_abl_loss - mean_abl_loss)/(zero_abl_loss - loss)).item():.2%}")
    return score, loss, recons_loss, zero_abl_loss

local_encoder = encoder_resid_pre_9

with torch.no_grad():
    example_tokens = tokenized_data[:200]["tokens"].to("cuda")
    logits, cache = model.run_with_cache(example_tokens)
    per_token_loss = model.loss_fn(logits, example_tokens, True)
    # imshow(per_token_loss)

get_recons_loss(encoder_resid_pre_9, example_tokens, num_batches=5, local_encoder=local_encoder)

In [None]:
def get_feature_acts(point, layer, dic, num_batches = 1000, minibatch_size = 50):
  try:
    del feature_acts
    del random_feature_acts
  except NameError:
    pass

  # get however many tokens we need
  toks = tokenized_data["tokens"][:num_batches]
  toks = toks.to("cuda")

  # get activations on test tokens at point of interest. Run model on batches of tokens with size [batch_size, 128]. Be careful with RAM.

  random_W_enc = t.randn( dic["W_enc"].size() ).cuda()

  for i in tqdm.tqdm(range(toks.size(0)//minibatch_size)):
    # split toks into minibatch and run model with cache on minibatch
    toks_batch = toks[minibatch_size*i : minibatch_size*(i+1), :]
    logits, cache = model.run_with_cache(toks_batch, stop_at_layer=layer+1, names_filter=utils.get_act_name(point, layer))
    del logits

    act_batch = cache[point, layer]
    act_batch = act_batch.detach().cuda()
    del cache

    # get feature acts and random feature acts on this minibatch (fewer random ones to save RAM)
    feature_act_batch = t.relu(einops.einsum(act_batch - dic["b_dec"], dic["W_enc"], "batch seq resid , resid mlp -> batch seq mlp")  + dic["b_enc"])

    random_feature_act_batch = t.relu(einops.einsum(act_batch[:10] - dic["b_dec"], random_W_enc, "batch seq resid , resid mlp -> batch seq mlp")  + dic["b_enc"])
    random_feature_act_batch = random_feature_act_batch / random_feature_act_batch.norm(dim=-1, keepdim=True) * feature_act_batch[:10].norm(dim=-1, keepdim=True)  #fix normalisation
    del act_batch

    # append minibatch feature acts to storage variable
    if i == 0:  # on first iteration, create feature_acts
      feature_acts = feature_act_batch
      random_feature_acts = random_feature_act_batch
    else:  # then add to it
      feature_acts = t.cat([feature_acts, feature_act_batch], dim=0)
      random_feature_acts = t.cat([random_feature_acts, random_feature_act_batch], dim=0)

    del feature_act_batch
    del random_feature_act_batch

  # set BOS acts to zero
  feature_acts[:, 0, :] = 0
  random_feature_acts[:, 0, :] = 0

  # flatten [batch n_seq] dimensions
  feature_acts = feature_acts.reshape(-1, feature_acts.size(2))
  random_feature_acts = random_feature_acts.reshape(-1, random_feature_acts.size(2))

  print("feature_acts has size:", feature_acts.size())

  return toks, feature_acts, random_feature_acts


toks, feature_acts, random_feature_acts = get_feature_acts("resid_pre", 9, encoder_resid_pre_9.state_dict(), num_batches = 1000, minibatch_size = 50)

In [None]:
l0_metric = (feature_acts > 0).float().mean(0) * 128
px.histogram(l0_metric.cpu().numpy(), nbins=100, title="L0 Metric")

In [None]:
feature_density = (feature_acts > 0).float().mean(0).detach().cpu()
feature_density_log = np.log10(feature_density.numpy() + 1e-8)
px.histogram(feature_density_log, title="Feature Density", labels={"value": "Density", "variable": "Feature"})

In [None]:
hidden_is_pos0 = feature_acts > 0
d_enc = feature_acts.shape[-1]
cooccur_count = torch.zeros((d_enc, d_enc), device="cuda", dtype=torch.float32)
for end_i in tqdm.trange(d_enc):
    cooccur_count[:, end_i] = hidden_is_pos0[hidden_is_pos0[:, end_i]].float().sum(0)
# %%
num_firings0 = hidden_is_pos0.sum(0)
cooccur_freq = cooccur_count / torch.maximum(num_firings0[:, None], num_firings0[None, :])


In [None]:
# set nan to 0
cooccur_freq[cooccur_freq.isnan()] = 0.
# set diag to 0
cooccur_freq.fill_diagonal_(0)
# fill dense features with 0 
cooccur_freq[feature_density_log > -2,:] = 0
# cooccur_freq[:] = 0.

val, ind = cooccur_freq.flatten().topk(500)


start_topk_ind = (ind // d_enc)
other_topk_ind = (ind % d_enc)


df = pd.DataFrame(
    {
        "start_topk_ind": start_topk_ind.cpu().numpy(),
        "other_topk_ind": other_topk_ind.cpu().numpy(),
        "val": val.cpu().numpy(),
        "earlier_features": num_firings0[start_topk_ind].cpu().numpy(),
        "other_features": num_firings0[other_topk_ind].cpu().numpy(),
        "start_top_k_density": feature_density[start_topk_ind.cpu()].cpu().numpy(),
        "other_top_k_density": feature_density[other_topk_ind.cpu()].cpu().numpy(),
        "feature_density_ratio": feature_density[start_topk_ind.cpu()].cpu().numpy() / feature_density[other_topk_ind.cpu()].cpu().numpy(),
    }
)

df.sample(40)
# # %%

In [None]:
# %%
# cooccur_count = cooccur_count.float() / hidden_acts0.shape[0]
# %%
fig = px.histogram(cooccur_freq[cooccur_freq>0.3].detach().cpu(), log_y=True)
fig.show()

In [None]:
feature_acts.shape

In [None]:
token_df = nutils.make_token_df(toks)

In [121]:
import  neel.utils as nutils
feature_id = 2271	
token_df["val1"] = (feature_acts[:, feature_id]).detach().cpu().numpy()
feature_id = 2158	
token_df["val2"] = (feature_acts[:, feature_id]).detach().cpu().numpy()
pd.set_option('display.max_rows', 50)
display(token_df.sort_values("val1", ascending=False).head(30))
display(token_df.sort_values("val2", ascending=False).head(30))

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,val,val1,val2
40595,This,This/19,"academic literature.\n\n|This|,",317,19,317/19,41.937229,41.937229,1.09469
33549,This,This/13,every citizen.\n\n|This| article,262,13,262/13,41.79977,41.79977,1.875313
53424,This,This/48,search query.\n\n|This| is,417,48,417/48,40.419174,40.419174,1.086643
28693,This,This/21,�ites.\n\n|This| is,224,21,224/21,40.028366,40.028366,2.84594
18174,This,This/126,"to work.""\n\n|This|",141,126,141/126,39.188549,39.188549,4.416582
88655,This,This/79,with 5 percent undecided.| This| is,692,79,692/79,38.682384,38.682384,0.011215
68906,This,This/42,��s arms.| This| is,538,42,538/42,38.680344,38.680344,1.351418
81508,This,This/100,"a mountain"".\n\n|This| evidently",636,100,636/100,38.573799,38.573799,0.353576
96352,This,This/96,ari said.\n\n|This| year,752,96,752/96,38.561771,38.561771,3.267462
113590,This,This/54,criminal code.\n\n|This| May,887,54,887/54,38.256516,38.256516,1.771352


Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,val,val1,val2
90132,this,this/20,subjected to an attack on| this| scale,704,20,704/20,0.0,0.0,25.313328
73814,this,this/86,we near a conclusion to| this| evaluation,576,86,576/86,2.233396,2.233396,24.943745
22199,this,this/55,rural and urban areas in| this| respect,173,55,173/55,0.654719,0.654719,23.604286
94766,this,this/46,the legislation to include in| this| database,740,46,740/46,0.0,0.0,23.482731
105736,this,this/8,publishers and exchanges centered on| this| last,826,8,826/8,2.496266,2.496266,23.41185
109204,this,this/20,on Twitter with comments on| this| nearly,853,20,853/20,4.13606,4.13606,23.355202
108427,this,this/11,are with those affected by| this| horrific,847,11,847/11,2.344148,2.344148,22.97123
121423,this,this/79,very little to do with| this| technology,948,79,948/79,0.0,0.0,22.868614
27384,this,this/120,a far bigger role in| this| election,213,120,213/120,0.0,0.0,22.755737
63950,this,this/78,The only other pages from| this| book,499,78,499/78,3.59036,3.59036,22.754097


In [119]:
px.scatter(token_df[(token_df.val1 >0) | (token_df.val2 >0) ], x="val1", y="val2", hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2")