In [1]:
import torch
from torch import nn
from torchtyping import TensorType

class TiedSAE(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super().__init__()
        self.encoder = nn.Parameter(torch.empty((n_dict_components, activation_size)))
        nn.init.xavier_uniform_(self.encoder)
        self.encoder_bias = nn.Parameter(torch.zeros((n_dict_components,)))

    def get_learned_dict(self):
        norms = torch.norm(self.encoder, 2, dim=-1)
        return self.encoder / torch.clamp(norms, 1e-8)[:, None]

    def encode(self, batch):
        c = torch.einsum("nd,bd->bn", self.encoder, batch)
        c = c + self.encoder_bias
        c = torch.clamp(c, min=0.0)
        return c

    def decode(self, code: TensorType["_batch_size", "_n_dict_components"]) -> TensorType["_batch_size", "_activation_size"]:
        learned_dict = self.get_learned_dict()
        x_hat = torch.einsum("nd,bn->bd", learned_dict, code)
        return x_hat

    def forward(self, batch: TensorType["_batch_size", "_activation_size"]) -> TensorType["_batch_size", "_activation_size"]:
        c = self.encode(batch)
        x_hat = self.decode(c)
        return x_hat, c

    def n_dict_components(self):
        return self.get_learned_dict().shape[0]

In [3]:
import torch
from transformer_lens import HookedTransformer
import numpy as np
from torch.utils.data import DataLoader
from datasets import Dataset, load_dataset
from tqdm.auto import tqdm
from functools import partial
from einops import rearrange
from huggingface_hub import hf_hub_download
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification

layer = 9
# dataset_name = "NeelNanda/pile-10k"
# dataset_name = "BlueSunflower/chess_games_base"
dataset_name = "stas/openwebtext-10k"
device = "cuda:0"
# model_name = "EleutherAI/pythia-70m-deduped"
# model_name="usvsnsp/pythia-6.9b-rm-full-hh-rlhf"
# model_name = "BlueSunflower/Pythia-160M-chess"
model_name="reciprocate/dahoas-gptj-rm-static"

max_seq_length=60
input_setting = "input_only"
# input_setting = "both"
model_type="reward_model"
# model_type="causal"
ablate_context = True

# model_id = "Elriggs/pythia-70m-deduped"
# model_id = "Elriggs/pythia-6.9-rm"
model_id = "Elriggs/dahoas-gptj-rm-sae"
# model_id = "jbrinkma/Pythia-160M-chess-sp76-r4-gpt-neox-layers-4"
filename = f"dahoas-gptj-rm-static_r4_transformer.h.{layer}.pt"
# filename = f"tied_residual_l{layer}_r6/_63/learned_dicts.pt" 
ae_download_location = hf_hub_download(repo_id=model_id, filename=filename)
# all_autoencoders = torch.load(ae_download_location)
# num_l1s = len(all_autoencoders)
# all_l1s = [hyperparams["l1_alpha"] for autoencoder, hyperparams in all_autoencoders]
# print(all_l1s)
# auto_num = 5
# autoencoder, hyperparams = all_autoencoders[auto_num]
autoencoder = torch.load(ae_download_location)
# autoencoder.to_device(device)
autoencoder.to(device)
# cache_name = f"gpt_neox.layers.{layer}"
cache_name=f"transformer.h.{layer}"


if(model_type == "reward_model"):
    model = AutoModelForSequenceClassification.from_pretrained(model_name).to(device)
else:
    model = AutoModelForCausalLM.from_pretrained(model_name).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

Downloading (…)4_transformer.h.9.pt: 100%|██████████| 269M/269M [00:05<00:00, 48.7MB/s] 
Loading checkpoint shards: 100%|██████████| 3/3 [00:25<00:00,  8.57s/it]
Some weights of the model checkpoint at reciprocate/dahoas-gptj-rm-static were not used when initializing GPTJForSequenceClassification: ['transformer.h.1.attn.masked_bias', 'transformer.h.21.attn.bias', 'transformer.h.19.attn.bias', 'transformer.h.6.attn.bias', 'transformer.h.17.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.26.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.22.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.23.attn.bias', 'transformer.h.10.attn.bias', 'transformer.h.23.attn.masked_bias', 'transformer.h.7.attn.masked_bias', 'transformer.h.9.attn.bias', 'transformer.h.4.attn.bias', 'transformer.h.8.attn.masked_bias', 'transformer.h.20.attn.masked_bias', 'transformer.h.3.attn.bias', 'transformer.h.1

In [None]:
# import json
# from datasets import Dataset, load_from_disk

# # Create a Hugging Face dataset
# dataset = Dataset.from_json("test_stockfish_5000.json").map(
#     lambda x: tokenizer(x["moves"]),
#     batched=True,
# ).filter(
#     lambda x: len(x['input_ids']) > max_seq_length
# ).map(
#     lambda x: {'input_ids': x['input_ids'][:max_seq_length]}
# )

In [4]:
def download_dataset(dataset_name, tokenizer, max_length=256, num_datapoints=None):
    if(num_datapoints):
        split_text = f"train[:{num_datapoints}]"
    else:
        split_text = "train"
    dataset = load_dataset(dataset_name, split=split_text).map(
        lambda x: tokenizer(x['text']),
        batched=True,
    ).filter(
        lambda x: len(x['input_ids']) > max_length
    ).map(
        lambda x: {'input_ids': x['input_ids'][:max_length]}
    )
    return dataset

print(f"Downloading {dataset_name}")
dataset = download_dataset(dataset_name, tokenizer=tokenizer, max_length=max_seq_length, num_datapoints=None) # num_datapoints grabs all of them if None

Downloading stas/openwebtext-10k


Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-73ad36b49865c8de.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-674bb14400f25eb6.arrow
Loading cached processed dataset at /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b/cache-a9d221474ae6ef99.arrow


In [5]:
from alpha_utils_interp import *
import os
# make features/ dir if not exist
save_path = "features/"
if not os.path.exists(save_path):
    os.makedirs(save_path)
num_feature_datapoints = 10
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, max_seq_length, autoencoder, batch_size=32)

# features = [1,2,3,4,5,6,7,8,9,10]
# features = [5,6]
# features = None
num_features = 30
feature = 0
input_setting = "input_only"
for _ in range(num_features):
    # Check if feature is dead (<10 activations)
    dead_threshold = 10
    # if(dictionary_activations[:, current_feature].count_nonzero() < dead_threshold):
    while(dictionary_activations[:, feature].count_nonzero() < dead_threshold):
        print(f"Feature {feature} is dead")
        feature += 1
    uniform_indices = get_feature_indices(feature, dictionary_activations, k=num_feature_datapoints, setting="uniform")
    text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], tokenizer, max_seq_length, dataset)
    # get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
    if(input_setting == "input_only"):
        # Calculate logit diffs on this feature for the full_token_list
        logit_diffs = ablate_feature_direction(model, full_token_list, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="sentences", model_type=model_type)
        # save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type)
        save_token_display(full_token_list, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = logit_diffs, model_type=model_type, show=True)
        all_changed_activations = ablate_context_one_token_at_a_time(model, token_list, cache_name, autoencoder, feature, max_ablation_length=30)
        save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    else:
        logit_diffs = ablate_feature_direction(model, dataset, cache_name, max_seq_length, autoencoder, feature = feature, batch_size=32, setting="dataset")
        _, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, tokenizer, max_seq_length, dataset)
        get_token_statistics(feature, logit_diffs, dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, setting="output", num_unique_tokens=10)
        save_token_display(full_token_list_ablated, full_activations, tokenizer, path =f"{save_path}uniform_{feature}.png", logit_diffs = full_activations_ablated)
    # if(ablate_context ==True):
    #     all_changed_activations = ablate_context_one_token_at_a_time(model, dataset, cache_name, autoencoder, feature, max_ablation_length=20)
    #     save_token_display(token_list, all_changed_activations, tokenizer, path =f"{save_path}ablate_context_{feature}.png", model_type=model_type, show=True)
    # combine_images(feature, setting=input_setting, ablate_context=ablate_context)
    feature += 1

