In [1]:
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor
import plotly.io as pio
import numpy as np
from tqdm import trange

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
import hook_utils

%reload_ext autoreload
%autoreload 2

In [2]:
data = haystack_utils.load_txt_data('data/TinyStories-train.txt')[:500]
filtered_prompts = [prompt for prompt in data if not prompt.startswith("Once")]

data/TinyStories-train.txt: Loaded 14815490 examples with 0 to 4928 characters each.


In [3]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


### Unigrams

In [None]:
# Similar words stored near in the embed
tokens = model.to_tokens("Once upon a time", prepend_bos=False)
logits = model.W_E[tokens].squeeze(0) @ model.W_U

for i in range(4):
    values, indices = torch.topk(logits[i], 10)
    print(model.to_str_tokens(indices))

['Once', 'One', 'Yesterday', 'At', 'There', 'Today', 'On', 'Em', 'When', 'In']
[' upon', 'bie', 'ancer', 'itting', 'll', 'aj', 'packs', 'upon', 'uttering', ' haven']
[' a', ' an', ' the', ' some', ' two', ' something', ' another', ' one', ' very', ' many']
[' time', ' day', ' week', ' evening', ' morning', ' afternoon', ' night', ' Sunday', ' way', ' year']


In [None]:
# Associations
tokens = model.to_tokens("cow", prepend_bos=False)
logits = (model.W_E[tokens].squeeze(0) @ model.W_U).squeeze(0)
mean_logit = logits.median() # mean is lower

associated_tokens = []
for string in "grass brown field moo meat milk udders animal".split(' '):
    token = model.to_tokens(string, prepend_bos=False)
    associated_tokens.append(token[0, 0])
associated_tokens = torch.stack(associated_tokens)
mean_associated_logit = logits[associated_tokens].median() # mean is similar

print(mean_logit, mean_associated_logit)

values, indices = torch.topk(logits, 100)
print(model.to_str_tokens(indices))

tensor(0.1094, device='cuda:0') tensor(0.2133, device='cuda:0')
['cow', 'ns', 'ridge', 'wo', ' seventeen', 'eight', 'umbers', 'please', 'ucks', 'icians', 'oxy', 'bread', 'Seven', 'uns', 'bang', 'uggets', 'oster', 'wy', 'Ta', 'arella', 'irteen', ' reins', 'vals', ' fries', 'heddar', 'nine', 'aff', 'zee', 'rots', 'six', 'ban', 'bet', 'Six', 'Four', 'umbs', 'arians', 'Twenty', 'gary', 'cards', 'urger', 'workers', 'Sal', 'apple', 'rito', 'ulent', 'haw', 'cream', 'ction', 'leaf', 'load', 'agna', 'shirts', 'ksh', 'icy', 'keys', 'Apple', 'ants', 'trained', 'orted', 'ender', 'gob', 'chini', 'ocol', ' ounces', 'talk', 'rob', 'Dragon', 'chip', 'amer', 'plets', 'fed', 'iri', 'pill', 'mons', 'int', ' cents', 'reath', 'cker', 'pper', 'war', 'rol', 'Ah', 'osaurus', 'Ten', 'dollar', 'girls', 'seven', 'mur', 'affles', 'Offic', 'ws', 'ounced', 'ucker', 'iw', 'eating', 'ates', 'Ar', 'Chuck', 'Eight', 'mus']


In [None]:

tokens = model.to_tokens("Once upon a time", prepend_bos=False).squeeze(0)
logits = (model.W_E[tokens] @ model.W_U)
median_logit = logits.median()
print(median_logit)

next_logits = []
for i in range(tokens.shape[0] - 1):
    next_logits.append(logits[i, tokens[i + 1]])
next_logits = torch.stack(next_logits)
print(next_logits)

# 'a' token seems to boost some sensible completions
values, indices = torch.topk(logits[-2], 100)
print(model.to_str_tokens(indices))

