In [1]:
import os
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
import numpy as np
import pprint
import json
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
import tqdm.notebook as tqdm
import plotly.express as px
import pandas as pd

os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
from neel.imports import *
from neel_plotly import *
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(False)

In IPython
In IPython
Set autoreload
Imported everything!


<torch.autograd.grad_mode.set_grad_enabled at 0x7fc6b1378850>

# Set Up 


In [2]:
# Model Loading

model = HookedTransformer.from_pretrained("gelu-2l",device="cuda")

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

Loaded pretrained model gelu-2l into HookedTransformer


In [3]:
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self


In [4]:
encoder0 = AutoEncoder.load_from_hf("gelu-2l_L0_16384_mlp_out_51", "cuda")
encoder1 = AutoEncoder.load_from_hf("gelu-2l_L1_16384_mlp_out_50", "cuda")

{'act_name': 'blocks.0.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'device': 'cuda',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 0,
 'lr': 0.0001,
 'model_batch_size': 512,
 'model_name': 'gelu-2l',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 51,
 'seq_len': 128,
 'site': 'mlp_out'}
{'act_name': 'blocks.1.hook_mlp_out',
 'act_size': 512,
 'batch_size': 4096,
 'beta1': 0.9,
 'beta2': 0.99,
 'buffer_batches': 12288,
 'buffer_mult': 384,
 'buffer_size': 1572864,
 'device': 'cuda',
 'dict_mult': 32,
 'dict_size': 16384,
 'enc_dtype': 'fp32',
 'l1_coeff': 0.0003,
 'layer': 1,
 'lr': 0.0001,
 'model_batch_size': 512,
 'model_name': 'gelu-2l',
 'num_tokens': 2000000000,
 'remove_rare_dir': False,
 'seed': 50,
 'seq_len': 128,
 'site': 'mlp_out'}


