In [1]:
import torch
import json
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import plotly.graph_objects as go
from scipy.stats import norm
from einops import einsum

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
from haystack_utils import get_mlp_activations
from hook_utils import get_ablate_neuron_hook, save_activation
from pythia_160m_utils import get_neuron_accuracy, ablation_effect
import plotting_utils
from plotting_utils import plot_neuron_acts, color_binned_histogram
import hook_utils

%reload_ext autoreload
%autoreload 2

In [2]:
with open('data/tiny_stories_chatgpt.json', 'r') as f:
    prompts = json.load(f)

filtered_prompts = [prompt for prompt in prompts if not prompt.startswith("Once")]

In [2]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

# large_acts_df = plotting_utils.get_neuron_moments(model, prompts,
#                                                   [[i, j] for i in range(8) for j in range(256)], hook_pre=True)

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


### Unigrams

In [39]:
# Similar words stored near in the embed
tokens = model.to_tokens("Once upon a time", prepend_bos=False)
logits = model.W_E[tokens].squeeze(0) @ model.W_U

for i in range(4):
    values, indices = torch.topk(logits[i], 10)
    print(model.to_str_tokens(indices))

['Once', 'One', 'Yesterday', 'At', 'There', 'Today', 'On', 'Em', 'When', 'In']
[' upon', 'bie', 'ancer', 'itting', 'll', 'aj', 'packs', 'upon', 'uttering', ' haven']
[' a', ' an', ' the', ' some', ' two', ' something', ' another', ' one', ' very', ' many']
[' time', ' day', ' week', ' evening', ' morning', ' afternoon', ' night', ' Sunday', ' way', ' year']


In [17]:
# Associations
tokens = model.to_tokens("cow", prepend_bos=False)
logits = (model.W_E[tokens].squeeze(0) @ model.W_U).squeeze(0)
mean_logit = logits.median() # mean is lower

associated_tokens = []
for string in "grass brown field moo meat milk udders animal".split(' '):
    token = model.to_tokens(string, prepend_bos=False)
    associated_tokens.append(token[0, 0])
associated_tokens = torch.stack(associated_tokens)
mean_associated_logit = logits[associated_tokens].median() # mean is similar

print(mean_logit, mean_associated_logit)

values, indices = torch.topk(logits, 100)
print(model.to_str_tokens(indices))

tensor(0.1094, device='cuda:0') tensor(0.2133, device='cuda:0')
['cow', 'ns', 'ridge', 'wo', ' seventeen', 'eight', 'umbers', 'please', 'ucks', 'icians', 'oxy', 'bread', 'Seven', 'uns', 'bang', 'uggets', 'oster', 'wy', 'Ta', 'arella', 'irteen', ' reins', 'vals', ' fries', 'heddar', 'nine', 'aff', 'zee', 'rots', 'six', 'ban', 'bet', 'Six', 'Four', 'umbs', 'arians', 'Twenty', 'gary', 'cards', 'urger', 'workers', 'Sal', 'apple', 'rito', 'ulent', 'haw', 'cream', 'ction', 'leaf', 'load', 'agna', 'shirts', 'ksh', 'icy', 'keys', 'Apple', 'ants', 'trained', 'orted', 'ender', 'gob', 'chini', 'ocol', ' ounces', 'talk', 'rob', 'Dragon', 'chip', 'amer', 'plets', 'fed', 'iri', 'pill', 'mons', 'int', ' cents', 'reath', 'cker', 'pper', 'war', 'rol', 'Ah', 'osaurus', 'Ten', 'dollar', 'girls', 'seven', 'mur', 'affles', 'Offic', 'ws', 'ounced', 'ucker', 'iw', 'eating', 'ates', 'Ar', 'Chuck', 'Eight', 'mus']


In [24]:

tokens = model.to_tokens("Once upon a time", prepend_bos=False).squeeze(0)
logits = (model.W_E[tokens] @ model.W_U)
median_logit = logits.median()
print(median_logit)

next_logits = []
for i in range(tokens.shape[0] - 1):
    next_logits.append(logits[i, tokens[i + 1]])
next_logits = torch.stack(next_logits)
print(next_logits)

# 'a' token seems to boost some sensible completions
values, indices = torch.topk(logits[-2], 100)
print(model.to_str_tokens(indices))

