In [3]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification
from baukit import TraceDict


# Download the model
device = "cuda:0"
model_name="EleutherAI/Pythia-70M-deduped"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
from autoencoders import *
# ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
model_id = "jbrinkma/Pythia-70M-deduped-SAEs"
autoencoders = []
layers = model.config.num_hidden_layers
cache_names = [(f"gpt_neox.layers.{i}", f"gpt_neox.layers.{i+1}.mlp") for i in range(layers-1)]
num_layers = len(cache_names)
cache_names = [item for sublist in cache_names for item in sublist]
filenames = [(f"Pythia-70M-deduped-{i}.pt", f"Pythia-70M-deduped-mlp-{i+1}.pt") for i in range(layers-1)]
filenames = [item for sublist in filenames for item in sublist]
for filen in filenames:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filen)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to_device(device)
    # Freeze autoencoder weights
    autoencoder.encoder.requires_grad_(False)
    autoencoder.encoder_bias.requires_grad_(False)
    autoencoders.append(autoencoder)

In [5]:
from activation_dataset import chunk_and_tokenize

# Download the dataset
# This formats it, so every datapoint is max_length tokens long
dataset_name="stas/openwebtext-10k"
max_seq_length=32
dataset = load_dataset(dataset_name, split="train[:300]")
dataset, _ = chunk_and_tokenize(dataset, tokenizer, max_length=max_seq_length)
max_tokens = dataset.num_rows*max_seq_length
print(f"Number of tokens: {max_tokens/1e6:.2f}M")

Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-ffa9653217360dc9_*_of_00008.arrow


Number of tokens: 0.33M


In [27]:
import torch as torch
import os
from torch import nn
    
class mlp_no_bias(nn.Module):
    def __init__(self, input_size, output_size=None):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        self.linear = nn.Linear(input_size, input_size, bias=False)
        self.linear2 = nn.Linear(input_size, output_size, bias=False)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x
    def hidden_forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        return x
    
class linear_no_bias(nn.Module):
    def __init__(self, input_size, output_size=None):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        self.linear = nn.Linear(input_size, output_size, bias=False)
    def forward(self, x):
        x = self.linear(x)
        return x
    
class sae(nn.Module):
    def __init__(self, input_size, output_size=None):
        super().__init__()
        if(output_size is None):
            output_size = input_size
        self.linear = nn.Linear(input_size, input_size, bias=False)
        self.linear2 = nn.Linear(input_size, output_size, bias=False)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = self.linear(x)
        x_hidden = self.relu(x)
        x = self.linear2(x_hidden)
        return x, x_hidden

num_layers = 5
dir = "sparse_weights"
linear_name = "linear_weights_no_bias_{layer}.pt"
mlp_name = "mlp_weights_no_bias_{layer}.pt"
sae_name = "Pythia-70M-deduped-mlp-{layer}-l1_alpha-0.0004.pt"
linear_sparse_connections = []
mlp_sparse_connections = []
sae_sparse_connections = []
for layer in range(num_layers):
    linear_sparse_connections.append(torch.load(os.path.join(dir, linear_name.format(layer=layer))))
    mlp_sparse_connections.append(torch.load(os.path.join(dir, mlp_name.format(layer=layer))))
    sae_sparse_connections.append(torch.load(os.path.join(dir, sae_name.format(layer=layer+1))))

In [7]:
# Load in the features indices
# Note: linear & MLP features are a subet of alive features, not total features
import pickle
with open("linear_features.pkl", "rb") as f:
    linear_features = pickle.load(f)
with open("mlp_features.pkl", "rb") as f:
    mlp_features = pickle.load(f)
with open("alive_features.pkl", "rb") as f:
    alive_features_ind = pickle.load(f)

In [28]:
from baukit import Trace
from alpha_utils_interp import *

layer = 3
batch_size = 32
ind_num = 0
cache_names_layer = cache_names[layer*2:layer*2+2]
autoencoders_layer = autoencoders[layer*2:layer*2+2]
cache_name = cache_names_layer[ind_num]
autoencoder = autoencoders_layer[ind_num]
sae = sae_sparse_connections[layer].to(device)
mlp = mlp_sparse_connections[layer].to(device)

