In [None]:
import os
import transformer_lens
from transformer_lens import HookedTransformer, utils
import torch
import numpy as np
import pprint
import json
import torch.nn as nn
import torch.nn.functional as F
from datasets import load_dataset
from huggingface_hub import HfApi
from IPython.display import HTML
from functools import partial
import tqdm.notebook as tqdm
import plotly.express as px
import pandas as pd

os.environ["TRANSFORMERS_CACHE"] = "/workspace/cache/"
from neel.imports import *
from neel_plotly import *
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
torch.set_grad_enabled(False)

# Set Up 


In [None]:
# Model Loading

model = HookedTransformer.from_pretrained("gelu-2l")

n_layers = model.cfg.n_layers
d_model = model.cfg.d_model
n_heads = model.cfg.n_heads
d_head = model.cfg.d_head
d_mlp = model.cfg.d_mlp
d_vocab = model.cfg.d_vocab

In [None]:
DTYPES = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}
SAVE_DIR = Path("/workspace/1L-Sparse-Autoencoder/checkpoints")
class AutoEncoder(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        d_hidden = cfg["dict_size"]
        l1_coeff = cfg["l1_coeff"]
        dtype = DTYPES[cfg["enc_dtype"]]
        torch.manual_seed(cfg["seed"])
        self.W_enc = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(cfg["act_size"], d_hidden, dtype=dtype)))
        self.W_dec = nn.Parameter(torch.nn.init.kaiming_uniform_(torch.empty(d_hidden, cfg["act_size"], dtype=dtype)))
        self.b_enc = nn.Parameter(torch.zeros(d_hidden, dtype=dtype))
        self.b_dec = nn.Parameter(torch.zeros(cfg["act_size"], dtype=dtype))

        self.W_dec.data[:] = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)

        self.d_hidden = d_hidden
        self.l1_coeff = l1_coeff

        self.to(cfg["device"])
    
    def forward(self, x):
        x_cent = x - self.b_dec
        acts = F.relu(x_cent @ self.W_enc + self.b_enc)
        x_reconstruct = acts @ self.W_dec + self.b_dec
        l2_loss = (x_reconstruct.float() - x.float()).pow(2).sum(-1).mean(0)
        l1_loss = self.l1_coeff * (acts.float().abs().sum())
        loss = l2_loss + l1_loss
        return loss, x_reconstruct, acts, l2_loss, l1_loss
    
    @torch.no_grad()
    def make_decoder_weights_and_grad_unit_norm(self):
        W_dec_normed = self.W_dec / self.W_dec.norm(dim=-1, keepdim=True)
        W_dec_grad_proj = (self.W_dec.grad * W_dec_normed).sum(-1, keepdim=True) * W_dec_normed
        self.W_dec.grad -= W_dec_grad_proj
        # Bugfix(?) for ensuring W_dec retains unit norm, this was not there when I trained my original autoencoders.
        self.W_dec.data = W_dec_normed
    
    def get_version(self):
        version_list = [int(file.name.split(".")[0]) for file in list(SAVE_DIR.iterdir()) if "pt" in str(file)]
        if len(version_list):
            return 1+max(version_list)
        else:
            return 0

    def save(self):
        version = self.get_version()
        torch.save(self.state_dict(), SAVE_DIR/(str(version)+".pt"))
        with open(SAVE_DIR/(str(version)+"_cfg.json"), "w") as f:
            json.dump(cfg, f)
        print("Saved as version", version)
    
    @classmethod
    def load(cls, version):
        cfg = (json.load(open(SAVE_DIR/(str(version)+"_cfg.json"), "r")))
        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(torch.load(SAVE_DIR/(str(version)+".pt")))
        return self

    @classmethod
    def load_from_hf(cls, version, device_override=None):
        """
        Loads the saved autoencoder from HuggingFace. 
        
        Version is expected to be an int, or "run1" or "run2"

        version 25 is the final checkpoint of the first autoencoder run,
        version 47 is the final checkpoint of the second autoencoder run.
        """
        if version=="run1":
            version = 25
        elif version=="run2":
            version = 47
        
        cfg = utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}_cfg.json")
        if device_override is not None:
            cfg["device"] = device_override

        pprint.pprint(cfg)
        self = cls(cfg=cfg)
        self.load_state_dict(utils.download_file_from_hf("NeelNanda/sparse_autoencoder", f"{version}.pt", force_is_torch=True))
        return self


