In [None]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification


# Download the model
device = "cuda:0"
# model_name="BlueSunflower/Pythia-70M-chess"
model_name="EleutherAI/Pythia-70M-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
max_seq_length=30

from torch import nn
from torchtyping import TensorType
class TiedSAE(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super().__init__()
        self.encoder = nn.Parameter(torch.empty((n_dict_components, activation_size)))
        nn.init.xavier_uniform_(self.encoder)
        self.encoder_bias = nn.Parameter(torch.zeros((n_dict_components,)))

    def get_learned_dict(self):
        norms = torch.norm(self.encoder, 2, dim=-1)
        return self.encoder / torch.clamp(norms, 1e-8)[:, None]

    def encode(self, batch):
        c = torch.einsum("nd,bd->bn", self.encoder, batch)
        c = c + self.encoder_bias
        c = torch.clamp(c, min=0.0)
        return c

    def decode(self, code: TensorType["_batch_size", "_n_dict_components"]) -> TensorType["_batch_size", "_activation_size"]:
        learned_dict = self.get_learned_dict()
        x_hat = torch.einsum("nd,bn->bd", learned_dict, code)
        return x_hat

    def forward(self, batch: TensorType["_batch_size", "_activation_size"]) -> TensorType["_batch_size", "_activation_size"]:
        c = self.encode(batch)
        x_hat = self.decode(c)
        return x_hat, c

    def n_dict_components(self):
        return self.get_learned_dict().shape[0]
# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "Elriggs/pythia-70M-deduped-sae"
filename = ["pythia-70m-deduped_r4_gpt_neox.layers.1.pt", "pythia-70m-deduped_r4_gpt_neox.layers.1.attention.pt"]
autoencoders = []
for filen in filename:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to(device)
    autoencoders.append(autoencoder)
cache_names = ["_".join(model_id.split("_")[-2:]) for model_id in filename]
cache_names = [".".join(cache_name.split(".")[:-1]) for cache_name in cache_names]

In [None]:
from datasets import load_dataset
dataset_name = "NeelNanda/pile-10k"
dataset = load_dataset(dataset_name, split="train").map(
    lambda x: tokenizer(x['text']),
    batched=True
).filter(
    lambda x: len(x['input_ids']) > max_seq_length
).map(
    lambda x: {'input_ids': x['input_ids'][:max_seq_length]}
)

In [None]:
from baukit import TraceDict
# with TraceDict(model, cache_names) as ret:

num_features, d_model = autoencoder.encoder.shape
datapoints = dataset.num_rows
dictionary_activations_list = [torch.zeros((datapoints*max_seq_length, num_features)) for _ in range(len(cache_names))]
batch_size = 32
with torch.no_grad(), dataset.formatted_as("pt"):
    dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
    for i, batch in enumerate(tqdm(dl)):
        batch = batch.to(device)
        # Get LLM intermediate activations
        with TraceDict(model, cache_names) as ret:
            _ = model(batch)
        # Get SAE intermediate codes
        for ae_ind, cache_name in enumerate(cache_names):
            autoencoder = autoencoders[ae_ind]
            internal_activations = ret[cache_name].output
            # check if instance tuple ie a layer output
            if(isinstance(internal_activations, tuple)):
                internal_activations = internal_activations[0]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations_list[ae_ind][i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
        # break
for dictionary_activations in dictionary_activations_list:
    sparsity = dictionary_activations[:200].count_nonzero(dim=1).float().mean()
    print(f"Sparsity: {sparsity:.4f}")

In [None]:
for ind, dictionary_activations in enumerate(dictionary_activations_list):
    sparsity = dictionary_activations[:200].count_nonzero(dim=1).float().mean()
    print(f"Sparsity: {sparsity:.4f} | {cache_names[ind]}")

# QK OV

In [53]:
import einops
from transformer_lens import FactoredMatrix

def get_gpt_neox_weights(model, block):
    with torch.no_grad():
        # get q, k, v weights
        n_heads = model.gpt_neox.layers[block].attention.num_attention_heads
        W = model.gpt_neox.layers[block].attention.query_key_value.weight
        W = einops.rearrange(W, "(i qkv h) m->qkv i m h", i=n_heads, qkv=3)
        W_Q = W[0]
        W_K = W[1]
        W_V = W[2]

        # get o weights
        W_O = model.gpt_neox.layers[block].attention.dense.weight
        W_O = einops.rearrange(W_O, "m (i h)->i h m", i=n_heads)

        # get QK and OV
        W_K_T = einops.rearrange(W_K, "head_index d_model d_head -> head_index d_head d_model")
        QK = FactoredMatrix(W_Q, W_K_T)
        OV = FactoredMatrix(W_V, W_O)

    return W_Q, W_K, W_V, W_O, QK, OV

dict_res = autoencoders[0].get_learned_dict()
dict_attn = autoencoders[1].get_learned_dict()
n_features = dict_res.shape[0]

qk_index = 2
ov_index = 1
head = 0
layer = 1
block = layer + 1
W_Q, W_K, W_V, W_O, QK, OV = get_gpt_neox_weights(model, block)
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10

input_setting = "input_only"
model_type = "causal"

with torch.no_grad():
    QK = W_Q[head] @ W_K[head].T
    OV = W_V[head] @ W_O[head]

    QK_circuit = dict_res @ QK @ dict_res.T
    OV_circuit = dict_res @ OV @ dict_attn.T

    top_val, top_ind = QK_circuit.flatten().topk(10)
    source_feature, target_feature = top_ind // n_features, top_ind % n_features   

    print(source_feature)
    print(target_feature)
    
    s_f = source_feature[qk_index].item()
    t_f = target_feature[qk_index].item()
    # o_f = OV_circuit[s_f].topk(10).indices[ov_index].item()
    d_s = dictionary_activations_list[0][:, s_f]
    d_t = dictionary_activations_list[0][:, t_f]
    
    # find where both dictionary activations are > 0
    threshold = 0.1
    max_s = d_s.max()
    max_t = d_t.max()
    mask = (d_s > max_s // 2) & (d_t > max_t // 2)
    # mask = (d_s > threshold) & (d_t > threshold)

    # find those indices
    indices = torch.arange(len(d_s))[mask]

    # sort the mask by max d_s features
    indices = indices[d_s[mask].argsort(descending=True)]
    #Take the first num_feature_datapoints
    indices = indices[:num_feature_datapoints]
    o_f_pre = OV_circuit[s_f]
    o_f = (o_f_pre * dictionary_activations_list[1][indices[ov_index]].to(device)).argmax()
    d_o = dictionary_activations_list[1][:, o_f]

    # Check if indices is empty
    if len(indices) == 0:
        assert False, "No indices found"
    _, _, _, full_token_list, _, full_activations_s = get_feature_datapoints(indices, d_s, tokenizer, max_seq_length, dataset)
    _, _, _, _, _, full_activations_t = get_feature_datapoints(indices, d_t, tokenizer, max_seq_length, dataset)
    _, _, _, _, _, full_activations_o = get_feature_datapoints(indices, d_o, tokenizer, max_seq_length, dataset)

    # Repeat each element of full_token_list 3 times. e.g. [A, B] -> [A, A, A, B, B, B]
    # full_token_list is a list of tensors w/ tokens
    full_token_tensor = torch.stack(full_token_list)  # Stack the list of tensors
    repeated_full_token_tensor = full_token_tensor.repeat_interleave(3, dim=0)
    interleaved_activations = [element for trio in zip(full_activations_s, full_activations_t, full_activations_o) for element in trio]
    logit_diff_s = ablate_feature_direction(model, full_token_list, cache_names[0], max_seq_length, autoencoders[0], feature = s_f, batch_size=32, setting="sentences", model_type=model_type)
    logit_diff_t = ablate_feature_direction(model, full_token_list, cache_names[0], max_seq_length, autoencoders[0], feature = t_f, batch_size=32, setting="sentences", model_type=model_type)
    logit_diff_o = ablate_feature_direction(model, full_token_list, cache_names[1], max_seq_length, autoencoders[1], feature = o_f, batch_size=32, setting="sentences", model_type=model_type)
    logit_diffs = [element for trio in zip(logit_diff_s, logit_diff_t, logit_diff_o) for element in trio]
save_token_display(repeated_full_token_tensor, interleaved_activations, tokenizer, path =f"{save_path}qko_{s_f}_{t_f}_{o_f}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)

tensor([ 604,  604, 1484,  604, 1208,  604, 1442,   93, 1208,  305],
       device='cuda:0')
tensor([ 334,  674, 1484, 1542,  305, 1631, 1442,   93, 1550,  334],
       device='cuda:0')


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


In [48]:
full_activations_o[-1]

[1.8221529722213745,
 1.9656373262405396,
 1.8898800611495972,
 1.827553391456604,
 1.9026468992233276,
 1.7848013639450073,
 1.9954370260238647,
 1.7708600759506226,
 1.2813901901245117,
 1.370510458946228,
 1.4281545877456665,
 1.6888474225997925,
 1.3720805644989014,
 1.5775412321090698,
 1.2513877153396606,
 1.324180006980896,
 1.4201960563659668,
 1.2852237224578857,
 1.361263632774353,
 1.2659509181976318,
 1.1160552501678467,
 1.4646660089492798,
 1.1142953634262085,
 1.3184452056884766,
 1.1532682180404663,
 1.0613470077514648,
 0.8021490573883057,
 0.8129143714904785,
 0.8632248640060425,
 1.0199837684631348]

In [None]:
# OV_circuit
# visualize the QK circuit
from matplotlib import pyplot as plt

plt.figure(figsize=(10,10))
# plt.imshow(QK_circuit[:100,:100].cpu().numpy())
# plt.colorbar()
for x in range(20):
    plt.hist(QK_circuit[:, x].cpu().numpy(), bins=100)
plt.title("QK distribution of attention for 10 features")
plt.xlabel("Attention score (pre-softmax)")
plt.show()

In [None]:
a = torch.tensor([[1,2], [3,4]])
a[:, 0]

In [None]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10
sae_index = 0
cache_name = cache_names[sae_index]
autoencoder = autoencoders[sae_index]
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32)
sparsity = dictionary_activations[:80].count_nonzero(dim=1).float().mean()
print(f"Sparsity: {sparsity}")
max_values = dictionary_activations.max(dim=0)

feature_ind = max_values.values.topk(20).indices

In [None]:

input_setting = "input_only"
model_type = "causal"
features = feature_ind.tolist()
# features = [604,  334,  674]
# features = [1484]
num_features = 10
sae_index = 0
dictionary_activations = dictionary_activations_list[sae_index]
cache_name = cache_names[sae_index]
autoencoder = autoencoders[sae_index]
for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 10
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        continue
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    else:
        logit_diffs = ablate_feature_direction(model, dataset, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="dataset")
        _, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, tokenizer, max_seq_length, dataset)
        get_token_statistics(feature, logit_diffs, dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, setting="output", num_unique_tokens=10)
        save_token_display(full_token_list_ablated, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = full_activations_ablated)

In [None]:
OV_circuit[t_f].topk(10).indices[0].item()

In [None]:
save_token_display(repeated_full_token_tensor, interleaved_activations, tokenizer, path =f"{save_path}uniform_{0}.png", logit_diffs = None, model_type=model_type, show=True)

In [None]:
OV_circuit[target_feature].topk(10)

In [None]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10

input_setting = "input_only"
model_type = "causal"
# features = feature_ind.tolist()
# features = [604,  334,  674]
features = [604, 334]
# features = [418]
num_features = 10
sae_index = 0
dictionary_activations = dictionary_activations_list[sae_index]
cache_name = cache_names[sae_index]
autoencoder = autoencoders[sae_index]
for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 10
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        continue
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    uniform_indices = torch.tensor([1255, 103303, 115751, 270277])
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)