tensor(-0.0134, device='cuda:0')
tensor([ 0.8695, -0.3610,  1.9775], device='cuda:0')
[' a', ' an', ' the', ' some', ' two', ' something', ' another', ' one', ' very', ' many', ' big', ' her', ' so', ' his', ' all', ' in', ' lots', ' to', ' more', ' three', ' four', ' it', ' their', ' on', ' out', ' there', ' small', ' too', ' that', ' and', ' not', '.', ' no', ' good', ' with', ' right', ' really', ' both', ' nice', ' different', ' only', ' five', ',', '\n', ' like', ' hard', ' A', ' bright', ' just', ' when', ' much', ' little', ' long', ' cool', ' brave', ' someone', ' fast', ' strong', ' Lily', ' funny', ' this', ' look', ' warm', ' made', ' fun', ' wide', ' new', ' as', ' silly', ' red', ' hot', ' each', ' close', ' open', ' pretty', ' Mom', ' what', ' Tom', ' other', ' its', ' even', ' far', ' tall', ' The', ' shiny', ' dark', ' loud', ' white', ' deep', ' soft', ' fair', ' black', ' j', ' outside', ' kind', ' your', 'A', ' for', ' help', ' happy']


In [4]:
# Generate heaps of token sequences and average their cache, then find the difference with the OUAT neurons
# Ablate components until OUAT fails
# n-l AND?

_, cache = model.run_with_cache(model.to_tokens(prompts)[:, :40])

def rand_hook(value, hook):
    cache_val = cache[hook.name][:value.size(0), :value.size(1), :value.size(2)]
    mean_val = cache_val.mean(dim=0)
    broadcasted_val = mean_val.unsqueeze(0).expand_as(value)
    value = broadcasted_val
    return value

prompt = "Once"
print(model.generate(prompt, 20, temperature=0))
for layer in range(model.cfg.n_layers):
    with model.hooks([(f'blocks.{layer}.hook_attn_out', rand_hook)]):
        print(f"Attn {layer}", model.generate(prompt, 20, temperature=0))
    with model.hooks([(f'blocks.{layer}.hook_mlp_out', rand_hook)]):
        print(f"MLP {layer}", model.generate(prompt, 20, temperature=0))

# Necessary components:
# MLP0, MLP1, MLP2, MLP3, MLP4, MLP5, MLP7 (every MLP but 6)

  0%|          | 0/20 [00:00<?, ?it/s]

Once upon a time, there was a little girl named Lily. She loved to play outside in the park


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 0 Once upon a time, there was a little girl named Lily. She loved to play with her toys and


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 0 Once time time time time time time time time time time time time time time time time time time time time


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 1 Once upon a time, there was a little girl named Lily and she went on a big day playing outside


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 1 Once,,,,,,,,,,,,,,,,,,,,


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 2 Once upon a time, there was a big little little little little little little little little little little little little


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 2 Once, to a pretty pretty pretty pretty pretty pretty pretty pretty pretty. to look to the blue blue blue


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 3 Once upon a time, in a big, big, big, big, big, big, big,


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 3 Once,, little little little little little little little little little little little little little little little little little little


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 4 Once upon a time, there was a little girl named Lilymymymymymymymymymy


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 4 Once were her mommy asked to pack inside,"Mommy" oh oh oh oh oh oh oh


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 5 Once upon a time, a little girl named Lily. She loved to play outside in the forest. One


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 5 Once mom,,,,, but then,,, but then,,,, but then,


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 6 Once upon a time, there was a little girl named Lily. She was very sad and her favorite toy


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 6 Once upon a time, upon a a a a a a a a a and princess. She and princess


  0%|          | 0/20 [00:00<?, ?it/s]

Attn 7 Once upon a time, there was a little girl named Timmy. Timmy was very sad because he


  0%|          | 0/20 [00:00<?, ?it/s]

MLP 7 Once?





















### Ablate by cosine sim - degrades around 80

In [30]:
# Figure out which neurons directly write to each vocab
# And see if we can ablate everything else

tokens = model.to_tokens("Once upon a time", prepend_bos=False)  # [1, 4]
token_dirs = model.tokens_to_residual_directions(tokens)[0]  # [4, 64]
token_dirs_reshaped = token_dirs.unsqueeze(1).unsqueeze(1)  # [4, 1, 1, 64]
W_out_reshaped = model.W_out.unsqueeze(0)  # [1, 8, 256, 64]

cosine_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
result = cosine_sim(token_dirs_reshaped, W_out_reshaped)  # [4, 8, 256]

_, cache = model.run_with_cache(tokens)
acts = [cache[f'blocks.{layer}.mlp.hook_post'][0] for layer in range(model.cfg.n_layers)] # [[batch pos]]*n_layers
acts = torch.stack(acts, dim=1) # 