In [None]:
encoder0 = AutoEncoder.load_from_hf("gelu-2l_L0_16384_mlp_out_51", "cuda")
encoder1 = AutoEncoder.load_from_hf("gelu-2l_L1_16384_mlp_out_50", "cuda")

In [None]:

data = load_dataset("NeelNanda/c4-10k", split="train")
tokenized_data = utils.tokenize_and_concatenate(data, model.tokenizer, max_length=128)
tokenized_data[0]

In [None]:
import plotly.io as pio
pio.renderers.default = "vscode"

In [None]:
# with torch.no_grad():
#     example_tokens = tokenized_data[:200]["tokens"]
#     logits, cache = model.run_with_cache(example_tokens)
#     per_token_loss = model.loss_fn(logits, example_tokens, True)
#     imshow(per_token_loss)

## Out-take

In [None]:

# original_mlp_out = cache["mlp_out", 1]
# loss, reconstr_mlp_out, hidden_acts, l2_loss, l1_loss = encoder1(original_mlp_out)
# def reconstr_hook(mlp_out, hook, new_mlp_out):
#     return new_mlp_out
# def zero_abl_hook(mlp_out, hook):
#     return torch.zeros_like(mlp_out)
# print("reconstr", model.run_with_hooks(example_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", 1), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out))], return_type="loss"))
# print("Orig", model(example_tokens, return_type="loss"))
# print("Zero", model.run_with_hooks(example_tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("mlp_out", 1), zero_abl_hook)]))
# # %%

# original_mlp_out = cache["mlp_out", 0]
# loss, reconstr_mlp_out, hidden_acts, l2_loss, l1_loss = encoder0(original_mlp_out)
# def reconstr_hook(mlp_out, hook, new_mlp_out):
#     return new_mlp_out
# def zero_abl_hook(mlp_out, hook):
#     return torch.zeros_like(mlp_out)
# print("reconstr", model.run_with_hooks(example_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", 0), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out))], return_type="loss"))
# print("Orig", model(example_tokens, return_type="loss"))
# print("Zero", model.run_with_hooks(example_tokens, return_type="loss", fwd_hooks=[(utils.get_act_name("mlp_out", 0), zero_abl_hook)]))
# # %%
# orig_logits = model(example_tokens)
# orig_ptl = model.loss_fn(orig_logits, example_tokens, True)

# zero_logits = model.run_with_hooks(example_tokens, return_type="logits", fwd_hooks=[(utils.get_act_name("mlp_out", 0), zero_abl_hook)])
# zero_ptl = model.loss_fn(zero_logits, example_tokens, True)

# recons_logits = model.run_with_hooks(example_tokens, fwd_hooks=[(utils.get_act_name("mlp_out", 0), partial(reconstr_hook, new_mlp_out=reconstr_mlp_out))], return_type="logits")
# recons_ptl = model.loss_fn(recons_logits, example_tokens, True)
# # %%
# histogram(recons_ptl.flatten())
# # %%
# scatter(x=(recons_ptl-orig_ptl).flatten(), y=(zero_ptl-orig_ptl).flatten())
# delta_ptl = recons_ptl - orig_ptl
# histogram(delta_ptl.flatten(), marginal="box")
# # %%
# scipy.stats.kurtosis(to_numpy(delta_ptl).flatten())
# # %%
# token_df = nutils.make_token_df(example_tokens).query("pos>=1")
# token_df["delta_ptl"] = to_numpy(delta_ptl.flatten())

# # %%
# display(token_df.sort_values("delta_ptl", ascending=False).head(20))
# display(token_df.sort_values("delta_ptl", ascending=True).head(20))
# # %%


In [None]:

with torch.no_grad():
    virtual_weights = encoder0.W_dec @ model.W_in[1] @ model.W_out[1] @ encoder1.W_enc
    virtual_weights.shape


In [None]:
model.cfg.d_model

In [None]:
encoder1.W_enc.shape

