In [1]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from baukit import TraceDict


# Download the model
device = "cuda:0"
model_name="EleutherAI/Pythia-70M-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from autoencoders import *
# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "jbrinkma/Pythia-70M-deduped-SAEs"
autoencoders = []
layers = model.config.num_hidden_layers
cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(layers-1)]
# cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.attention") for i in range(layers-1)]
# cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(1,2)]
num_layers = len(cache_names)
cache_names = [item for sublist in cache_names for item in sublist]
filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-mlp-{i+1}.pt") for i in range(layers-1)]
# filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-attention-{i+1}.pt") for i in range(layers-1)]
# filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-mlp-{i+1}.pt") for i in range(1,2)]
filenames = [item for sublist in filenames for item in sublist]
for filen in filenames:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to_device(device)
    # Freeze autoencoder weights
    autoencoder.encoder.requires_grad_(False)
    autoencoder.encoder_bias.requires_grad_(False)
    autoencoders.append(autoencoder)

In [3]:
max_seq_length=30 # max length of per data point
from datasets import load_dataset
# dataset_name = "NeelNanda/pile-10k"
dataset_name = "stas/openwebtext-10k"
dataset = load_dataset(dataset_name, split="train").map(
    lambda x: tokenizer(x['text']),
    batched=True
).filter(
    lambda x: len(x['input_ids']) > max_seq_length
).map(
    lambda x: {'input_ids': x['input_ids'][:max_seq_length]}
)

Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-28e2aeab2f86f98a.arrow
                                                                      

In [4]:
import torch as torch
import os
from torch import nn
    
class mlp_no_bias(nn.Module):
    def __init__(self, input_size, output_size=None):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        self.linear = nn.Linear(input_size, input_size, bias=False)
        self.linear2 = nn.Linear(input_size, output_size, bias=False)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
    
class linear_no_bias(nn.Module):
    def __init__(self, input_size, output_size=None):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        self.linear = nn.Linear(input_size, output_size, bias=False)
    def forward(self, x):
        x = self.linear(x)
        return x

num_layers = 5
dir = "sparse_weights"
linear_name = "linear_weights_no_bias_{layer}.pt"
mlp_name = "mlp_weights_no_bias_{layer}.pt"
linear_sparse_connections = []
mlp_sparse_connections = []
for layer in range(num_layers):
    linear_sparse_connections.append(torch.load(os.path.join(dir, linear_name.format(layer=layer))))
    mlp_sparse_connections.append(torch.load(os.path.join(dir, mlp_name.format(layer=layer))))

In [5]:
# Load in the features indices
# Note: linear & MLP features are a subet of alive features, not total features
import pickle
with open("linear_features.pkl", "rb") as f:
    linear_features = pickle.load(f)
with open("mlp_features.pkl", "rb") as f:
    mlp_features = pickle.load(f)
with open("alive_features.pkl", "rb") as f:
    alive_features_ind = pickle.load(f)

In [6]:
threshold = 0.0005
for i in range (10):
    large_weights = (linear_sparse_connections[0].linear.weight[i] > threshold).sum()
    print(f"Large weights: {large_weights}")

Large weights: 5
Large weights: 6
Large weights: 3
Large weights: 6
Large weights: 13
Large weights: 21
Large weights: 3
Large weights: 3
Large weights: 4
Large weights: 3


In [51]:
# Figure out Features
layer = 1
cache_names_layer = cache_names[layer*2:layer*2+2]
autoencoders_layer = autoencoders[layer*2:layer*2+2]
output_index_restricted = 1
input_index = linear_sparse_connections[layer].linear.weight.topk(4).indices[output_index_restricted][:2]
# Find the output_index_restricted'th True index of linear_features[layer]
output_index = linear_features[layer].nonzero()[0][output_index_restricted]
features_all = [input_index.tolist(), [output_index]]

# Now view the features
ind_num = 1
features = features_all[ind_num]
cache_name = cache_names_layer[ind_num]
autoencoder = autoencoders_layer[ind_num]
cache_name

'gpt_neox.layers.2.mlp'

In [55]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

# Figure out Features
layer = 1
cache_names_layer = cache_names[layer*2:layer*2+2]
autoencoders_layer = autoencoders[layer*2:layer*2+2]
output_index_restricted = 1
input_index = linear_sparse_connections[layer].linear.weight.topk(4).indices[output_index_restricted][:2]
# Find the output_index_restricted'th True index of linear_features[layer]
output_index = linear_features[layer].nonzero()[0][output_index_restricted]
features_all = [input_index.tolist(), [output_index]]

# Now view the features
ind_num = 0
features = features_all[ind_num]
cache_name = cache_names_layer[ind_num]
autoencoder = autoencoders_layer[ind_num]

num_feature_datapoints = 10
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32)

# features = [1,2,3,4,5,6,7,8,9,10]
# features = [5,6]
# features = None
num_features = 30
# feature = 0
input_setting = "input_only"
model_type="causal"
ablate_context = True

for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 1
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        # Go to next feature
        continue
    print(f"Feature {feature} is alive")
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    # uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="max")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    else:
        logit_diffs = ablate_feature_direction(model, dataset, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="dataset")
        _, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, tokenizer, max_seq_length, dataset)
        get_token_statistics(feature, logit_diffs, dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, setting="output", num_unique_tokens=10)
        save_token_display(full_token_list_ablated, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = full_activations_ablated)
    # if(ablate_context ==True):
    #     all_changed_activations = ablate_context_one_token_at_a_time(model, dataset, cache_name, autoencoder, feature, max_ablation_length=20)
    #     save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    # combine_images(feature, setting=input_setting, ablate_context=ablate_context)
    feature += 1

100%|██████████| 313/313 [00:10<00:00, 30.60it/s]


Feature 2191 is alive
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 3014 is alive
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               