tensor(-0.0134, device='cuda:0')
tensor([ 0.8695, -0.3610,  1.9775], device='cuda:0')
[' a', ' an', ' the', ' some', ' two', ' something', ' another', ' one', ' very', ' many', ' big', ' her', ' so', ' his', ' all', ' in', ' lots', ' to', ' more', ' three', ' four', ' it', ' their', ' on', ' out', ' there', ' small', ' too', ' that', ' and', ' not', '.', ' no', ' good', ' with', ' right', ' really', ' both', ' nice', ' different', ' only', ' five', ',', '\n', ' like', ' hard', ' A', ' bright', ' just', ' when', ' much', ' little', ' long', ' cool', ' brave', ' someone', ' fast', ' strong', ' Lily', ' funny', ' this', ' look', ' warm', ' made', ' fun', ' wide', ' new', ' as', ' silly', ' red', ' hot', ' each', ' close', ' open', ' pretty', ' Mom', ' what', ' Tom', ' other', ' its', ' even', ' far', ' tall', ' The', ' shiny', ' dark', ' loud', ' white', ' deep', ' soft', ' fair', ' black', ' j', ' outside', ' kind', ' your', 'A', ' for', ' help', ' happy']


In [None]:
# Generate heaps of token sequences and average their cache, then find the difference with the OUAT neurons
# Ablate components until OUAT fails
# n-l AND?

_, cache = model.run_with_cache(model.to_tokens(prompts)[:, :40])

def rand_hook(value, hook):
    cache_val = cache[hook.name][:value.size(0), :value.size(1), :value.size(2)]
    mean_val = cache_val.mean(dim=0)
    broadcasted_val = mean_val.unsqueeze(0).expand_as(value)
    value = broadcasted_val
    return value

prompt = "Once"
print(model.generate(prompt, 20, temperature=0, use_past_kv_cache=False))
for layer in range(model.cfg.n_layers):
    with model.hooks([(f'blocks.{layer}.hook_attn_out', rand_hook)]):
        print(f"Attn {layer}", model.generate(prompt, 20, temperature=0, use_past_kv_cache=False))
    with model.hooks([(f'blocks.{layer}.hook_mlp_out', rand_hook)]):
        print(f"MLP {layer}", model.generate(prompt, 20, temperature=0, use_past_kv_cache=False))

# Necessary components:
# MLP0, MLP1, MLP2, MLP3, MLP4, MLP5, MLP7 (every MLP but 6)

NameError: name 'prompts' is not defined

### Ablate by cosine sim - degrades around 80

In [None]:
# Figure out which neurons directly write to each vocab
# And see if we can ablate everything else

tokens = model.to_tokens("Once upon a time", prepend_bos=False)  # [1, 4]
token_dirs = model.tokens_to_residual_directions(tokens)[0]  # [4, 64]
token_dirs_reshaped = token_dirs.unsqueeze(1).unsqueeze(1)  # [4, 1, 1, 64]
W_out_reshaped = model.W_out.unsqueeze(0)  # [1, 8, 256, 64]

cosine_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
result = cosine_sim(token_dirs_reshaped, W_out_reshaped)  # [4, 8, 256]

_, cache = model.run_with_cache(tokens)
acts = [cache[f'blocks.{layer}.mlp.hook_post'][0] for layer in range(model.cfg.n_layers)] # [[batch pos]]*n_layers
acts = torch.stack(acts, dim=1) # 


layer_neuron_tuples = []
for token_index in range(result.size(0)):
    values, indices = torch.topk(result[token_index].view(-1), 20, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (8, 256))
    layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))


In [None]:
layer_neuron_dict = haystack_utils.get_neurons_by_layer(layer_neuron_tuples)

sorted_tuples = []
sorted_acts = []

for layer in layer_neuron_dict.keys():
    neurons = layer_neuron_dict[layer]
    mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
    sorted_tuples.extend([(layer, neuron) for neuron in neurons])
    sorted_acts.extend(mean_acts)
    assert len(sorted_tuples) == len(sorted_acts)

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

Once in a time, a little girl named Amy was


### Ablate by DLA

In [30]:
def batched_dot_product(x, y):
    return torch.vmap(torch.dot)(x, y)
    

