In [1]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import plotly.graph_objects as go
from scipy.stats import norm, variation, skew, kurtosis
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import json
from tqdm import trange

import haystack_utils
from haystack_utils import get_mlp_activations
from hook_utils import get_ablate_neuron_hook, save_activation
from pythia_160m_utils import get_neuron_accuracy, ablation_effect
import plotting_utils
from plotting_utils import plot_neuron_acts, color_binned_histogram

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)
%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

with open('data/TinyStories-train.txt', 'r') as f:
    full_data = f.readlines()

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


In [3]:
print(len(full_data))
prompts = [data[:400] for data in full_data[:200]]

14815490


In [10]:
tokens_set = set()
for prompt in full_data[:20000]:
    tokens = model.to_tokens(prompt).flatten().tolist()
    tokens_set.update(tokens)


print(len(tokens_set))

7357


In [5]:
import operator

N = 10_000

# Resize the tokenizer vocabulary to the top 10,000 tokens
# restricted_vocab = sorted(model.tokenizer.get_vocab().items(), key=operator.itemgetter(1))[:N]
# model.tokenizer = {k: v for k, v in restricted_vocab}

In [30]:
def get_all_tokens(model, data):
    tokens = torch.empty(0, dtype=int).cuda()
    for item in data:
        tokens = torch.cat([tokens, model.to_tokens(item)[0]], dim=0)
    tokens = tokens.flatten()
    return tokens

In [31]:
def get_neuron_token_variance(
        model: HookedTransformer, data: list[str],
        tokens: torch.Tensor | None=None,
        disable_tqdm=True, hook_pre=False
) -> pd.DataFrame:
    '''
    Get the variance of the activations of a neuron for each token in the dataset.
    This looks like a tensor of tokens and a tensor of activations, and building:
      a tensor of token x activation_sum (10000 x d_mlp)
      a tensor of token_count
      repeat layer by layer (or do it all at once if we have memory)
    '''
    token_neuron_acts = torch.zeros(model.cfg.n_layers, model.cfg.d_vocab, model.cfg.d_mlp).cuda()
    token_neuron_acts_squared = torch.zeros(model.cfg.n_layers, model.cfg.d_vocab, model.cfg.d_mlp).cuda()

    if tokens is None:
        tokens = get_all_tokens(model, data)
    # Repeat the token counts for each neuron so we can do per-neuron operations later
    token_counts = torch.bincount(tokens, minlength=model.cfg.d_vocab).unsqueeze(1).repeat(1, model.cfg.d_mlp)

    for layer in trange(model.cfg.n_layers):
        acts = get_mlp_activations(data, layer, model, mean=False, disable_tqdm=disable_tqdm, hook_pre=hook_pre, 
                                   context_crop_start=0, context_crop_end=400)
        std = torch.std(acts, dim=0)
        acts = acts / std
        for i in range(tokens.shape[0]):
            token_neuron_acts[layer, tokens[i]] += acts[i]
            token_neuron_acts_squared[layer, tokens[i]] += acts[i] ** 2

        mean_acts = token_neuron_acts[layer] / (token_counts + 1e-8)
        mean_acts_squared = token_neuron_acts_squared[layer] / (token_counts + 1e-8)
        var_acts = mean_acts_squared - mean_acts ** 2
        token_neuron_acts[layer] = var_acts

    return token_neuron_acts

In [32]:
tokens = get_all_tokens(model, full_data[:20000])
token_neuron_acts = get_neuron_token_variance(model, full_data[:20000], tokens)

100%|██████████| 8/8 [40:37<00:00, 304.72s/it]


In [33]:
valid_token_indices = list(tokens_set)
print(len(valid_token_indices))

valid_neuron_acts = {}
for layer in range(model.cfg.n_layers):
    valid_neuron_acts[layer] = token_neuron_acts[layer][valid_token_indices]
    # Remove any weird numbers caused by tokens that never appear in the neuron act dataset
    valid_neuron_acts[layer][valid_neuron_acts[layer] > 1e10] = 0
    valid_neuron_acts[layer][valid_neuron_acts[layer] < -1e10] = 0

    with open(f'data/neuron_token_vars_{layer}', 'wb') as f:
        pickle.dump(token_neuron_acts[layer], f)    

