In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from baukit import TraceDict

# Download the modelz
device = "cuda:0"
model_name="EleutherAI/Pythia-70M-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# set seed
torch.manual_seed(0)
np.random.seed(0)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from autoencoders import *
# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "jbrinkma/Pythia-70M-deduped-SAEs"
autoencoders = []
# layers = model.config.num_hidden_layers
layers = 2
cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(layers-1)]
cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.attention") for i in range(layers-1)]
num_layers = len(cache_names)
cache_names = [item for sublist in cache_names for item in sublist]
filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-attention-{i+1}.pt") for i in range(layers-1)]
filenames = [item for sublist in filenames for item in sublist]
for filen in filenames:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to_device(device)
    # Freeze autoencoder weights
    autoencoder.encoder.requires_grad_(False)
    autoencoder.encoder_bias.requires_grad_(False)
    autoencoders.append(autoencoder)

In [3]:
from activation_dataset import chunk_and_tokenize
# Download the dataset
# This formats it, so every datapoint is max_length tokens long
# The batch size is for loading activations from the LLM, not for inference on the autoencoder
dataset_name="stas/openwebtext-10k"
max_seq_length=32
dataset = load_dataset(dataset_name, split="train[:500]")
dataset, _ = chunk_and_tokenize(dataset, tokenizer, max_length=max_seq_length)
max_tokens = dataset.num_rows*max_seq_length
print(f"Number of tokens: {max_tokens/1e6:.2f}M")

Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
                                                                           

Number of tokens: 0.57M




In [4]:
from baukit import TraceDict
# with TraceDict(model, cache_names) as ret:

num_features, d_model = autoencoder.encoder.shape
datapoints = dataset.num_rows
dictionary_activations_list = [torch.zeros((datapoints*max_seq_length, num_features)) for _ in range(len(cache_names))]
batch_size = 32
with torch.no_grad(), dataset.formatted_as("pt"):
    dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
    for i, batch in enumerate(tqdm(dl)):
        batch = batch.to(device)
        # Get LLM intermediate activations
        with TraceDict(model, cache_names) as ret:
            _ = model(batch)
        # Get SAE intermediate codes
        for ae_ind, cache_name in enumerate(cache_names):
            autoencoder = autoencoders[ae_ind]
            internal_activations = ret[cache_name].output
            # check if instance tuple ie a layer output
            if(isinstance(internal_activations, tuple)):
                internal_activations = internal_activations[0]
            batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations_list[ae_ind][i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
        # break

100%|██████████| 553/553 [00:24<00:00, 22.45it/s]


In [68]:
dictionary_activations_list, cache_names 

([tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]]),
  tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]])],
 ['gpt_neox.layers.0', 'gpt_neox.layers.1.attention'])

In [5]:
for ind, dictionary_activations in enumerate(dictionary_activations_list):
    sparsity = dictionary_activations[:200].count_nonzero(dim=1).float().mean()
    print(f"Sparsity: {sparsity:.4f} | {cache_names[ind]}")

Sparsity: 8.2150 | gpt_neox.layers.0
Sparsity: 8.4450 | gpt_neox.layers.1.attention


In [61]:
OV_circuit[:, t_f].shape

torch.Size([3072])

# QK OV

In [76]:
import einops
from transformer_lens import FactoredMatrix

def get_gpt_neox_weights(model, block):
    with torch.no_grad():
        # get q, k, v weights
        n_heads = model.gpt_neox.layers[block].attention.num_attention_heads
        W = model.gpt_neox.layers[block].attention.query_key_value.weight
        W = einops.rearrange(W, "(i qkv h) m->qkv i m h", i=n_heads, qkv=3)
        W_Q = W[0]
        W_K = W[1]
        W_V = W[2]

        # get o weights
        W_O = model.gpt_neox.layers[block].attention.dense.weight
        W_O = einops.rearrange(W_O, "m (i h)->i h m", i=n_heads)

        # get QK and OV
        W_K_T = einops.rearrange(W_K, "head_index d_model d_head -> head_index d_head d_model")
        QK = FactoredMatrix(W_Q, W_K_T)
        OV = FactoredMatrix(W_V, W_O)

    return W_Q, W_K, W_V, W_O, QK, OV

dict_res = autoencoders[0].get_learned_dict()
dict_attn = autoencoders[1].get_learned_dict()
n_features = dict_res.shape[0]

qk_index = 4
ov_index = 0
head = 1
layer = 0
block = layer + 1
W_Q, W_K, W_V, W_O, QK, OV = get_gpt_neox_weights(model, block)
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10

input_setting = "input_only"
model_type = "causal"