100%|██████████| 313/313 [08:25<00:00,  1.62s/it]


Feature 0 is dead
Feature 1 is dead
Feature 2 is dead
Feature 3 is dead
Feature 4 is dead
Feature 5 is dead
Feature 6 is dead
Feature 7 is dead
Feature 8 is dead
Feature 9 is dead
Feature 10 is dead
Feature 11 is dead


  bins = torch.bucketize(best_feature_activations, bin_boundaries)


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 13 is dead
Feature 14 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 16 is dead
Feature 17 is dead
Feature 18 is dead
Feature 19 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 21 is dead
Feature 22 is dead
Feature 23 is dead
Feature 24 is dead
Feature 25 is dead
Feature 26 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 28 is dead
Feature 29 is dead
Feature 30 is dead
Feature 31 is dead
Feature 32 is dead
Feature 33 is dead
Feature 34 is dead
Feature 35 is dead
Feature 36 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 38 is dead
Feature 39 is dead
Feature 40 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 42 is dead
Feature 43 is dead
Feature 44 is dead
Feature 45 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 47 is dead
Feature 48 is dead
Feature 49 is dead
Feature 50 is dead
Feature 51 is dead
Feature 52 is dead
Feature 53 is dead
Feature 54 is dead
Feature 55 is dead
Feature 56 is dead
Feature 57 is dead
Feature 58 is dead
Feature 59 is dead
Feature 60 is dead
Feature 61 is dead
Feature 62 is dead
Feature 63 is dead
Feature 64 is dead
Feature 65 is dead
Feature 66 is dead
Feature 67 is dead
Feature 68 is dead
Feature 69 is dead
Feature 70 is dead
Feature 71 is dead
Feature 72 is dead
Feature 73 is dead
Feature 74 is dead
Feature 75 is dead
Feature 76 is dead
Feature 77 is dead
Feature 78 is dead
Feature 79 is dead
Feature 80 is dead
Feature 81 is dead
Feature 82 is dead
Feature 83 is dead
Feature 84 is dead
Feature 85 is dead
Feature 86 is dead
Feature 87 is dead
Feature 88 is dead
Feature 89 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 91 is dead
Feature 92 is dead
Feature 93 is dead
Feature 94 is dead
Feature 95 is dead
Feature 96 is dead
Feature 97 is dead
Feature 98 is dead
Feature 99 is dead
Feature 100 is dead
Feature 101 is dead
Feature 102 is dead
Feature 103 is dead
Feature 104 is dead
Feature 105 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 107 is dead
Feature 108 is dead
Feature 109 is dead
Feature 110 is dead
Feature 111 is dead
Feature 112 is dead
Feature 113 is dead
Feature 114 is dead
Feature 115 is dead
Feature 116 is dead
Feature 117 is dead
Feature 118 is dead
Feature 119 is dead
Feature 120 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 122 is dead
Feature 123 is dead
Feature 124 is dead
Feature 125 is dead
Feature 126 is dead
Feature 127 is dead
Feature 128 is dead
Feature 129 is dead
Feature 130 is dead
Feature 131 is dead
Feature 132 is dead
Feature 133 is dead
Feature 134 is dead
Feature 135 is dead
Feature 136 is dead
Feature 137 is dead
Feature 138 is dead
Feature 139 is dead
Feature 140 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 143 is dead
Feature 144 is dead
Feature 145 is dead
Feature 146 is dead
Feature 147 is dead
Feature 148 is dead
Feature 149 is dead
Feature 150 is dead
Feature 151 is dead
Feature 152 is dead
Feature 153 is dead
Feature 154 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 156 is dead
Feature 157 is dead
Feature 158 is dead
Feature 159 is dead
Feature 160 is dead
Feature 161 is dead
Feature 162 is dead
Feature 163 is dead
Feature 164 is dead
Feature 165 is dead
Feature 166 is dead
Feature 167 is dead
Feature 168 is dead
Feature 169 is dead
Feature 170 is dead
Feature 171 is dead
Feature 172 is dead
Feature 173 is dead
Feature 174 is dead
Feature 175 is dead
Feature 176 is dead
Feature 177 is dead
Feature 178 is dead
Feature 179 is dead
Feature 180 is dead
Feature 181 is dead
Feature 182 is dead
Feature 183 is dead
Feature 184 is dead
Feature 185 is dead
Feature 186 is dead
Feature 187 is dead
Feature 188 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 190 is dead
Feature 191 is dead
Feature 192 is dead
Feature 193 is dead
Feature 194 is dead
Feature 195 is dead
Feature 196 is dead
Feature 197 is dead
Feature 198 is dead
Feature 199 is dead
Feature 200 is dead
Feature 201 is dead
Feature 202 is dead
Feature 203 is dead
Feature 204 is dead
Feature 205 is dead
Feature 206 is dead
Feature 207 is dead
Feature 208 is dead
Feature 209 is dead
Feature 210 is dead
Feature 211 is dead
Feature 212 is dead
Feature 213 is dead
Feature 214 is dead
Feature 215 is dead
Feature 216 is dead
Feature 217 is dead
Feature 218 is dead
Feature 219 is dead
Feature 220 is dead
Feature 221 is dead
Feature 222 is dead
Feature 223 is dead
Feature 224 is dead
Feature 225 is dead
Feature 226 is dead
Feature 227 is dead
Feature 228 is dead
Feature 229 is dead
Feature 230 is dead
Feature 231 is dead
Feature 232 is dead
Feature 233 is dead
Feature 234 is dead
Feature 235 is dead
Feature 236 is dead
Feature 237 is dead
Feature 238 is dead
Feature 239 is dead


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 259 is dead
Feature 260 is dead
Feature 261 is dead
Feature 262 is dead
Feature 263 is dead
Feature 264 is dead
Feature 265 is dead
Feature 266 is dead
Feature 267 is dead
Feature 268 is dead
Feature 269 is dead
Feature 270 is dead
Feature 271 is dead
Feature 272 is dead
Feature 273 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 275 is dead
Feature 276 is dead
Feature 277 is dead
Feature 278 is dead
Feature 279 is dead
Feature 280 is dead
Feature 281 is dead
Feature 282 is dead
Feature 283 is dead
Feature 284 is dead
Feature 285 is dead
Feature 286 is dead
Feature 287 is dead
Feature 288 is dead
Feature 289 is dead
Feature 290 is dead
Feature 291 is dead
Feature 292 is dead
Feature 293 is dead
Feature 294 is dead
Feature 295 is dead
Feature 296 is dead
Feature 297 is dead
Feature 298 is dead
Feature 299 is dead
Feature 300 is dead
Feature 301 is dead
Feature 302 is dead
Feature 303 is dead
Feature 304 is dead
Feature 305 is dead
Feature 306 is dead
Feature 307 is dead
Feature 308 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 310 is dead
Feature 311 is dead
Feature 312 is dead
Feature 313 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 315 is dead
Feature 316 is dead
Feature 317 is dead
Feature 318 is dead
Feature 319 is dead
Feature 320 is dead
Feature 321 is dead
Feature 322 is dead
Feature 323 is dead
Feature 324 is dead
Feature 325 is dead
Feature 326 is dead
Feature 327 is dead
Feature 328 is dead
Feature 329 is dead
Feature 330 is dead
Feature 331 is dead
Feature 332 is dead
Feature 333 is dead
Feature 334 is dead
Feature 335 is dead
Feature 336 is dead
Feature 337 is dead
Feature 338 is dead
Feature 339 is dead
Feature 340 is dead
Feature 341 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 343 is dead
Feature 344 is dead
Feature 345 is dead
Feature 346 is dead
Feature 347 is dead
Feature 348 is dead
Feature 349 is dead
Feature 350 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 352 is dead
Feature 353 is dead
Feature 354 is dead
Feature 355 is dead
Feature 356 is dead
Feature 357 is dead
Feature 358 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 360 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 362 is dead
Feature 363 is dead
Feature 364 is dead
Feature 365 is dead
Feature 366 is dead
Feature 367 is dead
Feature 368 is dead
Feature 369 is dead
Feature 370 is dead
Feature 371 is dead
Feature 372 is dead
Feature 373 is dead
Feature 374 is dead
Feature 375 is dead
Feature 376 is dead
Feature 377 is dead
Feature 378 is dead
Feature 379 is dead
Feature 380 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 382 is dead
Feature 383 is dead
Feature 384 is dead
Feature 385 is dead
Feature 386 is dead
Feature 387 is dead
Feature 388 is dead
Feature 389 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 392 is dead
Feature 393 is dead
Feature 394 is dead
Feature 395 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 397 is dead
Feature 398 is dead
Feature 399 is dead
Feature 400 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 402 is dead
Feature 403 is dead
Feature 404 is dead
Feature 405 is dead
Feature 406 is dead
Feature 407 is dead
Feature 408 is dead
Feature 409 is dead
Feature 410 is dead
Feature 411 is dead
Feature 412 is dead
Feature 413 is dead
Feature 414 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 416 is dead
Feature 417 is dead
Feature 418 is dead
Feature 419 is dead
Feature 420 is dead
Feature 421 is dead
Feature 422 is dead
Feature 423 is dead
Feature 424 is dead
Feature 425 is dead
Feature 426 is dead
Feature 427 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Feature 429 is dead
Feature 430 is dead
Feature 431 is dead
Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