valid_neuron_acts[0][0]

7357


tensor([3.3559e-03, 4.0625e-03, 1.0918e-03, 1.4053e-03, 7.7228e-03, 6.1107e-03,
        5.4795e-03, 6.3961e-03, 5.3644e-06, 1.1842e-03, 1.6332e-03, 2.5883e-03,
        5.3342e-03, 6.0445e-02, 7.4473e-03, 5.0735e-03, 5.1242e-03, 2.7890e-03,
        2.0046e-03, 2.2921e-03, 2.4658e-02, 2.8625e-03, 3.0652e-04, 3.4750e-04,
        3.2605e-03, 4.5653e-03, 1.2042e-03, 3.4957e-03, 3.4612e-03, 3.4347e-02,
        1.7056e-02, 5.2084e-04, 2.5506e-03, 2.8339e-03, 9.7350e-03, 1.2945e-03,
        1.5241e-03, 2.6124e-03, 1.9694e-02, 4.6659e-03, 1.9001e-03, 2.6565e-03,
        3.1988e-03, 6.9871e-03, 1.5029e-02, 5.8228e-04, 8.1278e-03, 2.1648e-04,
        1.8340e-04, 2.8374e-03, 8.9211e-03, 1.1770e-02, 5.1433e-04, 4.0309e-02,
        2.7327e-03, 1.0933e-02, 2.0526e-02, 1.0729e-03, 1.1268e-03, 8.2975e-03,
        8.3030e-04, 2.4622e-03, 9.9301e-05, 5.3285e-03, 8.5162e-03, 1.9978e-03,
        1.3846e-03, 9.9334e-04, 8.5476e-04, 6.8040e-03, 2.3564e-03, 4.1972e-03,
        2.6683e-03, 3.0035e-02, 2.1154e-

In [40]:
import gzip
import pickle
import os
import torch.quantization

def quantize_8bit(input):
    offset = input.min(axis=0).values
    scale = (input.max(axis=0).values - offset) / 255
    quant = ((input - offset) / scale).float().round().clamp(0, 255).to(torch.uint8)
    return quant, offset, scale

def unquantize_8bit(input, offset, scale):
    """Unquantize a tensor to a given precision.

    Args:
        input (torch.Tensor): The tensor to quantize.
        precision (int): The number of bits to quantize to.

    Returns:
        torch.Tensor: The quantized tensor.
    """
    return input.to(torch.float16) * scale + offset

# for layer in range(model.cfg.n_layers):
#     with gzip.open(f'data/tiny_stories/neuron_vars_{layer}.pkl.gz', 'rb') as f:
#         layer_neuron_acts = pickle.load(f)
#         activations, offset, scale = quantize_8bit(layer_neuron_acts)
#         pickle.dump(f, activations)

AttributeError: 'dict' object has no attribute 'min'

In [None]:
tensors = [torch.tensor([1, 2]), torch.tensor([3, 4, 5]), torch.tensor([6])]
lengths = [len(tensor) for tensor in tensors]
result = torch.empty(sum(lengths), dtype=torch.int64)

start = 0
for tensor in tensors:
    end = start + len(tensor)
    result[start:end] = tensor
    start = end

In [None]:
# large_acts_df = get_neuron_moments(model, prompts, [])
# layer_neuron_tuple = large_acts_df.sort_values('skew', ascending=False).iloc[:1][['layer', 'neuron']].values.tolist()[0]
# layer, neuron = layer_neuron_tuple

# hook_name = f'blocks.{layer}.mlp.hook_post'
# with model.hooks([(hook_name, save_activation)]):
#     model(prompts[-1])
# acts = model.hook_dict[hook_name].ctx['activation']
# haystack_utils.clean_print_strings_as_html(model.to_str_tokens(prompts[-1]), acts[0, :, neuron], max_value=1)