with torch.no_grad():
    QK = W_Q[head] @ W_K[head].T
    OV = W_V[head] @ W_O[head]

    QK_circuit = dict_res @ QK @ dict_res.T
    OV_circuit = dict_res @ OV @ dict_attn.T

    top_val, top_ind = QK_circuit.flatten().topk(10)
    source_feature, target_feature = top_ind // n_features, top_ind % n_features   
    print(top_val)
    print(source_feature)
    print(target_feature)

    s_f = source_feature[qk_index].item()
    t_f = target_feature[qk_index].item()
    # o_f = OV_circuit[s_f].topk(10).indices[0].item()
    o_f = OV_circuit[t_f].topk(10).indices[0].item()
    d_s = dictionary_activations_list[0][:, s_f]
    d_t = dictionary_activations_list[0][:, t_f]
    d_o = dictionary_activations_list[1][:, o_f]
    
    # find where both dictionary activations are > 0
    threshold = 2.0
    # max_s = d_s.max()
    # max_t = d_t.max()
    # mask = (d_s > max_s // 2) & (d_t > max_t // 2)
    # mask = (d_s > threshold)
    mask = (d_t > threshold)

    # find those indices
    indices = torch.arange(len(d_s))[mask]

    # sort the mask by max d_s features
    indices = indices[d_t[mask].argsort(descending=True)]
    #Take the first num_feature_datapoints
    indices = indices[:num_feature_datapoints]
    # o_f = OV_circuit[s_f].to("cpu")
    # o_f = (o_f_pre * dictionary_activations_list[1][indices[ov_index]].to(device)).argmax()

    # Check if indices is empty
    if len(indices) == 0:
        assert False, "No indices found"
    _, _, _, full_token_list, _, full_activations_s = get_feature_datapoints(indices, d_s, tokenizer, max_seq_length, dataset)
    _, _, _, _, _, full_activations_t = get_feature_datapoints(indices, d_t, tokenizer, max_seq_length, dataset)
    _, _, _, _, _, full_activations_o = get_feature_datapoints(indices, d_o, tokenizer, max_seq_length, dataset)

    # Repeat each element of full_token_list 3 times. e.g. [A, B] -> [A, A, A, B, B, B]
    # full_token_list is a list of tensors w/ tokens
    full_token_tensor = torch.stack(full_token_list)  # Stack the list of tensors
    repeated_full_token_tensor = full_token_tensor.repeat_interleave(3, dim=0)
    interleaved_activations = [element for trio in zip(full_activations_s, full_activations_t, full_activations_o) for element in trio]
    logit_diff_s = ablate_feature_direction(model, full_token_list, cache_names[0], max_seq_length, autoencoders[0], feature = s_f, batch_size=32, setting="sentences", model_type=model_type)
    logit_diff_t = ablate_feature_direction(model, full_token_list, cache_names[0], max_seq_length, autoencoders[0], feature = t_f, batch_size=32, setting="sentences", model_type=model_type)
    logit_diff_o = ablate_feature_direction(model, full_token_list, cache_names[1], max_seq_length, autoencoders[1], feature = o_f, batch_size=32, setting="sentences", model_type=model_type)
    logit_diffs = [element for trio in zip(logit_diff_s, logit_diff_t, logit_diff_o) for element in trio]
save_token_display(repeated_full_token_tensor, interleaved_activations, tokenizer, path =f"{save_path}qko_{s_f}_{t_f}_{o_f}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)

tensor([0.1230, 0.1212, 0.1158, 0.1135, 0.1131, 0.1121, 0.1116, 0.1113, 0.1110,
        0.1102], device='cuda:0')
tensor([1633, 2769,  539,  539, 1195, 1195, 2769, 1195,  568,  539],
       device='cuda:0')
tensor([ 537, 1929, 2313,  248, 1976,  409,  636, 2993,   22, 2031],
       device='cuda:0')


  full_tok = torch.tensor(dataset[md]["input_ids"])


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


In [58]:
d_o[mask].topk(10)

torch.return_types.topk(
values=tensor([0.4463, 0.1485, 0.1364, 0.1278, 0.1276, 0.1230, 0.0754, 0.0706, 0.0505,
        0.0466]),
indices=tensor([190, 224, 455, 391, 363, 255,  42, 394, 176, 279]))

In [None]:
# OV_circuit
# visualize the QK circuit
from matplotlib import pyplot as plt

plt.figure(figsize=(10,10))
# plt.imshow(QK_circuit[:100,:100].cpu().numpy())
# plt.colorbar()
for x in range(20):
    plt.hist(QK_circuit[:, x].cpu().numpy(), bins=100)
plt.title("QK distribution of attention for 10 features")
plt.xlabel("Attention score (pre-softmax)")
plt.show()

In [None]:
a = torch.tensor([[1,2], [3,4]])
a[:, 0]

In [None]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10
sae_index = 0
cache_name = cache_names[sae_index]
autoencoder = autoencoders[sae_index]
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32)
sparsity = dictionary_activations[:80].count_nonzero(dim=1).float().mean()
print(f"Sparsity: {sparsity}")
max_values = dictionary_activations.max(dim=0)

feature_ind = max_values.values.topk(20).indices

In [None]:

input_setting = "input_only"
model_type = "causal"
features = feature_ind.tolist()
# features = [604,  334,  674]
# features = [1484]
num_features = 10
sae_index = 0
dictionary_activations = dictionary_activations_list[sae_index]
cache_name = cache_names[sae_index]
autoencoder = autoencoders[sae_index]
for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 10
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        continue
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    else:
        logit_diffs = ablate_feature_direction(model, dataset, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="dataset")
        _, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, tokenizer, max_seq_length, dataset)
        get_token_statistics(feature, logit_diffs, dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, setting="output", num_unique_tokens=10)
        save_token_display(full_token_list_ablated, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = full_activations_ablated)

In [None]:
OV_circuit[t_f].topk(10).indices[0].item()

In [None]:
save_token_display(repeated_full_token_tensor, interleaved_activations, tokenizer, path =f"{save_path}uniform_{0}.png", logit_diffs = None, model_type=model_type, show=True)

In [None]:
OV_circuit[target_feature].topk(10)

In [None]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10

input_setting = "input_only"
model_type = "causal"
# features = feature_ind.tolist()
# features = [604,  334,  674]
features = [604, 334]
# features = [418]
num_features = 10
sae_index = 0
dictionary_activations = dictionary_activations_list[sae_index]
cache_name = cache_names[sae_index]
autoencoder = autoencoders[sae_index]
for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 10
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        continue
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    uniform_indices = torch.tensor([1255, 103303, 115751, 270277])
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)