In [5]:
data = load_dataset("NeelNanda/c4-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data[0]

{'tokens': tensor([    1,  6957,  1045,   282,  1074,   561,   843,  2019,   353,   311,
           282, 26952,  3249, 15776, 28316,  9385,  1520,  5082,    16,   188,
          5270,  2019,   311,   353,   282, 26952,  3249, 15776, 28316,  9385,
          1520,  5082,    33,   188, 17073,   479,   668,  4149,   282, 26952,
          3249, 15776, 28316,  9385,  1520,  5082,    16,   310,   807,  3061,
           889,   248,  8388,   390,   248,  1922,  4741,   274, 26952,  3249,
            16,   188,  2433,   626,  9564,   282, 15813,   248,  1096,   276,
         26952,  3249,    33,   188,    43,   650,   669,   321,  5940,   276,
          4902,   282,  5838,   276, 26952,  3249, 15776, 28316,  9385,  1520,
          5082,    16,     0, 11410,   308,  1326,   311,   769, 38830,   666,
         27157,   900,   324,   254, 47665,  1914,    16,   826,  2766,   555,
          4311,  5164,    14,   286, 30585,   353,  1042,   417,   452, 39619,
            16,   731,  4288,  1638,   303

## Co-Occurence investigation

In [6]:

def get_feature_acts(tokenized_data, point, layer, sae, num_batches = 1000, minibatch_size = 50, device = "cuda"):
  try:
    del feature_acts
    del random_feature_acts
  except NameError:
    pass

  # get however many tokens we need
  toks = tokenized_data["tokens"][:num_batches]

  # get activations on test tokens at point of interest. Run model on batches of tokens with size [batch_size, 128]. Be careful with RAM.

  for i in tqdm.tqdm(range(toks.size(0)//minibatch_size)):
    # split toks into minibatch and run model with cache on minibatch
    toks_batch = toks[minibatch_size*i : minibatch_size*(i+1), :].to(device)
    logits, cache = model.run_with_cache(toks_batch, stop_at_layer=layer+1, names_filter=utils.get_act_name(point, layer))
    del logits

    act_batch = cache[point, layer]
    act_batch = act_batch.detach().to(device)
    del cache

    # get feature acts on this minibatch (fewer random ones to save RAM)
    feature_act_batch = torch.relu(einops.einsum(act_batch - sae.b_dec, sae.W_enc, "batch seq resid , resid mlp -> batch seq mlp")  + sae.b_enc)

    del act_batch

    # append minibatch feature acts to storage variable
    if i == 0:  # on first iteration, create feature_acts
      feature_acts = feature_act_batch
    else:  # then add to it
      feature_acts = torch.cat([feature_acts, feature_act_batch], dim=0)

    del feature_act_batch

  # set BOS acts to zero
  feature_acts[:, 0, :] = 0

  # flatten [batch n_seq] dimensions
  #feature_acts = feature_acts.reshape(-1, feature_acts.size(2))

  print("feature_acts has size:", feature_acts.size())

  return toks, feature_acts

def compute_cooccurrences(hidden_acts, device="cuda"):
    try:
        hidden_acts = hidden_acts[:, 1:, :]
        hidden_acts = einops.rearrange(hidden_acts, "batch pos d_enc -> (batch pos) d_enc")
    except Exception as e:
        print("FAILED:", e)
        return None

    hidden_is_pos = hidden_acts > 0
    d_enc = hidden_acts.shape[-1]

    cooccur_count = torch.zeros((d_enc, d_enc), device=device, dtype=torch.float32)
    for end_i in tqdm.trange(d_enc):
        cooccur_count[:, end_i] = hidden_is_pos[hidden_is_pos[:, end_i]].float().sum(0)

    num_firings = hidden_is_pos.sum(0)
    cooccur_freq = cooccur_count / torch.maximum(num_firings[:, None], num_firings[None, :])
    cooccur_freq[cooccur_freq.isnan()] = 0.

    cooccur_freq.fill_diagonal_(0.)

    return cooccur_freq


toks, acts, = get_feature_acts(tokenized_data, "mlp_out", 1, encoder1, num_batches=768, minibatch_size=64)
cooccur_table = compute_cooccurrences(acts)

  0%|          | 0/12 [00:00<?, ?it/s]

feature_acts has size: torch.Size([768, 128, 16384])


  0%|          | 0/16384 [00:00<?, ?it/s]

In [7]:
torch.cuda.empty_cache()

In [8]:
df = pd.DataFrame(cooccur_table.detach().cpu().numpy())
df = df.stack().reset_index()
mean_acts = (acts > 0).float().mean(dim=(0,1))
print(mean_acts.shape)
d_enc = mean_acts.shape[-1]
df["mean_act_feature1"] = mean_acts.detach().cpu().repeat(d_enc)
df["mean_act_feature2"] = mean_acts.detach().cpu().repeat(1, d_enc).flatten()
df.columns = ["feature1", "feature2", "cooccur_freq", "mean_act_feature1", "mean_act_feature2"]
df = df[df["cooccur_freq"] > 0.3]
print(df.shape)
df.head()

torch.Size([16384])
(4280, 5)


Unnamed: 0,feature1,feature2,cooccur_freq,mean_act_feature1,mean_act_feature2
16533,1,149,0.346626,0.003316,0.003316
18680,1,2296,0.325444,0.003438,0.003438
19885,1,3501,0.3769,0.003347,0.003347
21311,1,4927,0.302521,0.004842,0.004842
21606,1,5222,0.32853,0.00353,0.00353


In [9]:
df["max_feature_act"] = df[["mean_act_feature1", "mean_act_feature2"]].max(axis=1)
df["min_feature_act"] = df[["mean_act_feature1", "mean_act_feature2"]].min(axis=1)

In [28]:
px.scatter(df.sample(1000), x="min_feature_act", y="cooccur_freq", color="cooccur_freq", hover_data=["feature1", "feature2"],
           log_x=True, log_y=True, opacity=0.9, color_continuous_scale="Blues", template="plotly_dark").show(),


(None,)

In [49]:
df[df.feature1==2448].sort_values("cooccur_freq", ascending=False).head(10)

Unnamed: 0,feature1,feature2,cooccur_freq,mean_act_feature1,mean_act_feature2,max_feature_act,min_feature_act
40114498,2448,6466,0.37766,0.001801,0.001801,0.001801,0.001801
40122699,2448,14667,0.356383,0.001516,0.001516,0.001516,0.001516
40116991,2448,8959,0.303191,0.001119,0.001119,0.001119,0.001119


In [51]:
acts = acts.view(-1, acts.size(-1))

In [50]:

feature_id_1 = 2448
feature_id_2 = 14667
feature_id_3 = 6466
token_df = nutils.make_token_df(toks, len_suffix=5)
token_df["val1"] = to_numpy(acts[:, feature_id_1])
token_df["val2"] = to_numpy(acts[:, feature_id_2])
token_df["val3"] = to_numpy(acts[:, feature_id_3])
pd.set_option('display.max_rows', 40)
display(token_df.sort_values("val1", ascending=False).head(20))
display(token_df.sort_values("val2", ascending=False).head(20))

Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,val1,val2,val3
75861,when,when/85,"person, or looks different| when| placed next...",592,85,592/85,6.521091,0.0,1.77371
82066,when,when/18,sure Terry* was skeptical| when| we started t...,641,18,641/18,6.271078,3.29708,4.011731
23895,when,when/87,The history is additionally crucial| when| dis...,186,87,186/87,6.235761,1.697388,0.292801
31801,when,when/57,challenged my faith so much| when| I was in t...,248,57,248/57,5.007453,2.170938,2.137739
53979,when,when/91,<|EOS|>Students face unique challenges| when| ...,421,91,421/91,4.67743,3.286412,0.0
11240,when,when/104,than the sum of both| when| viewed apart.\nSo,87,104,87/104,4.302068,0.0,0.566836
9374,when,when/30,of language and its conventions| when| writin...,73,30,73/30,4.208656,1.207078,0.0
15824,when,when/80,which apply concentrated nutritional medicine...,123,80,123/80,4.009836,0.670157,0.0
40026,when,when/90,upholstery! Especially| when| it comes at a p...,312,90,312/90,3.882013,0.0,0.0
70984,when,when/72,the can get pretty nasty| when| they are corn...,554,72,554/72,3.818575,0.0,5.152675


Unnamed: 0,str_tokens,unique_token,context,batch,pos,label,val1,val2,val3
63828,when,when/84,arrested Tuesday in King County| when| he was...,498,84,498/84,0.0,10.715939,0.0
62075,when,when/123,"ously successful first season,| when| Netflix'...",484,123,484/123,0.84266,9.879725,0.0
36180,when,when/84,"couple’s history,| when| they are no longer s...",282,84,282/84,0.0,8.923348,1.571835
27352,when,when/88,since 2010| when| the Hay Family Racing team,213,88,213/88,1.395564,6.986862,0.706428
37570,when,when/66,cook make food for him| when| he came down wi...,293,66,293/66,0.373414,6.629521,4.489678
64544,when,when/32,young man who started bowling| when| he was 8...,504,32,504/32,0.566542,5.981678,3.093635
46234,when,when/26,"yan, 26 -| when| they passed a table where",361,26,361/26,0.0,5.921212,4.416834
30372,When,When/36,could trust the Lord.| When| the pagans saw t...,237,36,237/36,0.661695,5.255411,0.836557
5064,when,when/72,.\nIt’s| when| people start obsessing over,39,72,39/72,1.751745,5.200799,0.394058
75908,when,when/4,<|BOS|>\nThis is| when| a design truly stands out,593,4,593/4,0.941269,5.030762,0.413388


In [52]:
# px.scatter(token_df, x="val1", y="val2", hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2")

# filter for val1 positive
tmp = token_df[(token_df.val1 > 0 ) | (token_df.val2 > 0) | (token_df.val3 > 0)]
fig = px.scatter_3d(
    tmp, x="val1", y="val2", z="val3", 
    opacity = 0.4,
    hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2 vs Feature 3")
fig.update_traces(marker=dict(size=5))
# make it much larger
fig.update_layout(width=800, height=800)
fig.show()
fig.write_html("bos_tokens.html")

### Looking at Co-Occurence graphically



In [None]:
feature_id_1 = 12989
feature_id_2 = 14809
feature_id_3 = 15690
token_df = nutils.make_token_df(example_tokens).query("pos>=1")
token_df["0_val1"] = to_numpy(hidden_acts0[:, feature_id_1])
token_df["0_val2"] = to_numpy(hidden_acts0[:, feature_id_2])
token_df["0_val3"] = to_numpy(hidden_acts0[:, feature_id_3])
feature_id_1 = 1245
feature_id_2 = 1353
feature_id_3 = 9604
token_df["1_val1"] = to_numpy(hidden_acts1[:, feature_id_1])
token_df["1_val2"] = to_numpy(hidden_acts1[:, feature_id_2])
token_df["1_val3"] = to_numpy(hidden_acts1[:, feature_id_3])

fig = px.scatter_matrix(
    token_df,
    dimensions=["0_val1", "0_val2", "0_val3", "1_val1", "1_val2", "1_val3"],
    # color="species"
    hover_data=["str_tokens", "context"],
)
fig.update_traces(marker=dict(size=5))
# make it much larger
fig.update_layout(width=1200, height=1200)
fig.show()

# fig = px.scatter_3d(
#     token_df, x="0_val1", y="0_val2", z="0_val3", 
#     color="1_val3",
#     hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2 vs Feature 3")
# fig.update_traces(marker=dict(size=5))
# # make it much larger
# fig.update_layout(width=800, height=800)
# fig.show()

### Co-Occurence between SAE's.

In [None]:
# Co-Occur between SAEs

d_enc = hidden_acts0.shape[-1]
cooccur_count = torch.zeros((d_enc, d_enc), device="cuda", dtype=torch.float32)
for end_i in tqdm.trange(d_enc):
    cooccur_count[:, end_i] = hidden_is_pos0[hidden_is_pos1[:, end_i]].float().sum(0)
# %%
num_firings0 = hidden_is_pos0.sum(0)
num_firings1 = hidden_is_pos1.sum(0)
cooccur_freq = cooccur_count / torch.maximum(num_firings0[:, None], num_firings1[None, :])
# %%
# cooccur_count = cooccur_count.float() / hidden_acts0.shape[0]
# %%
histogram(cooccur_freq[cooccur_freq>0.1], log_y=True)


In [None]:

# %%
cooccur_freq[cooccur_freq.isnan()] = 0.
val, ind = cooccur_freq.flatten().topk(10)

start_topk_ind = (ind // d_enc)
end_topk_ind = (ind % d_enc)
print(val)
print(start_topk_ind)
print(end_topk_ind)

print("earlier_features", num_firings0[start_topk_ind])
print("later features", num_firings1[end_topk_ind])


###  Dig Into a specific Feature

In [None]:
# %%
feature_id = 2593
layer = 0
token_df = nutils.make_token_df(example_tokens).query("pos>=1")
token_df["val"] = to_numpy(hidden_acts1[:, feature_id])
pd.set_option('display.max_rows', 50)
display(token_df.sort_values("val", ascending=False).head(50))


In [None]:
logit_weights = encoder1.W_dec[feature_id, :] @ model.W_U
vocab_df = nutils.create_vocab_df(logit_weights)
vocab_df["has_space"] = vocab_df["token"].apply(lambda x: nutils.SPACE in x)
px.histogram(vocab_df, x="logit", color="has_space", barmode="overlay", marginal="box")
display(vocab_df.head(20))

In [None]:
is_an = example_tokens[:, 1:].flatten() == model.to_single_token(" an")
line(hidden_is_pos0[is_an].float().mean(0))
line(hidden_is_pos1[is_an].float().mean(0))

In [None]:
scatter(x=virtual_weights[5740], y=hidden_is_pos1[is_an].float().mean(0), hover=np.arange(d_enc))

In [None]:
l0_feature_an = 5740
line(hidden_acts0[5740])
replacement_for_feature = torch.zeros_like(example_tokens).cuda()
replacement_for_feature[:, 1:] = hidden_acts0[:, 5740].reshape(600, 127)
mlp_out0_diff = replacement_for_feature[:, :, None] * encoder0.W_dec[l0_feature_an]
new_end_topk = end_topk_ind[start_topk_ind==5740]

### Causal Intervention

In [None]:

def remove_an_feature(mlp_out, hook):
    mlp_out[:, :] -= mlp_out0_diff
    return mlp_out
    
model.reset_hooks()
model.blocks[0].hook_mlp_out.add_hook(remove_an_feature)
_, new_cache = model.run_with_cache(example_tokens, stop_at_layer=2, names_filter=lambda x: "mlp_out" in x)
with torch.no_grad():
    loss, x_reconstruct, hidden_acts_new, l2_loss, l1_loss = encoder1(new_cache["mlp_out", 1])
model.reset_hooks()
# %%
new_hidden_acts_on_an = hidden_acts_new[:, :, new_end_topk].reshape(-1, 5)[example_tokens.flatten()==model.to_single_token(" an"), :]
old_hidden_acts_on_an = hidden_acts1[:, new_end_topk][example_tokens[:, 1:].flatten()==model.to_single_token(" an"), :]
for i in range(5):
    scatter(x=old_hidden_acts_on_an[:, i], y=new_hidden_acts_on_an[:, i], hover=np.arange(180), title=new_end_topk[i].item(), include_diag=True, yaxis="POst Ablation", xaxis="Pre Ablation")
# %%