In [None]:

# %%
histogram(virtual_weights.flatten()[::1001])
# %%
with torch.no_grad():
    neuron2neuron = model.W_out[0] @ model.W_in[1]
# histogram(neuron2neuron.flatten()[::101])
# # %%
# histogram(virtual_weights.mean(0), title="Ave by end feature")
# histogram(virtual_weights.mean(1), title="Ave by start feature")
# histogram(virtual_weights.median(0).values, title="Median by end feature")
# histogram(virtual_weights.median(1).values, title="Median by start feature")
# %%


## Co-Occurence investigation

In [None]:
torch.cuda.empty_cache()
gc.collect()

In [None]:
# START HERE!
example_tokens = tokenized_data[:800]["tokens"]
with torch.no_grad():
    _, cache = model.run_with_cache(example_tokens, stop_at_layer=2, names_filter=lambda x: "mlp_out" in x)
    loss, recons_mlp_out0, hidden_acts0, l2_loss, l1_loss = encoder0(cache["mlp_out", 0])
    loss, recons_mlp_out1, hidden_acts1, l2_loss, l1_loss = encoder1(cache["mlp_out", 1])

try:
    hidden_acts0 = hidden_acts0[:, 1:, :]
    hidden_acts0 = einops.rearrange(hidden_acts0, "batch pos d_enc -> (batch pos) d_enc")
    hidden_acts1 = hidden_acts1[:, 1:, :]
    hidden_acts1 = einops.rearrange(hidden_acts1, "batch pos d_enc -> (batch pos) d_enc")
except:
    print("FAILED")
    pass

hidden_is_pos0 = hidden_acts0 > 0
hidden_is_pos1 = hidden_acts1 > 0

In [None]:
# Co-Occur in SAEs

d_enc = hidden_acts0.shape[-1]
cooccur_count = torch.zeros((d_enc, d_enc), device="cuda", dtype=torch.float32)
for other_i in tqdm.trange(d_enc):
    cooccur_count[:, other_i] = hidden_is_pos0[hidden_is_pos0[:, other_i]].float().sum(0)
# %%
num_firings0 = hidden_is_pos0.sum(0)
cooccur_freq = cooccur_count / torch.maximum(num_firings0[:, None], num_firings0[None, :])
# %%
# cooccur_count = cooccur_count.float() / hidden_acts0.shape[0]
# %%
histogram(cooccur_freq[cooccur_freq>0.1], log_y=True)

# set nan to 0
cooccur_freq[cooccur_freq.isnan()] = 0.
# set diag to 0
cooccur_freq.fill_diagonal_(0)

# set to 0 any rows or columns where individual features don't fire enough
cooccur_freq[num_firings0 < 10, :] = 0
cooccur_freq[:, num_firings0 < 10] = 0

val, ind = cooccur_freq.flatten().topk(100)