layer_neuron_tuples = []
for token_index in range(result.size(0)):
    values, indices = torch.topk(result[token_index].view(-1), 20, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (8, 256))
    layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))


In [31]:
layer_neuron_dict = haystack_utils.get_neurons_by_layer(layer_neuron_tuples)

sorted_dla_layer_neuron_tuples = []
sorted_acts = []

for layer in layer_neuron_dict.keys():
    neurons = layer_neuron_dict[layer]
    mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
    sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
    sorted_acts.extend(mean_acts)
    assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_layer_neuron_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0))

  0%|          | 0/10 [00:00<?, ?it/s]

Once in a time, a little girl named Amy was


### Ablate by DLA

In [22]:
def batched_dot_product(x, y):
    return torch.vmap(torch.dot)(x, y)
    

def neuron_DLA(prompt: str, model: HookedTransformer, pos=np.s_[-1:]) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons'''
    tokens = model.to_tokens(prompt)
    tokens = tokens[:, :-1]
    answers = tokens[:, 1:]
    _, cache = model.run_with_cache(tokens)
    attrs, labels = cache.get_full_resid_decomposition(-1, expand_neurons=True, apply_ln=True, return_labels=True, pos_slice=pos)
    answer_residual_directions = model.tokens_to_residual_directions(answers)[0, pos]  # [pos d_model]

    neuron_indices = [i for i in range(len(labels)) if 'N' in labels[i]]
    neuron_labels = [labels[i] for i in neuron_indices]
    neuron_attrs = attrs[neuron_indices, :].squeeze(1)
    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_directions[[i]].repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels


def get_unspecified_neurons(model: HookedTransformer, neurons: list[tuple[int, int]]):
    layer_dict = haystack_utils.get_neurons_by_layer(neurons)
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

attrs, labels = neuron_DLA("Once upon a time", model, pos=np.s_[-1:])

In [67]:
x, y = haystack_utils.DLA(["Once upon"], model)

  0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
haystack_utils.line(x.cpu().squeeze(0))

In [20]:
attrs, labels = neuron_DLA("Once upon a time", model)
haystack_utils.line(attrs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="DLA per neuron in layer")

Tried to stack head results when they weren't cached. Computing head results now


RuntimeError: Expected size for first two dimensions of batch2 tensor to be: [2048, 64] but got: [2048, 1].

In [13]:
animal = 'cows'
attrs, labels = neuron_DLA(f'Once upon a time there was a boy named Bob. Bob loves to learn about different animals. "What do {animal} eat?", his mother asked. "They eat', 
                                 model)
haystack_utils.line(attrs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="DLA per neuron in layer")

Tried to stack head results when they weren't cached. Computing head results now


In [14]:
print(attrs.sum())
px.histogram(attrs.flatten().cpu().numpy())

tensor(31.1651, device='cuda:0')


In [43]:
# Compare top neurons found with DLA vs. those found with cosine sim
dla_layer_neuron_tuples = []
for token_index in range(1):
    values, indices = torch.topk(attrs[token_index], 100, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    dla_layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))

print(dla_layer_neuron_tuples[:5])
print(layer_neuron_tuples[:5])

[(3, 238), (0, 4), (0, 35), (0, 5), (0, 34)]
[(0, 78), (0, 22), (2, 236), (2, 125), (5, 164)]


In [44]:
# Ablate the neurons found above and check that it messes things up
layer_neuron_dict = haystack_utils.get_neurons_by_layer(dla_layer_neuron_tuples)

sorted_dla_layer_neuron_tuples = []
sorted_acts = []

for layer in layer_neuron_dict.keys():
    neurons = layer_neuron_dict[layer]
    mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
    sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
    sorted_acts.extend(mean_acts)
    assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_layer_neuron_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0))

  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a time time, there, in a happy


In [89]:
dla_layer_neuron_tuples = []
for token_index in range(4):
    values, indices = torch.topk(attrs[token_index].view(-1), 5, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (8, 256))
    dla_layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))

unspecified_neuron_tuples = get_unspecified_neurons(model, dla_layer_neuron_tuples)
layer_neuron_dict = haystack_utils.get_neurons_by_layer(unspecified_neuron_tuples)

sorted_dla_layer_neuron_tuples = []
sorted_acts = []

for layer in layer_neuron_dict.keys():
    neurons = layer_neuron_dict[layer]
    mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
    sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
    sorted_acts.extend(mean_acts)
    assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_layer_neuron_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0))

  0%|          | 0/10 [00:00<?, ?it/s]

OnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnce
