In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import plotly.graph_objects as go
from scipy.stats import norm, variation, skew, kurtosis
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import json
from tqdm import trange

import haystack_utils
from haystack_utils import get_mlp_activations
from hook_utils import get_ablate_neuron_hook, save_activation
from pythia_160m_utils import get_neuron_accuracy, ablation_effect
import plotting_utils
from plotting_utils import plot_neuron_acts, color_binned_histogram

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)
%reload_ext autoreload
%autoreload 2

In [4]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

with open('data/TinyStories-train.txt', 'r') as f:
    full_data = f.readlines()

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


In [3]:
print(len(full_data))
prompts = [data[:400] for data in full_data[:200]]

14815490


In [4]:
tokens_set = set()
for prompt in full_data[:20000]:
    tokens = model.to_tokens(prompt).flatten().tolist()
    tokens_set.update(tokens)


print(len(tokens_set))

7357


In [5]:
import operator

N = 10_000

# Resize the tokenizer vocabulary to the top 10,000 tokens
# restricted_vocab = sorted(model.tokenizer.get_vocab().items(), key=operator.itemgetter(1))[:N]
# model.tokenizer = {k: v for k, v in restricted_vocab}

In [5]:
def get_neuron_token_variance(
        model: HookedTransformer, data: list[str],
        disable_tqdm=True, hook_pre=False
) -> pd.DataFrame:
    '''
    Get the variance of the activations of a neuron for each token in the dataset.
    This looks like a tensor of tokens and a tensor of activations, and building:
      a tensor of token x activation_sum (10000 x d_mlp)
      a tensor of token_count
      repeat layer by layer (or do it all at once if we have memory)
    '''
    token_neuron_acts = torch.zeros(model.cfg.n_layers, model.cfg.d_vocab, model.cfg.d_mlp).cuda()
    token_neuron_acts_squared = torch.zeros(model.cfg.n_layers, model.cfg.d_vocab, model.cfg.d_mlp).cuda()

    tokens = torch.empty(0, dtype=int).cuda()
    for item in data:
        tokens = torch.cat([tokens, model.to_tokens(item)[0]], dim=0)
    tokens = tokens.flatten()
    token_counts = torch.bincount(tokens, minlength=model.cfg.d_vocab).repeat(model.cfg.d_mlp, 1).reshape(-1, model.cfg.d_mlp)

    for layer in trange(model.cfg.n_layers):
        acts = get_mlp_activations(data, layer, model, mean=False, disable_tqdm=disable_tqdm, hook_pre=hook_pre, 
                                   context_crop_start=0, context_crop_end=400)
        print("acts nan", acts.isnan().sum())
        for i in range(tokens.shape[0]):
            token_neuron_acts[layer, tokens[i]] += acts[i]
            token_neuron_acts_squared[layer, tokens[i]] += acts[i] ** 2

        mean_acts = token_neuron_acts[layer] / token_counts + 1e-8
        print("mean acts nan", mean_acts.isnan().sum())
        mean_acts_squared = token_neuron_acts_squared[layer] / token_counts + 1e-8
        var_acts = mean_acts_squared - mean_acts ** 2
        token_neuron_acts[layer] = var_acts

    return token_neuron_acts

token_neuron_acts = get_neuron_token_variance(model, full_data[:4000])

  0%|          | 0/8 [00:00<?, ?it/s]

In [79]:
for layer in token_neuron_acts.keys():
    with open(f'data/neuron_token_vars_{layer}', 'wb') as f:
        pickle.dump(token_neuron_acts[layer], f)

In [81]:
valid_token_indices = list(tokens_set)
print(len(valid_token_indices))
token_neuron_acts[0][valid_token_indices, :].shape

7357


torch.Size([7357, 256])

In [89]:
token_neuron_acts[0, 0]

tensor([ 5.4765e-05, -4.8202e-02,         nan,         nan,         nan,
                nan, -1.1680e+01,         nan,         nan,         nan,
                nan,  7.7938e-02, -1.8480e-02,  1.0072e-03,         nan,
                nan,         nan,         nan,         nan,         nan,
                nan,         nan,         nan,         nan,         nan,
        -9.1745e-01, -8.1557e+01,         nan,         nan,         nan,
        -3.7388e-01,         nan, -3.1596e+01, -8.5876e+01, -3.8638e+01,
                nan, -3.8681e+01,         nan, -1.4826e+03, -4.1907e+03,
        -6.4918e-02, -1.4258e+01, -6.9301e+02, -1.4280e+01, -2.1543e-02,
        -4.6847e+03, -2.2924e+00, -3.9207e+02,         nan, -9.0109e+01,
        -7.6054e+01, -2.5891e+02,         nan,         nan, -1.6895e+03,
                nan, -1.8139e+02, -3.8759e+02,         nan,         nan,
                nan,         nan,         nan,         nan,  1.2338e-03,
        -1.2670e+01, -1.0005e+02, -3.6061e+03, -5.7

In [None]:
tensors = [torch.tensor([1, 2]), torch.tensor([3, 4, 5]), torch.tensor([6])]
lengths = [len(tensor) for tensor in tensors]
result = torch.empty(sum(lengths), dtype=torch.int64)

start = 0
for tensor in tensors:
    end = start + len(tensor)
    result[start:end] = tensor
    start = end

In [None]:
# large_acts_df = get_neuron_moments(model, prompts, [])
# layer_neuron_tuple = large_acts_df.sort_values('skew', ascending=False).iloc[:1][['layer', 'neuron']].values.tolist()[0]
# layer, neuron = layer_neuron_tuple

# hook_name = f'blocks.{layer}.mlp.hook_post'
# with model.hooks([(hook_name, save_activation)]):
#     model(prompts[-1])
# acts = model.hook_dict[hook_name].ctx['activation']
# haystack_utils.clean_print_strings_as_html(model.to_str_tokens(prompts[-1]), acts[0, :, neuron], max_value=1)