device = model.device
num_features, d_model = autoencoder.encoder.shape
datapoints = dataset.num_rows
dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
token_list = torch.zeros((datapoints*max_seq_length), dtype=torch.int64)
with torch.no_grad(), dataset.formatted_as("pt"):
    dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
    for i, batch in enumerate(tqdm(dl)):
        batch = batch.to(device)
        token_list[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = rearrange(batch, "b s -> (b s)")
        with Trace(model, cache_name) as ret:
            _ = model(batch).logits
            internal_activations = ret.output
            if(isinstance(internal_activations, tuple)):
                internal_activations = internal_activations[0]
        batched_neuron_activations = rearrange(internal_activations, "b s n -> (b s) n" )
        batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
        # _, batched_dictionary_activations = sae(batched_dictionary_activations)
        batched_dictionary_activations = mlp.hidden_forward(batched_dictionary_activations)
        dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()

100%|██████████| 320/320 [00:11<00:00, 27.08it/s]


In [29]:
max_activating_features = dictionary_activations[:10000, :].max(0).values
alive_features = max_activating_features.count_nonzero()
print(f"Number of alive features: {alive_features}")
features = max_activating_features.topk(20).indices.tolist()

Number of alive features: 2775


In [30]:
num_feature_datapoints = 5
# num_features = 30
# shifted = 0
# features = [i+shifted for i in range(num_features)]
model_type="causal"

for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 1
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        # Go to next feature
        continue
    print(f"Feature {feature} is alive")
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    save_token_display(full_token_list, full_activations, tokenizer, path =f"n/a", logit_diffs = None, save=False, model_type=model_type, show=True)

Feature 1422 is alive


  full_tok = torch.tensor(dataset[md]["input_ids"])


Feature 1431 is alive


Feature 2821 is alive


Feature 27 is alive


Feature 2461 is alive


Feature 2107 is alive


Feature 2935 is alive


Feature 2203 is alive


Feature 494 is alive


Feature 816 is alive


Feature 740 is alive


Feature 1302 is alive


Feature 516 is alive


Feature 2640 is alive


Feature 172 is alive


Feature 1995 is alive


Feature 1723 is alive


Feature 1299 is alive


Feature 630 is alive


Feature 1308 is alive


In [31]:
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)

####################################################################################################################################################################
# Other way around: restricted linear indices -> alive indices -> original indices
def linear_to_original(ind, layer, linear_or_mlp_features):
    linear_to_alive_map = linear_or_mlp_features[layer].nonzero()[0]
    alive_to_original_map = alive_features_ind[layer].nonzero()[:,0]
    return alive_to_original_map[linear_to_alive_map[ind]]
# Figure out Features
layer = 0

# setting = "linear"
setting = "sae"
k = 3
cache_names_layer = cache_names[layer*2:layer*2+2]
autoencoders_layer = autoencoders[layer*2:layer*2+2]
# Index for output (mlp_out)
# linear: 359, 448, 404, 472, 377
output_index_restricted = 8
if(setting == "linear"):
    output_index = linear_to_original(output_index_restricted, layer, linear_features)
    input_weights, input_index = linear_sparse_connections[layer].linear.weight.topk(k)
if(setting == "mlp"):    
    output_index = linear_to_original(output_index_restricted, layer, mlp_features)
    input_weights, input_index = mlp_sparse_connections[layer].linear2.weight.topk(k)
if(setting == "sae"):
    output_index = linear_to_original(output_index_restricted, layer, mlp_features)
    input_weights, input_index = sae_sparse_connections[layer].linear2.weight.topk(k)
    # input_index, input_weights = mlp_sparse_connections[layer].linear.weight.topk(k)

input_index = input_index[output_index_restricted]
input_weights = input_weights[output_index_restricted]
# output_index = mlp_features[layer].nonzero()[0][output_index_restricted]
# output_index = 835
# features_all = [input_index.tolist(), [output_index]]
features_all = [[], indices]

# Now view the features
# 0 is input feature (ie {} residual stream)
# 1 is output feature (ie mlp_out)
print(f"Weights:{input_weights}")
ind_num = 0

features = features_all[ind_num]
cache_name = cache_names_layer[ind_num]
autoencoder = autoencoders_layer[ind_num]

dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32)

num_feature_datapoints = 5
num_features = 30
model_type="causal"

for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 1
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        # Go to next feature
        continue
    print(f"Feature {feature} is alive")
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    save_token_display(full_token_list, full_activations, tokenizer, path =f"n/a", logit_diffs = None, save=False, model_type=model_type, show=True)

[2507, 1910, 1404, 2227, 2960]
Weights:tensor([0.3598, 0.3579, 0.3552], grad_fn=<SelectBackward0>)
Output feature


100%|██████████| 320/320 [00:10<00:00, 29.33it/s]


Feature 2507 is alive


  full_tok = torch.tensor(dataset[md]["input_ids"])


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 1910 is alive


  full_tok = torch.tensor(dataset[md]["input_ids"])


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 1404 is alive


  full_tok = torch.tensor(dataset[md]["input_ids"])


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 2227 is alive


  full_tok = torch.tensor(dataset[md]["input_ids"])


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 2960 is alive


  full_tok = torch.tensor(dataset[md]["input_ids"])


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Input feature


100%|██████████| 320/320 [00:11<00:00, 28.68it/s]