Loading page (1/2)
Rendering (2/2)                                                    
Done                                                               


In [7]:
sparsity = dictionary_activations[0:100].count_nonzero(dim=1).float().mean().item()
print(f"Sparsity: {sparsity}")
dead_features = dictionary_activations[:100000].sum(0).count_nonzero().item()
print(f"Alive features: {dead_features}")

Sparsity: 9.130000114440918
Alive features: 1315


In [36]:
dictionary_activations.sum(0).count_nonzero().item()

7250

In [35]:
dictionary_activations.shape

torch.Size([600000, 16384])

In [None]:
# get_token_statistics(feature, dictionary_activations[:, feature], dataset, tokenizer, max_seq_length, tokens_for_each_datapoint, save_location = save_path, num_unique_tokens=10)
setting = "input"
feature_activation = dictionary_activations[:, feature]
save_location=""
num_unique_tokens=10
negative_threshold=-0.01

if(setting=="input"):
    nonzero_indices = feature_activation.nonzero()[:, 0]  # Get the nonzero indices
else:
    nonzero_indices = (feature_activation < negative_threshold).nonzero()[:, 0]
nonzero_values = feature_activation[nonzero_indices].abs()  # Get the nonzero values

# Unravel the indices to get the token IDs
datapoint_indices = [np.unravel_index(i, (dataset.num_rows, max_seq_length)) for i in nonzero_indices]
all_tokens = [dataset[int(md)]["input_ids"][int(s_ind)] for md, s_ind in datapoint_indices]