start_topk_ind = (ind // d_enc)
other_topk_ind = (ind % d_enc)
# print(val)
# print(start_topk_ind)
# print(other_topk_ind)

# print("earlier_features", num_firings0[start_topk_ind])
# print("other features", num_firings0[other_topk_ind])

df = pd.DataFrame(
    {
        "start_topk_ind": start_topk_ind.cpu().numpy(),
        "other_topk_ind": other_topk_ind.cpu().numpy(),
        "val": val.cpu().numpy(),
        "earlier_features": num_firings0[start_topk_ind].cpu().numpy(),
        "other_features": num_firings0[other_topk_ind].cpu().numpy(),
    }
)

df.head(30)

In [None]:
# Co-Occur in SAEs

d_enc = hidden_acts0.shape[-1]
cooccur_count = torch.zeros((d_enc, d_enc), device="cuda", dtype=torch.float32)
for other_i in tqdm.trange(d_enc):
    cooccur_count[:, other_i] = hidden_is_pos1[hidden_is_pos1[:, other_i]].float().sum(0)
# %%
num_firings1 = hidden_is_pos1.sum(0)
cooccur_freq = cooccur_count / torch.maximum(num_firings1[:, None], num_firings1[None, :])
# %%
# cooccur_count = cooccur_count.float() / hidden_acts0.shape[0]
# %%
histogram(cooccur_freq[cooccur_freq>0.1], log_y=True)

# set nan to 0
cooccur_freq[cooccur_freq.isnan()] = 0.
# set diag to 0
cooccur_freq.fill_diagonal_(0)

# set to 0 any rows or columns where individual features don't fire enough
cooccur_freq[num_firings1 < 10, :] = 0
cooccur_freq[:, num_firings1 < 10] = 0

val, ind = cooccur_freq.flatten().topk(100)

start_topk_ind = (ind // d_enc)
other_topk_ind = (ind % d_enc)
# print(val)
# print(start_topk_ind)
# print(other_topk_ind)

# print("earlier_features", num_firings0[start_topk_ind])
# print("other features", num_firings0[other_topk_ind])

df = pd.DataFrame(
    {
        "start_topk_ind": start_topk_ind.cpu().numpy(),
        "other_topk_ind": other_topk_ind.cpu().numpy(),
        "val": val.cpu().numpy(),
        "earlier_features": num_firings1[start_topk_ind].cpu().numpy(),
        "other_features": num_firings1[other_topk_ind].cpu().numpy(),
    }
)
print(df["start_topk_ind"].value_counts())
df.head(30)

In [None]:
feature = 6269
# get rows where feature is in start_topk_ind
df.query("start_topk_ind==@feature").sort_values("val", ascending=False).head(20)
# get rows where feature is in other_topk_ind
df.query("other_topk_ind==@feature").sort_values("val", ascending=False).head(20)


In [None]:
# is_token = example_tokens[:, 1:].flatten() == model.to_single_token(" big")
# line(hidden_is_pos0[is_token].float().mean(0))
# line(hidden_is_pos1[is_token].float().mean(0))

In [None]:
# feature_ids = [3698

In [None]:
feature_id_1 = 12989
feature_id_2 = 14809
feature_id_3 = 15690
token_df = nutils.make_token_df(example_tokens).query("pos>=1")
token_df["val1"] = to_numpy(hidden_acts0[:, feature_id_1])
token_df["val2"] = to_numpy(hidden_acts0[:, feature_id_2])
token_df["val3"] = to_numpy(hidden_acts0[:, feature_id_3])
pd.set_option('display.max_rows', 40)
display(token_df.sort_values("val1", ascending=False).head(20))
display(token_df.sort_values("val2", ascending=False).head(20))
display(token_df.sort_values("val3", ascending=False).head(20))

In [None]:
# px.scatter(token_df, x="val1", y="val2", hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2")
fig = px.scatter_3d(
    token_df, x="val1", y="val2", z="val3", 
    opacity = 0.4,
    hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2 vs Feature 3")
fig.update_traces(marker=dict(size=5))
# make it much larger
fig.update_layout(width=800, height=800)
fig.show()

### Looking at Co-Occurence graphically



In [None]:
feature_id_1 = 12989
feature_id_2 = 14809
feature_id_3 = 15690
token_df = nutils.make_token_df(example_tokens).query("pos>=1")
token_df["0_val1"] = to_numpy(hidden_acts0[:, feature_id_1])
token_df["0_val2"] = to_numpy(hidden_acts0[:, feature_id_2])
token_df["0_val3"] = to_numpy(hidden_acts0[:, feature_id_3])
feature_id_1 = 1245
feature_id_2 = 1353
feature_id_3 = 9604
token_df["1_val1"] = to_numpy(hidden_acts1[:, feature_id_1])
token_df["1_val2"] = to_numpy(hidden_acts1[:, feature_id_2])
token_df["1_val3"] = to_numpy(hidden_acts1[:, feature_id_3])

fig = px.scatter_matrix(
    token_df,
    dimensions=["0_val1", "0_val2", "0_val3", "1_val1", "1_val2", "1_val3"],
    # color="species"
    hover_data=["str_tokens", "context"],
)
fig.update_traces(marker=dict(size=5))
# make it much larger
fig.update_layout(width=1200, height=1200)
fig.show()

# fig = px.scatter_3d(
#     token_df, x="0_val1", y="0_val2", z="0_val3", 
#     color="1_val3",
#     hover_data=["str_tokens", "context"], title="Feature 1 vs Feature 2 vs Feature 3")
# fig.update_traces(marker=dict(size=5))
# # make it much larger
# fig.update_layout(width=800, height=800)
# fig.show()

### Co-Occurence between SAE's.

In [None]:
# Co-Occur between SAEs

d_enc = hidden_acts0.shape[-1]
cooccur_count = torch.zeros((d_enc, d_enc), device="cuda", dtype=torch.float32)
for end_i in tqdm.trange(d_enc):
    cooccur_count[:, end_i] = hidden_is_pos0[hidden_is_pos1[:, end_i]].float().sum(0)
# %%
num_firings0 = hidden_is_pos0.sum(0)
num_firings1 = hidden_is_pos1.sum(0)
cooccur_freq = cooccur_count / torch.maximum(num_firings0[:, None], num_firings1[None, :])
# %%
# cooccur_count = cooccur_count.float() / hidden_acts0.shape[0]
# %%
histogram(cooccur_freq[cooccur_freq>0.1], log_y=True)


In [None]:

# %%
cooccur_freq[cooccur_freq.isnan()] = 0.
val, ind = cooccur_freq.flatten().topk(10)

start_topk_ind = (ind // d_enc)
end_topk_ind = (ind % d_enc)
print(val)
print(start_topk_ind)
print(end_topk_ind)

print("earlier_features", num_firings0[start_topk_ind])
print("later features", num_firings1[end_topk_ind])


###  Dig Into a specific Feature

In [None]:
# %%
feature_id = 2593
layer = 0
token_df = nutils.make_token_df(example_tokens).query("pos>=1")
token_df["val"] = to_numpy(hidden_acts1[:, feature_id])
pd.set_option('display.max_rows', 50)
display(token_df.sort_values("val", ascending=False).head(50))


In [None]:
logit_weights = encoder1.W_dec[feature_id, :] @ model.W_U
vocab_df = nutils.create_vocab_df(logit_weights)
vocab_df["has_space"] = vocab_df["token"].apply(lambda x: nutils.SPACE in x)
px.histogram(vocab_df, x="logit", color="has_space", barmode="overlay", marginal="box")
display(vocab_df.head(20))

In [None]:
is_an = example_tokens[:, 1:].flatten() == model.to_single_token(" an")
line(hidden_is_pos0[is_an].float().mean(0))
line(hidden_is_pos1[is_an].float().mean(0))

In [None]:
scatter(x=virtual_weights[5740], y=hidden_is_pos1[is_an].float().mean(0), hover=np.arange(d_enc))

In [None]:
l0_feature_an = 5740
line(hidden_acts0[5740])
replacement_for_feature = torch.zeros_like(example_tokens).cuda()
replacement_for_feature[:, 1:] = hidden_acts0[:, 5740].reshape(600, 127)
mlp_out0_diff = replacement_for_feature[:, :, None] * encoder0.W_dec[l0_feature_an]
new_end_topk = end_topk_ind[start_topk_ind==5740]

### Causal Intervention

In [None]:

def remove_an_feature(mlp_out, hook):
    mlp_out[:, :] -= mlp_out0_diff
    return mlp_out
    
model.reset_hooks()
model.blocks[0].hook_mlp_out.add_hook(remove_an_feature)
_, new_cache = model.run_with_cache(example_tokens, stop_at_layer=2, names_filter=lambda x: "mlp_out" in x)
with torch.no_grad():
    loss, x_reconstruct, hidden_acts_new, l2_loss, l1_loss = encoder1(new_cache["mlp_out", 1])
model.reset_hooks()
# %%
new_hidden_acts_on_an = hidden_acts_new[:, :, new_end_topk].reshape(-1, 5)[example_tokens.flatten()==model.to_single_token(" an"), :]
old_hidden_acts_on_an = hidden_acts1[:, new_end_topk][example_tokens[:, 1:].flatten()==model.to_single_token(" an"), :]
for i in range(5):
    scatter(x=old_hidden_acts_on_an[:, i], y=new_hidden_acts_on_an[:, i], hover=np.arange(180), title=new_end_topk[i].item(), include_diag=True, yaxis="POst Ablation", xaxis="Pre Ablation")
# %%
