In [6]:
import torch
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification


# Download the model
device = "cuda:0"
model_name="BlueSunflower/Pythia-70M-chess"
model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)
max_seq_length=30


Downloading (…)lve/main/config.json: 100%|██████████| 630/630 [00:00<00:00, 2.54MB/s]
Downloading pytorch_model.bin: 100%|██████████| 307M/307M [00:10<00:00, 29.3MB/s] 
Downloading (…)neration_config.json: 100%|██████████| 111/111 [00:00<00:00, 409kB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 284/284 [00:00<00:00, 984kB/s]
Downloading (…)/main/tokenizer.json: 100%|██████████| 2.11M/2.11M [00:00<00:00, 30.8MB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 131/131 [00:00<00:00, 481kB/s]


## Download the autoencoders

In [64]:
ae_model_id = ["jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.1", "jbrinkma/Pythia-70M-chess_sp51_r4_gpt_neox.layers.2.mlp"]
filename = "sae.pt"
autoencoders = []
for model_id in ae_model_id:
    ae_download_location = hf_hub_download(repo_id=model_id, filename=filename)
    autoencoder = torch.load(ae_download_location)
    autoencoder.to_device(device)
    autoencoders.append(autoencoder)
cache_names = ["_".join(model_id.split("_")[-2:]) for model_id in ae_model_id]

Downloading sae.pt: 100%|██████████| 4.20M/4.20M [00:00<00:00, 7.24MB/s]
Downloading sae.pt: 100%|██████████| 4.20M/4.20M [00:00<00:00, 16.8MB/s]


## Download data

In [8]:
import json
import os
import requests
import tarfile
from pathlib import Path


def download_data(): 

    # setup dir
    data_dir = "./data"
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)

    file_path = os.path.join(data_dir, "data_stockfish_262k.tar.gz")
    if not Path(file_path).is_file():

        # load tar.gz file
        r = requests.get("https://huggingface.co/datasets/BlueSunflower/chess_games_base/resolve/main/data_stockfish_262k.tar.gz")
        with open(file_path, 'wb') as f:
            f.write(r.content)

        # unpack tar.gz file
        file = tarfile.open(file_path) 
        file.extractall(data_dir) 
        file.close() 

download_data()

In [26]:
from datasets import load_dataset
local_path = "data/test_stockfish_5000.json"
dataset = load_dataset("json", data_files=local_path, split="train").map(
    lambda x: tokenizer(x['moves']),
    batched=True
).filter(
    lambda x: len(x['input_ids']) > max_seq_length
).map(
    lambda x: {'input_ids': x['input_ids'][:max_seq_length]}
)

Found cached dataset json (/root/.cache/huggingface/datasets/json/default-a6a5482aabd33742/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4)
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-a6a5482aabd33742/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-58fdf3e322afcc8e.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-a6a5482aabd33742/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-172a37628a616d5d.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/json/default-a6a5482aabd33742/0.0.0/e347ab1c932092252e717ff3f949105a4dd28b27e842dd53157d2f72e276c2e4/cache-1aed0094d3128987.arrow


## Visualize a feature

In [81]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10
ae_index = 0
cache_name = cache_names[ae_index]
autoencoder = autoencoders[ae_index]
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32)
sparsity = dictionary_activations[:80].count_nonzero(dim=1).float().mean()
print(f"Sparsity: {sparsity}")
max_values = dictionary_activations.max(dim=0)

feature_ind = max_values.values.topk(20).indices

100%|██████████| 157/157 [00:05<00:00, 31.03it/s]


Sparsity: 60.412498474121094


In [82]:
input_setting = "input_only"
model_type = "causal"
features = feature_ind.tolist()
num_features = 10
for feature in features:
    # Check if feature is dead (<10 activations)
    dead_threshold = 10
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    if(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        continue
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    else:
        logit_diffs = ablate_feature_direction(model, dataset, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="dataset")
        _, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, tokenizer, max_seq_length, dataset)
        get_token_statistics(feature, logit_diffs, dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, setting="output", num_unique_tokens=10)
        save_token_display(full_token_list_ablated, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = full_activations_ablated)

Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


## Find related features
### Gradient based method

In [None]:
# Grab a datapoint

# run through the model w/ two ae's intervened

# Use some version of this code: 
import torch

# Forward pass
output = model(input)
loss = criterion(output, target)

# Zero gradients
optimizer.zero_grad()

# Backward pass to compute gradients
loss.backward(retain_graph=True)

# Get the gradient between two internal parameters
grad_N_to_N_minus_1 = torch.autograd.grad(param_N, param_N_minus_1, retain_graph=True)[0]


### Correlational -> Causal