# Find the max value for each unique token
token_value_dict = defaultdict(int)
for token, value in zip(all_tokens, nonzero_values):
    token_value_dict[token] = max(token_value_dict[token], value)
# if(setting=="input"):
sorted_tokens = sorted(token_value_dict.keys(), key=lambda x: -token_value_dict[x])
# else:
#     sorted_tokens = sorted(token_value_dict.keys(), key=lambda x: token_value_dict[x])
# Take the top 10 (or fewer if there aren't 10)
max_tokens = sorted_tokens[:min(num_unique_tokens, len(sorted_tokens))]
total_sums = nonzero_values.abs().sum()
max_token_sums = []
token_activations = []
assert len(max_tokens) > 0, "poo poo pee pee"

In [None]:
feature_activation.count_nonzero()

In [None]:
dictionary_activations.count_nonzero(dim=0)[:10]

In [None]:
torch.stack(full_token_list)

In [None]:
# Run example text through model
text = "I like to eat pizza."
tokens = tokenizer(text, return_tensors="pt").input_ids.to(device)
# outputs = model(tokens)
# outputs.logits

device = model.device
def less_than_rank_1_ablate(value, hook):
    # Only ablate the feature direction up to the negative bias
    # ie Only subtract when it activates above that negative bias.

    # Rearrange to fit autoencoder
    int_val = rearrange(value, 'b s h -> (b s) h')
    # Run through the autoencoder
    act = autoencoder.encode(int_val)
    dictionary_for_this_autoencoder = autoencoder.get_learned_dict()
    feature_direction = torch.outer(act[:, feature].squeeze(), dictionary_for_this_autoencoder[feature].squeeze())
    batch, seq_len, hidden_size = value.shape
    feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)
    value -= feature_direction
    return value

# with Trace(model, cache_name, edit_output=less_than_rank_1_ablate) as ret:
#     _ = model(batch).logits
#     representation = ret.output
#     # check if instance tuple
#     if(isinstance(representation, tuple)):
#         representation = representation[0]


batch = tokens
original_logits = model(batch).logits.log_softmax(dim=-1)
# with Trace(model, cache_name, edit_output=less_than_rank_1_ablate) as ret:
with Trace(model, cache_name, edit_output=less_than_rank_1_ablate) as ret:
    ablated_logits = model(batch).logits

In [None]:

# def ablate_feature_direction(model, dataset, cache_name, max_seq_length, autoencoder, feature, batch_size=32):

device = model.cfg.device
def less_than_rank_1_ablate(value, hook):
    # Only ablate the feature direction up to the negative bias
    # ie Only subtract when it activates above that negative bias.
    # Rearrange to fit autoencoder
    int_val = rearrange(value, 'b s h -> (b s) h')
    # Run through the autoencoder
    act = autoencoder.encode(int_val)
    dictionary_for_this_autoencoder = autoencoder.get_learned_dict()
    feature_direction = torch.outer(act[:, feature].squeeze(), dictionary_for_this_autoencoder[feature].squeeze())
    batch, seq_len, hidden_size = value.shape
    feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)
    value -= feature_direction
    return value

# tokens_batched = full_tokens_list
tokens_batched = torch.stack(full_token_list)
logit_diffs = torch.zeros((tokens_batched.shape[0], tokens_batched.shape[1]-1))

with torch.no_grad(), dataset.formatted_as("pt"):
    original_logits = model(tokens_batched.to(device)).log_softmax(dim=-1)
    ablated_logits = model.run_with_hooks(tokens_batched.to(device), fwd_hooks=[(cache_name, less_than_rank_1_ablate)]).log_softmax(dim=-1)
    diff_logits = ablated_logits  - original_logits# ablated > original -> negative diff
    gather_tokens = rearrange(tokens_batched[:,1:].to(device), "b s -> b s 1")
    gathered = diff_logits[:, :-1].gather(-1,gather_tokens)
    # append all 0's to the beggining of gathered
    gathered = torch.cat([torch.zeros((gathered.shape[0],1,1)).to(device), gathered], dim=1)
    logit_diffs = rearrange(gathered, "b s n -> (b s n)").cpu()