def neuron_DLA(prompt: str, model: HookedTransformer, pos=np.s_[-1:]) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons'''
    tokens = model.to_tokens(prompt)
    answers = tokens[:, 1:]
    tokens = tokens[:, :-1]
    
    _, cache = model.run_with_cache(tokens)
    attrs, labels = cache.get_full_resid_decomposition(-1, expand_neurons=True, apply_ln=True, return_labels=True, pos_slice=pos)
    answer_residual_directions = model.tokens_to_residual_directions(answers)
    if answer_residual_directions.ndim == 1:
        answer_residual_directions = answer_residual_directions.unsqueeze(0)  # [1 d_model]
    elif answer_residual_directions.ndim == 3:
        answer_residual_directions = answer_residual_directions[0]  # [pos d_model]
    answer_residual_directions = answer_residual_directions[pos]  # [pos d_model]
    neuron_indices = [i for i in range(len(labels)) if 'N' in labels[i]]
    neuron_labels = [labels[i] for i in neuron_indices]
    neuron_attrs = attrs[neuron_indices, :].squeeze(1)
    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_directions[[i]].repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def get_neuron_mean_acts(dla_layer_neuron_tuples: list[tuple[int, int]]) -> tuple[torch.Tensor, torch.Tensor]:
    layer_neuron_dict = haystack_utils.get_neurons_by_layer(dla_layer_neuron_tuples)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []

    for layer in layer_neuron_dict.keys():
        neurons = layer_neuron_dict[layer]
        mean_acts = haystack_utils.get_mlp_activations(filtered_prompts[:200], layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, neurons: list[tuple[int, int]]):
    layer_dict = haystack_utils.get_neurons_by_layer(neurons)
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

In [22]:
# High level viz

# x, y = haystack_utils.DLA(["Once upon"], model)
# haystack_utils.line(x.cpu().squeeze(0))

attrs, labels = neuron_DLA("Once upon a time", model, pos=np.s_[-4:])
# haystack_utils.line(attrs[0].cpu().numpy(), xlabel="Correct logit", ylabel="", title="DLA per neuron in layer")

# print(attrs.sum())
# px.histogram(attrs.flatten().cpu().numpy())

Tried to stack head results when they weren't cached. Computing head results now


In [32]:
# Ablate the top 3 DLA neurons for each index and check that it messes things up

# Get top neurons
dla_layer_neuron_tuples = []
for token_index in range(4):
    values, indices = torch.topk(attrs[token_index], 3, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    dla_layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))

print(dla_layer_neuron_tuples)

sorted_dla_tuples, sorted_acts = get_neuron_mean_acts(dla_layer_neuron_tuples)
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

[(7, 132), (2, 125), (7, 219), (0, 153), (0, 6), (7, 32), (0, 234), (1, 39), (2, 75), (7, 59), (7, 127), (0, 60)]


  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a mood, there was a little girl named


In [None]:
# Ablate for each token individually 

from transformer_lens import utils
strings = ["Once", " upon", " a", " time"]

# Minimal top DLA neurons
for token_index, n_ablations in [(0, 10), (1, 20), (2, 6)]:
    values, indices = torch.topk(attrs[token_index], n_ablations, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    token_dla_layer_neuron_tuples = zip(layer_indices.tolist(), neuron_indices.tolist())
    sorted_dla_tuples, sorted_acts = get_neuron_mean_acts(token_dla_layer_neuron_tuples)
    hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_tuples, sorted_acts)
    with model.hooks(hooks):
        test_prompt = "".join(strings[:token_index + 1])
        # print(utils.test_prompt(test_prompt, strings[token_index + 1] or 'butterfly', model))
        # print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False, verbose=False))

### Ablating the disjoint set and ablating to a different prompt - both cursed

In [None]:
# # Try ablating the disjoint set
# unspecified_neuron_tuples = get_unspecified_neurons(model, dla_layer_neuron_tuples)
# layer_neuron_dict = haystack_utils.get_neurons_by_layer(unspecified_neuron_tuples)

# sorted_dla_layer_neuron_tuples = []
# sorted_acts = []

# for layer in layer_neuron_dict.keys():
#     neurons = layer_neuron_dict[layer]
#     mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
#     sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
#     sorted_acts.extend(mean_acts)
#     assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

# hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_layer_neuron_tuples, sorted_acts)
# with model.hooks(hooks):
#     print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

OnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnce


In [None]:
# def get_resample_neurons_hooks(neurons: list[tuple[int, int]], resampled_cache):
#     layer_neurons = haystack_utils.get_neurons_by_layer(neurons)
#     hooks = []
#     for layer, neurons in layer_neurons.items():
#         def resample_neurons_hook(value, hook):
#             resample_cache_slice = np.s_[:value.shape[0], :value.shape[1], neurons]
#             value[:, :, neurons] = resampled_cache[hook.name][resample_cache_slice]
#             return value
#         hooks.append((f'blocks.{layer}.mlp.hook_post', resample_neurons_hook))
#     return hooks

# # Try again with resampled activations
# _, resample_cache = model.run_with_cache(["Once upon a time"])
# hooks = get_resample_neurons_hooks(sorted_dla_layer_neuron_tuples, resample_cache)
# with model.hooks(hooks):
#     print(model.generate("Once", 3, temperature=0, use_past_kv_cache=False))

  0%|          | 0/3 [00:00<?, ?it/s]

Once upon a time


### Causal study

In [33]:
def get_neuron_loss_increases(prompt: str, positionwise=False):
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return losses

losses = get_neuron_loss_increases("Once upon a time", positionwise=True)

  0%|          | 0/8 [00:00<?, ?it/s]

100%|██████████| 8/8 [01:03<00:00,  7.93s/it]


In [40]:
def get_hook_inputs_for_token_index(loss_increases_by_neuron, model=model, filtered_prompts=filtered_prompts):
    values, indices = torch.topk(loss_increases_by_neuron, 40)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    causal_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))
    layer_neuron_dict = haystack_utils.get_neurons_by_layer(causal_layer_neuron_tuples)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []

    for layer in layer_neuron_dict.keys():
        neurons = layer_neuron_dict[layer]
        mean_acts = haystack_utils.get_mlp_activations(filtered_prompts[:200], layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

In [41]:
# Get "upon" neurons
num_neurons = 20

token_index = 1
loss_increases = torch.tensor([loss[token_index] for loss in losses])
sorted_ablation_tuples, sorted_acts = get_hook_inputs_for_token_index(loss_increases)
print(sorted_ablation_tuples[:num_neurons], sorted_acts[:num_neurons])

[(4, 175), (4, 98), (4, 152), (4, 122), (4, 219), (4, 80), (4, 49), (4, 129), (4, 61), (3, 60), (3, 118), (3, 75), (3, 211), (3, 87), (3, 146), (3, 158), (6, 236), (6, 126), (6, 223), (6, 19)] [tensor(-0.0236, device='cuda:0'), tensor(-0.0060, device='cuda:0'), tensor(-0.0116, device='cuda:0'), tensor(-0.0268, device='cuda:0'), tensor(-0.0180, device='cuda:0'), tensor(-0.0217, device='cuda:0'), tensor(-0.0229, device='cuda:0'), tensor(-0.0370, device='cuda:0'), tensor(-0.0272, device='cuda:0'), tensor(-0.0124, device='cuda:0'), tensor(-0.0243, device='cuda:0'), tensor(-0.0105, device='cuda:0'), tensor(-0.0090, device='cuda:0'), tensor(-0.0152, device='cuda:0'), tensor(-0.0273, device='cuda:0'), tensor(0.0013, device='cuda:0'), tensor(-0.0016, device='cuda:0'), tensor(-0.0352, device='cuda:0'), tensor(0.0206, device='cuda:0'), tensor(0.0329, device='cuda:0')]


In [36]:
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_ablation_tuples[:num_neurons], sorted_acts[:num_neurons])
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

Once, the little little little little little little little little


In [49]:
def compare_dla_and_ablation(token_index, attrs, losses, num_neurons=20):
    print("DLA:")
    values, indices = torch.topk(attrs[token_index], num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    dla_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))
    indices_1d_dla = [np.ravel_multi_index(index_2d, (model.cfg.n_layers, model.cfg.d_mlp)) for index_2d in dla_layer_neuron_tuples[:num_neurons]]
    
    print(dla_layer_neuron_tuples[:num_neurons])
    print(attrs[token_index][indices_1d_dla])

    print("Ablation:")
    loss_increases_by_neuron = torch.tensor([loss[token_index] for loss in losses])
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    indices = indices.numpy()
    layer_indices, neuron_indices = np.unravel_index(indices[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(attrs[token_index][indices.tolist()[:num_neurons]])

compare_dla_and_ablation(token_index=0, attrs=attrs, losses=losses, num_neurons=10)

DLA:
[(7, 132), (2, 125), (7, 219), (0, 155), (3, 90), (0, 166), (7, 86), (0, 135), (7, 55), (6, 136)]
tensor([4.9803, 2.3184, 1.6035, 1.5476, 1.3802, 1.3253, 1.2454, 1.0499, 1.0420,
        1.0340], device='cuda:0')
Ablation:
[(7, 132), (3, 22), (6, 179), (4, 178), (6, 181), (5, 11), (4, 36), (7, 219), (4, 38), (4, 239)]
tensor([ 4.9803,  0.5572, -0.7496, -0.4230, -0.5537,  0.5289,  0.1620,  1.6035,
         0.0813, -0.3386], device='cuda:0')


In [None]:
important_neurons = [(4, 175), (4, 98)]
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_ablation_tuples[:2], sorted_acts[:2])
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

Once for a little little bunny named Benny who loved to
