In [2]:
import torch
import json
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import plotly.graph_objects as go
from scipy.stats import norm

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
from haystack_utils import get_mlp_activations
from hook_utils import get_ablate_neuron_hook, save_activation
from pythia_160m_utils import get_neuron_accuracy, ablation_effect
import plotting_utils
from plotting_utils import plot_neuron_acts, color_binned_histogram


%reload_ext autoreload
%autoreload 2

In [5]:
with open('data/tiny_stories_chatgpt.json', 'r') as f:
    prompts = json.load(f)

In [3]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

# large_acts_df = plotting_utils.get_neuron_moments(model, prompts,
#                                                   [[i, j] for i in range(8) for j in range(256)], hook_pre=True)

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


In [6]:
print(model.generate('Once'))

  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a time, there was a big bear who


In [None]:
# Generate heaps of token sequences and average their cache, then find the difference with the OUAT neurons
# Ablate components until OUAT fails
# n-l AND?

_, cache = model.run_with_cache(model.to_tokens(prompts)[:, :40])

def rand_hook(value, hook):
    cache_val = cache[hook.name][:value.size(0), :value.size(1), :value.size(2)]
    mean_val = cache_val.mean(dim=0)
    broadcasted_val = mean_val.unsqueeze(0).expand_as(value)
    value = broadcasted_val
    return value

prompt = "Once"
print(model.generate(prompt, 20, temperature=0))
for layer in range(model.cfg.n_layers):
    with model.hooks([(f'blocks.{layer}.hook_attn_out', rand_hook)]):
        print(f"Attn {layer}", model.generate(prompt, 20, temperature=0))
    with model.hooks([(f'blocks.{layer}.hook_mlp_out', rand_hook)]):
        print(f"MLP {layer}", model.generate(prompt, 20, temperature=0))

# Necessary components:
# MLP0, MLP1, MLP2, MLP3, MLP4, MLP5, MLP7 (every MLP but 6)

In [12]:
# Figure out which neurons directly write to each vocab
# And see if we can ablate everything else

tokens = model.to_tokens("Once upon a time", prepend_bos=False)  # [1, 4]
token_dirs = model.tokens_to_residual_directions(tokens)[0]  # [4, 64]
token_dirs_reshaped = token_dirs.unsqueeze(1).unsqueeze(1)  # [4, 1, 1, 64]
W_out_reshaped = model.W_out.unsqueeze(0)  # [1, 8, 256, 64]

cosine_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
result = cosine_sim(token_dirs_reshaped, W_out_reshaped)  # [4, 8, 256]

layer_neuron_tuples = []
for token_index in range(result.size(0)):
    values, indices = torch.topk(result[token_index].view(-1), 30, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (8, 256))
    layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))

layer_neuron_dict = haystack_utils.get_neurons_by_layer(layer_neuron_tuples)


In [24]:
filtered_prompts = [prompt for prompt in prompts if not prompt.startswith("Once")]

In [27]:
sorted_layer_neuron_tuples = []
sorted_acts = []

for layer in layer_neuron_dict.keys():
    neurons = layer_neuron_dict[layer]
    mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
    sorted_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
    sorted_acts.extend(mean_acts)
    assert len(sorted_layer_neuron_tuples) == len(sorted_acts)

import hook_utils
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_layer_neuron_tuples, sorted_acts)

In [28]:
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0))

  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a time, there
Once there, there