In [None]:
def make_colorbar(min_value, max_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    # Add color bar
    colorbar = ""
    num_colors = 4
    if(min_value < -negative_threshold):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    # Do zero
    colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>'
    # Do positive
    if(max_value > positive_threshold):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    return colorbar

def value_to_color(activation, max_value, min_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    if activation > positive_threshold:
        ratio = activation/max_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)'
    elif activation < -negative_threshold:
        ratio = activation/min_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)'
    else:
        text_color = "0,0,0"
        background_color = f'rgba({white},{white},{white},1)'
    return text_color, background_color

def convert_token_array_to_list(array):
    if isinstance(array, torch.Tensor):
        if array.dim() == 1:
            array = [array.tolist()]
        elif array.dim()==2:
            array = array.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(array, list):
        # ensure it's a list of lists
        if isinstance(array[0], int):
            array = [array]
    return array

def tokens_and_activations_to_html(toks, activations, tokenizer, logit_diffs=None):
    # text_spacing = "0.07em"
    text_spacing = "0.00em"
    toks = convert_token_array_to_list(toks)
    activations = convert_token_array_to_list(activations)
    # toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '\\n') for t in tok] for tok in toks]
    highlighted_text = []
    # Make background black
    # highlighted_text.append('<body style="background-color:black; color: white;">')
    highlighted_text.append("""
<body style="background-color: black; color: white;">
""")
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    if(logit_diffs):
        logit_max_value = max([max(activ) for activ in logit_diffs])
        logit_min_value = min([min(activ) for activ in logit_diffs])

    # Add color bar
    highlighted_text.append("Token Activations: " + make_colorbar(min_value, max_value))
    if(logit_diffs):
        highlighted_text.append('<div style="margin-top: 0.1em;"></div>')
        highlighted_text.append("Logit Diff: " + make_colorbar(logit_min_value, logit_max_value))
    
    highlighted_text.append('<div style="margin-top: 0.5em;"></div>')
    for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
        for act_ind, (a, t) in enumerate(zip(act, tok)):
            if(logit_diffs):
                highlighted_text.append('<div style="display: inline-block;">')
            text_color, background_color = value_to_color(a, max_value, min_value)
            highlighted_text.append(f'<span style="background-color:{background_color};margin-right: {text_spacing}; color:rgb({text_color})">{t.replace(" ", "&nbsp")}</span>')
            if(logit_diffs):
                logit_diffs_act = logit_diffs[seq_ind][act_ind]
                _, logit_background_color = value_to_color(logit_diffs_act, logit_max_value, logit_min_value)
                highlighted_text.append(f'<div style="display: block; margin-right: {text_spacing}; height: 10px; background-color:{logit_background_color}; text-align: center;"></div></div>')
        highlighted_text.append('<div style="margin-top: 0.2em;"></div>')
        # highlighted_text.append('<br><br>')
    # highlighted_text.append('</body>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text

def display_tokens(tokens, activations, tokenizer, logit_diffs=None):
    return display(HTML(tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs)))

def save_token_display(tokens, activations, tokenizer, path, logit_diffs=None):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs)
    imgkit.from_string(html, path)
    # print(f"Saved to {path}")
    return
save_token_display(full_token_list_ablated, full_activations, model.tokenizer, path =f"{save_path}uniform_black_{feature}.png", logit_diffs = full_activations_ablated)
# display_tokens(full_token_list_ablated, full_activations, model.tokenizer, logit_diffs = full_activations_ablated)

In [None]:
from PIL import Image, ImageDraw, ImageFont

# Load an existing image
image = Image.open(f"concatenated_image.png")

# Create a drawing object
draw = ImageDraw.Draw(image)

# Define the position for the text to be placed (this example puts it at the top-center)
text = "Feature {feature}"
text_position = (image.width // 2, 10)

# You can use a truetype or opentype font file if you want to customize
# Otherwise, PIL provides a basic built-in font
try:
    font = ImageFont.truetype("arial.ttf", size=30)
except IOError:
    font = ImageFont.load_default()

# Calculate text size and position to center-align the text
text_width, text_height = draw.textsize(text, font=font)
text_position = ((image.width - text_width) // 2, 100)

# Add the text to image
draw.text(text_position, text, font=font, fill="white")

# Save or show the image
image.save("image_with_title.png")
image.show()


In [None]:
font.font.getsize(text)

In [None]:
text_position, text, font

In [None]:
from IPython.core.display import display, HTML

# The initial HTML and CSS to set the background to black
initial_html = """
<!DOCTYPE html>
<html>
<head>
    <title>My Page</title>
</head>
<body style="background-color: blue;">
"""

# The HTML for the text spans that you'll append later
span_text = """
    <span style="color: white;">This is some text.</span>
    <span style="color: green;">This is some more text.</span>
"""

# Closing tags for the HTML
closing_html = """
</body>
</html>
"""

# Combine everything
full_html = initial_html + span_text + closing_html

# Display the HTML
display(HTML(full_html))

In [None]:
model.cfg.device

In [None]:
d_model = model.cfg.d_model
assert (d_model == autoencoder.encoder.shape[-1]), f"Model and autoencoder must have same hidden size. Model: {d_model}, Autoencoder: {autoencoder.encoder.shape[-1]}"

In [None]:
# Now we can use the model to get the activations
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from einops import rearrange
def get_dictionary_activations(model, dataset, cache_name, autoencoder, batch_size=32):
    num_features, d_model = autoencoder.encoder.shape
    datapoints = dataset.num_rows
    dictionary_activations = torch.zeros((datapoints*max_seq_length, num_features))
    token_list = torch.zeros((datapoints*max_seq_length), dtype=torch.int64)
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            token_list[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = rearrange(batch, "b s -> (b s)")
            _, cache = model.run_with_cache(batch.to(device))
            batched_neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )
            batched_dictionary_activations = autoencoder.encode(batched_neuron_activations)
            dictionary_activations[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length,:] = batched_dictionary_activations.cpu()
    return dictionary_activations, token_list

print("Getting dictionary activations")
dictionary_activations, tokens_for_each_datapoint = get_dictionary_activations(model, dataset, cache_name, autoencoder, batch_size=32)

In [None]:

def ablate_feature_direction(model, dataset, cache_name, autoencoder, feature, batch_size=32):
    def less_than_rank_1_ablate(value, hook):
        # Only ablate the feature direction up to the negative bias
        # ie Only subtract when it activates above that negative bias.

        # Rearrange to fit autoencoder
        int_val = rearrange(value, 'b s h -> (b s) h')
        # Run through the autoencoder
        act = autoencoder.encode(int_val)
        dictionary_for_this_autoencoder = autoencoder.get_learned_dict()
        feature_direction = torch.outer(act[:, feature].squeeze(), dictionary_for_this_autoencoder[feature].squeeze())
        batch, seq_len, hidden_size = value.shape
        feature_direction = rearrange(feature_direction, '(b s) h -> b s h', b=batch, s=seq_len)
        value -= feature_direction
        return value

    datapoints = dataset.num_rows
    logit_diffs = torch.zeros((datapoints*max_seq_length))
    with torch.no_grad(), dataset.formatted_as("pt"):
        dl = DataLoader(dataset["input_ids"], batch_size=batch_size)
        for i, batch in enumerate(tqdm(dl)):
            original_logits = model(batch.to(device)).log_softmax(dim=-1)
            ablated_logits = model.run_with_hooks(batch.to(device), fwd_hooks=[(cache_name, less_than_rank_1_ablate)]).log_softmax(dim=-1)
            diff_logits = ablated_logits  - original_logits# ablated > original -> negative diff
            gather_tokens = rearrange(batch[:,1:].to(device), "b s -> b s 1")
            gathered = diff_logits[:, :-1].gather(-1,gather_tokens)
            # append all 0's to the beggining of gathered
            gathered = torch.cat([torch.zeros((gathered.shape[0],1,1)).to(device), gathered], dim=1)
            diff = rearrange(gathered, "b s n -> (b s n)")
            # Add one to the first position of logit diff, so we're always skipping over the first token (since it's not predicted)
            logit_diffs[i*batch_size*max_seq_length:(i+1)*batch_size*max_seq_length] = diff.cpu()
    return logit_diffs
feature = 1
logit_diffs = ablate_feature_direction(model, dataset, cache_name, autoencoder, feature = feature, batch_size=32)

In [None]:
# # from interp_utils import *
# if isinstance(features, int):
#     features = [features]
# for feature in features:
#     text_list, full_text, token_list, full_token_list = get_feature_datapoints(feature, dictionary_activations, model.tokenizer, max_seq_length, dataset, setting="uniform")
#     # text_list, full_text, token_list, full_token_list = get_feature_datapoints(feature, dictionary_activations, dataset, setting="max")
#     # visualize_text(full_text, feature, model, autoencoder, layer)
# l = visualize_text(text_list, feature, model, autoencoder, layer)

In [None]:
from IPython.display import display, HTML
import imgkit

def make_colorbar(min_value, max_value, white = 245, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    # Add color bar
    colorbar = ""
    num_colors = 4
    if(min_value < -negative_threshold):
        for i in range(num_colors, 0, -1):
            ratio = i / (num_colors)
            value = round((min_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1); color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    # Do zero
    colorbar += f'<span style="background-color:rgba({white},{white},{white},1);color:rgb(0,0,0)">&nbsp0.0&nbsp</span>'
    # Do positive
    if(max_value > positive_threshold):
        for i in range(1, num_colors+1):
            ratio = i / (num_colors)
            value = round((max_value*ratio),1)
            text_color = "255,255,255" if ratio > 0.5 else "0,0,0"
            colorbar += f'<span style="background-color:rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1);color:rgb({text_color})">&nbsp{value}&nbsp</span>'
    return colorbar

def value_to_color(activation, max_value, min_value, white = 245, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):
    if activation > positive_threshold:
        ratio = activation/max_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)'
    elif activation < -negative_threshold:
        ratio = activation/min_value
        text_color = "0,0,0" if ratio <= 0.5 else "255,255,255"  
        background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)'
    else:
        text_color = "0,0,0"
        background_color = f'rgba({white},{white},{white},1)'
    return text_color, background_color

def convert_token_array_to_list(array):
    if isinstance(array, torch.Tensor):
        if array.dim() == 1:
            array = [array.tolist()]
        elif array.dim()==2:
            array = array.tolist()
        else: 
            raise NotImplementedError("tokens must be 1 or 2 dimensional")
    elif isinstance(array, list):
        # ensure it's a list of lists
        if isinstance(array[0], int):
            array = [array]
    return array

def tokens_and_activations_to_html(toks, activations, tokenizer, logit_diffs=None):
    toks = convert_token_array_to_list(toks)
    activations = convert_token_array_to_list(activations)
    # toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '↵') for t in tok] for tok in toks]
    toks = [[tokenizer.decode(t).replace('Ġ', '&nbsp').replace('\n', '\\n') for t in tok] for tok in toks]
    highlighted_text = []
    max_value = max([max(activ) for activ in activations])
    min_value = min([min(activ) for activ in activations])
    if(logit_diffs):
        logit_max_value = max([max(activ) for activ in logit_diffs])
        logit_min_value = min([min(activ) for activ in logit_diffs])

    # Add color bar
    highlighted_text.append("Token Activations: " + make_colorbar(min_value, max_value))
    if(logit_diffs):
        highlighted_text.append('<br><br>')
        highlighted_text.append("Logit Diff: " + make_colorbar(logit_min_value, logit_max_value))
        
    highlighted_text.append('<br><br>')
    for seq_ind, (act, tok) in enumerate(zip(activations, toks)):
        for act_ind, (a, t) in enumerate(zip(act, tok)):
            if(logit_diffs):
                highlighted_text.append('<div style="display: inline-block;">')
            text_color, background_color = value_to_color(a, max_value, min_value)
            highlighted_text.append(f'<span style="background-color:{background_color};color:rgb({text_color})">{t.replace(" ", "&nbsp")}</span>')
            if(logit_diffs):
                logit_diffs_act = logit_diffs[seq_ind][act_ind]
                _, logit_background_color = value_to_color(logit_diffs_act, logit_max_value, logit_min_value)
                highlighted_text.append(f'<div style="display: block; height: 10px; background-color:{logit_background_color}; text-align: center;"></div></div>')
        highlighted_text.append('<br><br>')
    highlighted_text = ''.join(highlighted_text)
    return highlighted_text

def display_tokens(tokens, activations, tokenizer, logit_diffs=None):
    return display(HTML(tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs)))

def save_token_display(tokens, activations, tokenizer, path, logit_diffs=None):
    html = tokens_and_activations_to_html(tokens, activations, tokenizer, logit_diffs)
    imgkit.from_string(html, path)
    # print(f"Saved to {path}")
    return

In [None]:
def get_feature_indices(feature_index, dictionary_activations, tokenizer, token_amount, dataset, k=10, setting="max"):
    best_feature_activations = dictionary_activations[:, feature_index]
    # Sort the features by activation, get the indices
    if setting=="max":
        found_indices = torch.argsort(best_feature_activations, descending=True)[:k]
    elif setting=="uniform":
        # min_value = torch.min(best_feature_activations)
        min_value = torch.min(best_feature_activations)
        max_value = torch.max(best_feature_activations)

        # Define the number of bins
        num_bins = k

        # Calculate the bin boundaries as linear interpolation between min and max
        bin_boundaries = torch.linspace(min_value, max_value, num_bins + 1)

        # Assign each activation to its respective bin
        bins = torch.bucketize(best_feature_activations, bin_boundaries)

        # Initialize a list to store the sampled indices
        sampled_indices = []

        # Sample from each bin
        for bin_idx in torch.unique(bins):
            if(bin_idx==0): # Skip the first one. This is below the median
                continue
            # Get the indices corresponding to the current bin
            bin_indices = torch.nonzero(bins == bin_idx, as_tuple=False).squeeze(dim=1)
            
            # Randomly sample from the current bin
            sampled_indices.extend(np.random.choice(bin_indices, size=1, replace=False))

        # Convert the sampled indices to a PyTorch tensor & reverse order
        found_indices = torch.tensor(sampled_indices).long().flip(dims=[0])
    else: # random
        # get nonzero indices
        nonzero_indices = torch.nonzero(best_feature_activations)[:, 0]
        # shuffle
        shuffled_indices = nonzero_indices[torch.randperm(nonzero_indices.shape[0])]
        found_indices = shuffled_indices[:k]
    return found_indices
def get_feature_datapoints(found_indices, best_feature_activations, tokenizer, token_amount, dataset):
    num_datapoints = dataset.num_rows
    datapoint_indices =[np.unravel_index(i, (num_datapoints, token_amount)) for i in found_indices]
    all_activations = best_feature_activations.reshape(num_datapoints, token_amount).tolist()
    full_activations = []
    partial_activations = []
    text_list = []
    full_text = []
    token_list = []
    full_token_list = []
    for i, (md, s_ind) in enumerate(datapoint_indices):
        md = int(md)
        s_ind = int(s_ind)
        full_tok = torch.tensor(dataset[md]["input_ids"])
        full_text.append(tokenizer.decode(full_tok))
        tok = dataset[md]["input_ids"][:s_ind+1]
        full_activations.append(all_activations[md])
        partial_activations.append(all_activations[md][:s_ind+1])
        text = tokenizer.decode(tok)
        text_list.append(text)
        token_list.append(tok)
        full_token_list.append(full_tok)
    return text_list, full_text, token_list, full_token_list, partial_activations, full_activations

uniform_indices = get_feature_indices(feature, dictionary_activations, model.tokenizer, max_seq_length, dataset, k=10, setting="uniform")
text_list, full_text, token_list, full_token_list, partial_activations, full_activations = get_feature_datapoints(uniform_indices, dictionary_activations[:, feature], model.tokenizer, max_seq_length, dataset)
display_tokens(token_list, partial_activations, model.tokenizer)

In [None]:
from collections import defaultdict
import matplotlib.pyplot as plt

def get_token_statistics(feature, feature_activation, dataset, max_seq_length, save_location="", num_unique_tokens=10, setting="input", negative_threshold=-0.01):
    if(setting=="input"):
        nonzero_indices = feature_activation.nonzero()[:, 0]  # Get the nonzero indices
    else:
        nonzero_indices = (feature_activation < negative_threshold).nonzero()[:, 0]
    nonzero_values = feature_activation[nonzero_indices].abs()  # Get the nonzero values

    # Unravel the indices to get the token IDs
    datapoint_indices = [np.unravel_index(i, (dataset.num_rows, max_seq_length)) for i in nonzero_indices]
    all_tokens = [dataset[int(md)]["input_ids"][int(s_ind)] for md, s_ind in datapoint_indices]

    # Find the max value for each unique token
    token_value_dict = defaultdict(int)
    for token, value in zip(all_tokens, nonzero_values):
        token_value_dict[token] = max(token_value_dict[token], value)
    # if(setting=="input"):
    sorted_tokens = sorted(token_value_dict.keys(), key=lambda x: -token_value_dict[x])
    # else:
    #     sorted_tokens = sorted(token_value_dict.keys(), key=lambda x: token_value_dict[x])
    # Take the top 10 (or fewer if there aren't 10)
    max_tokens = sorted_tokens[:min(num_unique_tokens, len(sorted_tokens))]
    total_sums = nonzero_values.abs().sum()
    max_token_sums = []
    token_activations = []
    for max_token in max_tokens:
        # Find ind of max token
        max_token_indices = tokens_for_each_datapoint[nonzero_indices] == max_token
        # Grab the values for those indices
        max_token_values = nonzero_values[max_token_indices]
        max_token_sum = max_token_values.abs().sum()
        max_token_sums.append(max_token_sum)
        token_activations.append(max_token_values)


    if(setting=="input"):
        title_text = "Input Token Activations"
        save_name = "input"
        y_label = "Feature Activation"
    else:
        title_text = "Output Logit Difference"
        save_name = "logit_diff"
        y_label = "Logit Difference"

    # Plot a boxplot for each tensor in the list
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_title(f'{title_text}: feature={feature}')
    max_text = [model.tokenizer.decode([t]).replace("\n", "\\n").replace(" ", "_") for t in max_tokens]
    # Set x-axis label
    ax.set_xlabel('Token')
    #rotate x labels
    plt.xticks(rotation=30)
    # Set y-axis label
    ax.set_ylabel(y_label)
    ax.boxplot(token_activations[::-1], labels=max_text[::-1])
    #Save it
    plt.savefig(f'{save_location}feature_{feature}_{save_name}_boxplot.png')

    #Bar graph of the percentage of total activations
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.set_title(f'Weighted Percentage of {title_text}: feature={feature}')
    max_text = [model.tokenizer.decode([t]).replace("\n", "\\n").replace(" ", "_") for t in max_tokens]
    # Set x-axis label
    ax.set_xlabel('Token')
    plt.xticks(rotation=30)

    # Set y-axis label
    ax.set_ylabel(f'Weighted Percentage of Total {y_label}')
    ax.bar(max_text[::-1], [t/total_sums*100 for t in max_token_sums[::-1]])
    plt.savefig(f'{save_location}feature_{feature}_{save_name}_bar.png')
# get_token_statistics(feature, dictionary_activations[:, feature], dataset, max_seq_length, save_location = "features/", num_unique_tokens=10)
get_token_statistics(feature, logit_diffs, dataset, max_seq_length, save_location = "features/", setting="output", num_unique_tokens=10)

In [None]:
(logit_diffs < -0.01).nonzero()[:, 0], feature_activation.nonzero()[:, 0]

In [None]:
_, _, _, full_token_list_ablated, _, full_activations_ablated = get_feature_datapoints(uniform_indices, logit_diffs, model.tokenizer, max_seq_length, dataset)
display_tokens(full_token_list_ablated, full_activations, model.tokenizer, logit_diffs = full_activations_ablated)

In [None]:
tokens_and_activations_to_html(full_token_list_ablated, full_activations, model.tokenizer, logit_diffs = full_activations_ablated)

In [None]:
display(HTML('<div style="display: inline-block;"><span style="background-color:rgba(245,245,245,1);color:rgb(0,0,0)">About</span><div style="display: block; height: 10px; background-color:rgba(245,245,245,1); text-align: center;"></div></div><div style="display: inline-block;"><span style="background-color:rgba(245,245,245,1);color:rgb(0,0,0)">\\n</span><div style="display: block; height: 10px; background-color:rgba(245,245,245,1); text-align: center;"></div></div><div style="display: inline-block;"><span style="background-color:rgba(245,245,245,1);color:rgb(0,0,0)">\\n</span><div style="display: block; height: 10px; background-color:rgba(245,245,245,1); text-align: center;"></div></div><div style="display: inline-block;"><span style="background-color:rgba(245,245,245,1);color:rgb(0,0,0)">Te</span><div style="display: block; height: 10px; background-color:rgba(245,245,245,1); text-align: center;"></div></div><div style="display: inline-block;"><span style="background-color:rgba(245,245,245,1);color:rgb(0,0,0)">ams</span><div style="display: block; height: 10px; background-color:rgba(245,245,245,1); text-align: center;"></div></div><div style="display: inline-block;"><span style="background-color:rgba(245,245,245,1);color:rgb(0,0,0)">&nbspare</span><div style="display: block; height: 10px; background-color:rgba(245,245,245,1); text-align: center;"></div></div>'))

In [None]:
#def ablate_context_one_token_at_a_time(model, dataset, cache_name, autoencoder, feature, batch_size=32):
all_changed_activations = []
for token_ind, token_l in enumerate(token_list):
# for token_ind, token_l in enumerate(full_token_list):
    seq_size = len(token_l)
    original_activation = partial_activations[token_ind][-1]
    # Run through the model for each seq length
    if(seq_size==1):
        continue # Size 1 sequences don't have any context to ablate
    # changed_activations = torch.zeros(seq_size).cpu() + original_activation
    changed_activations = torch.zeros(seq_size).cpu() 
    for i in range(seq_size-1):
        # ablated_tokens = token_l[:i+1] + token_l[i+1:]
        ablated_tokens = token_l
        ablated_tokens = torch.tensor(ablated_tokens).unsqueeze(0)
        with torch.no_grad():
            _, cache = model.run_with_cache(ablated_tokens.to(device))
            neuron_activations = rearrange(cache[cache_name], "b s n -> (b s) n" )
            dictionary_activations = autoencoder.encode(neuron_activations)

            changed_activations[i] += dictionary_activations[-1,feature].item()
    changed_activations -= original_activation
    all_changed_activations.append(changed_activations.tolist())

In [None]:
display_tokens(token_list, all_changed_activations, model.tokenizer)

In [None]:
all_changed_activations

In [None]:
display_tokens(token_list, partial_activations, model.tokenizer)

In [None]:
ablate_feature_direction_display(full_text, autoencoder, model, layer, features=feature)