In [52]:
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor
import plotly.io as pio
import numpy as np
from tqdm import trange

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
import hook_utils

%reload_ext autoreload
%autoreload 2

In [53]:
data = haystack_utils.load_txt_data('data_large/TinyStories-train.txt')[:500]
filtered_prompts = [prompt for prompt in data if not prompt.startswith("Once")]

data_large/TinyStories-train.txt: Loaded 935512 examples with 0 to 4287 characters each.


In [54]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-1M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-1M into HookedTransformer


### Unigrams

In [55]:
# Similar words stored near in the embed
tokens = model.to_tokens("Once upon a time", prepend_bos=False)
logits = model.W_E[tokens].squeeze(0) @ model.W_U

for i in range(4):
    values, indices = torch.topk(logits[i], 10)
    print(model.to_str_tokens(indices))

['Once', 'One', 'Yesterday', 'At', 'There', 'Today', 'On', 'Em', 'When', 'In']
[' upon', 'bie', 'ancer', 'itting', 'll', 'aj', 'packs', 'upon', 'uttering', ' haven']
[' a', ' an', ' the', ' some', ' two', ' something', ' another', ' one', ' very', ' many']
[' time', ' day', ' week', ' evening', ' morning', ' afternoon', ' night', ' Sunday', ' way', ' year']


In [56]:
# Associations
tokens = model.to_tokens("cow", prepend_bos=False)
logits = (model.W_E[tokens].squeeze(0) @ model.W_U).squeeze(0)
mean_logit = logits.median() # mean is lower

associated_tokens = []
for string in "grass brown field moo meat milk udders animal".split(' '):
    token = model.to_tokens(string, prepend_bos=False)
    associated_tokens.append(token[0, 0])
associated_tokens = torch.stack(associated_tokens)
mean_associated_logit = logits[associated_tokens].median() # mean is similar

print(mean_logit, mean_associated_logit)

values, indices = torch.topk(logits, 100)
print(model.to_str_tokens(indices))

tensor(0.1094, device='cuda:0') tensor(0.2133, device='cuda:0')
['cow', 'ns', 'ridge', 'wo', ' seventeen', 'eight', 'umbers', 'please', 'ucks', 'icians', 'oxy', 'bread', 'Seven', 'uns', 'bang', 'uggets', 'oster', 'wy', 'Ta', 'arella', 'irteen', ' reins', 'vals', ' fries', 'heddar', 'nine', 'aff', 'zee', 'rots', 'six', 'ban', 'bet', 'Six', 'Four', 'umbs', 'arians', 'Twenty', 'gary', 'cards', 'urger', 'workers', 'Sal', 'apple', 'rito', 'ulent', 'haw', 'cream', 'ction', 'leaf', 'load', 'agna', 'shirts', 'ksh', 'icy', 'keys', 'Apple', 'ants', 'trained', 'orted', 'ender', 'gob', 'chini', 'ocol', ' ounces', 'talk', 'rob', 'Dragon', 'chip', 'amer', 'plets', 'fed', 'iri', 'pill', 'mons', 'int', ' cents', 'reath', 'cker', 'pper', 'war', 'rol', 'Ah', 'osaurus', 'Ten', 'dollar', 'girls', 'seven', 'mur', 'affles', 'Offic', 'ws', 'ounced', 'ucker', 'iw', 'eating', 'ates', 'Ar', 'Chuck', 'Eight', 'mus']


In [57]:

tokens = model.to_tokens("Once upon a time", prepend_bos=False).squeeze(0)
logits = (model.W_E[tokens] @ model.W_U)
median_logit = logits.median()
print(median_logit)

next_logits = []
for i in range(tokens.shape[0] - 1):
    next_logits.append(logits[i, tokens[i + 1]])
next_logits = torch.stack(next_logits)
print(next_logits)

# 'a' token seems to boost some sensible completions
values, indices = torch.topk(logits[-2], 100)
print(model.to_str_tokens(indices))

tensor(-0.0134, device='cuda:0')
tensor([ 0.8695, -0.3610,  1.9775], device='cuda:0')
[' a', ' an', ' the', ' some', ' two', ' something', ' another', ' one', ' very', ' many', ' big', ' her', ' so', ' his', ' all', ' in', ' lots', ' to', ' more', ' three', ' four', ' it', ' their', ' on', ' out', ' there', ' small', ' too', ' that', ' and', ' not', '.', ' no', ' good', ' with', ' right', ' really', ' both', ' nice', ' different', ' only', ' five', ',', '\n', ' like', ' hard', ' A', ' bright', ' just', ' when', ' much', ' little', ' long', ' cool', ' brave', ' someone', ' fast', ' strong', ' Lily', ' funny', ' this', ' look', ' warm', ' made', ' fun', ' wide', ' new', ' as', ' silly', ' red', ' hot', ' each', ' close', ' open', ' pretty', ' Mom', ' what', ' Tom', ' other', ' its', ' even', ' far', ' tall', ' The', ' shiny', ' dark', ' loud', ' white', ' deep', ' soft', ' fair', ' black', ' j', ' outside', ' kind', ' your', 'A', ' for', ' help', ' happy']


In [58]:
# Generate heaps of token sequences and average their cache, then find the difference with the OUAT neurons
# Ablate components until OUAT fails
# n-l AND?

_, cache = model.run_with_cache(model.to_tokens(prompts)[:, :40])

def rand_hook(value, hook):
    cache_val = cache[hook.name][:value.size(0), :value.size(1), :value.size(2)]
    mean_val = cache_val.mean(dim=0)
    broadcasted_val = mean_val.unsqueeze(0).expand_as(value)
    value = broadcasted_val
    return value

prompt = "Once"
print(model.generate(prompt, 20, temperature=0, use_past_kv_cache=False))
for layer in range(model.cfg.n_layers):
    with model.hooks([(f'blocks.{layer}.hook_attn_out', rand_hook)]):
        print(f"Attn {layer}", model.generate(prompt, 20, temperature=0, use_past_kv_cache=False))
    with model.hooks([(f'blocks.{layer}.hook_mlp_out', rand_hook)]):
        print(f"MLP {layer}", model.generate(prompt, 20, temperature=0, use_past_kv_cache=False))

# Necessary components:
# MLP0, MLP1, MLP2, MLP3, MLP4, MLP5, MLP7 (every MLP but 6)

NameError: name 'prompts' is not defined

### Ablate by cosine sim - degrades around 80

In [None]:
# Figure out which neurons directly write to each vocab
# And see if we can ablate everything else

tokens = model.to_tokens("Once upon a time", prepend_bos=False)  # [1, 4]
token_dirs = model.tokens_to_residual_directions(tokens)[0]  # [4, 64]
token_dirs_reshaped = token_dirs.unsqueeze(1).unsqueeze(1)  # [4, 1, 1, 64]
W_out_reshaped = model.W_out.unsqueeze(0)  # [1, 8, 256, 64]

cosine_sim = torch.nn.CosineSimilarity(dim=-1, eps=1e-6)
result = cosine_sim(token_dirs_reshaped, W_out_reshaped)  # [4, 8, 256]

_, cache = model.run_with_cache(tokens)
acts = [cache[f'blocks.{layer}.mlp.hook_post'][0] for layer in range(model.cfg.n_layers)] # [[batch pos]]*n_layers
acts = torch.stack(acts, dim=1) # 


layer_neuron_tuples = []
for token_index in range(result.size(0)):
    values, indices = torch.topk(result[token_index].view(-1), 20, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (8, 256))
    layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))


In [None]:
layer_neuron_dict = haystack_utils.get_neurons_by_layer(layer_neuron_tuples)

sorted_tuples = []
sorted_acts = []

for layer in layer_neuron_dict.keys():
    neurons = layer_neuron_dict[layer]
    mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
    sorted_tuples.extend([(layer, neuron) for neuron in neurons])
    sorted_acts.extend(mean_acts)
    assert len(sorted_tuples) == len(sorted_acts)

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

Once in a time, a little girl named Amy was


### Utils

In [128]:
def batched_dot_product(x, y):
    return torch.vmap(torch.dot)(x, y)
    
def neuron_DLA(prompt: str, model: HookedTransformer, pos=np.s_[-1:]) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons'''
    tokens = model.to_tokens(prompt)
    answers = tokens[:, 1:]
    tokens = tokens[:, :-1]
    
    _, cache = model.run_with_cache(tokens)
    attrs, labels = cache.get_full_resid_decomposition(-1, expand_neurons=True, apply_ln=True, return_labels=True, pos_slice=pos)
    
    # I think it removes the batch dimension if it's one
    answer_residual_directions = model.tokens_to_residual_directions(answers)
    if answer_residual_directions.ndim == 1:
        answer_residual_directions = answer_residual_directions.unsqueeze(0)  # [1 d_model]
    elif answer_residual_directions.ndim == 3:
        answer_residual_directions = answer_residual_directions[0]  # [pos d_model]
    answer_residual_directions = answer_residual_directions[pos]  # [pos d_model]

    neuron_indices = [i for i in range(len(labels)) if 'N' in labels[i]]
    neuron_labels = [labels[i] for i in neuron_indices]
    neuron_attrs = attrs[neuron_indices, :].squeeze(1)
    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_directions[[i]].repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def get_neuron_mean_acts(neurons: list[tuple[int, int]], model: HookedTransformer = model) -> tuple[torch.Tensor, torch.Tensor]:
    layer_neuron_dict = haystack_utils.get_neurons_by_layer(neurons)
    sorted_layer_neuron_tuples = []
    sorted_acts = []

    for layer in layer_neuron_dict.keys():
        neurons = layer_neuron_dict[layer]
        mean_acts = haystack_utils.get_mlp_activations(filtered_prompts[:200], layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_layer_neuron_tuples) == len(sorted_acts)

    return sorted_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, neurons: list[tuple[int, int]]):
    layer_dict = haystack_utils.get_neurons_by_layer(neurons)
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

def get_neuron_loss_increases(prompt: str, positionwise: bool=False, model: HookedTransformer=model) -> torch.Tensor:
    n_tokens = model.to_tokens(prompt).shape[1] - 1
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return torch.stack(losses).reshape(n_tokens, model.cfg.n_layers * model.cfg.d_mlp)

def compare_dla_and_ablation(token_index: int, dla_attrs_by_neuron: torch.Tensor, ablation_losses_by_neuron: torch.Tensor, num_neurons=20, model: HookedTransformer=model):
    print("DLA:")
    values, indices = torch.topk(dla_attrs_by_neuron[token_index], num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    dla_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))
    indices_1d_dla = [np.ravel_multi_index(index_2d, (model.cfg.n_layers, model.cfg.d_mlp)) for index_2d in dla_layer_neuron_tuples[:num_neurons]]
    
    print(dla_layer_neuron_tuples[:num_neurons])
    print(dla_attrs_by_neuron[token_index][indices_1d_dla])

    print("Ablation:")
    loss_increases_by_neuron = ablation_losses_by_neuron[:, token_index]
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy()[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[token_index][indices.tolist()[:num_neurons]])

def get_hook_inputs_for_token_index(loss_increases_by_neuron, model: HookedTransformer=model, filtered_prompts=filtered_prompts, k=40):
    values, indices = torch.topk(loss_increases_by_neuron, k)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    causal_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))
    layer_neuron_dict = haystack_utils.get_neurons_by_layer(causal_layer_neuron_tuples)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []

    for layer in layer_neuron_dict.keys():
        neurons = layer_neuron_dict[layer]
        mean_acts = haystack_utils.get_mlp_activations(filtered_prompts[:200], layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

### Ablate by DLA

In [129]:
# High level viz

# x, y = haystack_utils.DLA(["Once upon"], model)
# haystack_utils.line(x.cpu().squeeze(0))

attrs, labels = neuron_DLA("Once upon a time", model, pos=np.s_[-4:])
# haystack_utils.line(attrs[0].cpu().numpy(), xlabel="Correct logit", ylabel="", title="DLA per neuron in layer")

# print(attrs.sum())
# px.histogram(attrs.flatten().cpu().numpy())

Tried to stack head results when they weren't cached. Computing head results now


In [130]:
# Ablate the top 3 DLA neurons for each index and check that it messes things up

# Get top neurons
dla_layer_neuron_tuples = []
for token_index in range(4):
    values, indices = torch.topk(attrs[token_index], 3, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    dla_layer_neuron_tuples.extend(zip(layer_indices.tolist(), neuron_indices.tolist()))

print(dla_layer_neuron_tuples)

sorted_dla_tuples, sorted_acts = get_neuron_mean_acts(dla_layer_neuron_tuples)
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_tuples, sorted_acts)
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

[(0, 197), (1, 1221), (0, 2869), (0, 1322), (0, 1701), (0, 2995), (0, 2995), (1, 1221), (1, 944), (1, 191), (3, 2447), (1, 75)]


  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a time, there was a little girl named


In [131]:
# Ablate for each token individually 
from transformer_lens import utils
strings = ["Once", " upon", " a", " time"]

# Minimal top DLA neurons
for token_index, n_ablations in [(0, 10), (1, 20), (2, 6)]:
    values, indices = torch.topk(attrs[token_index], n_ablations, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    token_dla_layer_neuron_tuples = zip(layer_indices.tolist(), neuron_indices.tolist())
    sorted_dla_tuples, sorted_acts = get_neuron_mean_acts(token_dla_layer_neuron_tuples)
    hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_tuples, sorted_acts)
    with model.hooks(hooks):
        test_prompt = "".join(strings[:token_index + 1])
        # print(utils.test_prompt(test_prompt, strings[token_index + 1] or 'butterfly', model))
        # print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False, verbose=False))

### Ablating the disjoint set and ablating to a different prompt - both cursed

In [None]:
# # Try ablating the disjoint set
# unspecified_neuron_tuples = get_unspecified_neurons(model, dla_layer_neuron_tuples)
# layer_neuron_dict = haystack_utils.get_neurons_by_layer(unspecified_neuron_tuples)

# sorted_dla_layer_neuron_tuples = []
# sorted_acts = []

# for layer in layer_neuron_dict.keys():
#     neurons = layer_neuron_dict[layer]
#     mean_acts = haystack_utils.get_mlp_activations(filtered_prompts, layer, model, context_crop_start=2, hook_pre=False, neurons=neurons, disable_tqdm=True)
#     sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
#     sorted_acts.extend(mean_acts)
#     assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

# hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_layer_neuron_tuples, sorted_acts)
# with model.hooks(hooks):
#     print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

OnceOnceOnceOnceOnceOnceOnceOnceOnceOnceOnce


In [None]:
# def get_resample_neurons_hooks(neurons: list[tuple[int, int]], resampled_cache):
#     layer_neurons = haystack_utils.get_neurons_by_layer(neurons)
#     hooks = []
#     for layer, neurons in layer_neurons.items():
#         def resample_neurons_hook(value, hook):
#             resample_cache_slice = np.s_[:value.shape[0], :value.shape[1], neurons]
#             value[:, :, neurons] = resampled_cache[hook.name][resample_cache_slice]
#             return value
#         hooks.append((f'blocks.{layer}.mlp.hook_post', resample_neurons_hook))
#     return hooks

# # Try again with resampled activations
# _, resample_cache = model.run_with_cache(["Once upon a time"])
# hooks = get_resample_neurons_hooks(sorted_dla_layer_neuron_tuples, resample_cache)
# with model.hooks(hooks):
#     print(model.generate("Once", 3, temperature=0, use_past_kv_cache=False))

  0%|          | 0/3 [00:00<?, ?it/s]

Once upon a time


### Causal study

In [132]:
attrs, labels = neuron_DLA("Once upon a time", model, pos=np.s_[-4:])
losses = get_neuron_loss_increases("Once upon a time", positionwise=True)
num_neurons = 20
compare_dla_and_ablation(token_index=0, attrs=attrs, losses=losses, num_neurons=10)

# Get "upon" neurons
# token_index = 1
# loss_increases = losses[token_index]
# sorted_ablation_tuples, sorted_acts = get_hook_inputs_for_token_index(loss_increases)
# print(sorted_ablation_tuples[:num_neurons], sorted_acts[:num_neurons])

Tried to stack head results when they weren't cached. Computing head results now


  0%|          | 0/4 [00:00<?, ?it/s]

100%|██████████| 4/4 [01:30<00:00, 22.70s/it]


TypeError: compare_dla_and_ablation() got an unexpected keyword argument 'attrs'

In [122]:
# Get "upon" neurons
token_index = 1
sorted_ablation_tuples, sorted_acts = get_hook_inputs_for_token_index(losses[token_index])
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_ablation_tuples[:num_neurons], sorted_acts[:num_neurons])
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_ablation_tuples[:2], sorted_acts[:2])
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a time, there was a little girl named


  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a time, there was a little girl named


In [123]:
model = HookedTransformer.from_pretrained("roneneldan/TinyStories-33M",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

attrs, labels = neuron_DLA("Once upon a time", model, pos=np.s_[-4:])
losses = get_neuron_loss_increases("Once upon a time", positionwise=True, model=model)

Using pad_token, but it is not set yet.


Loaded pretrained model roneneldan/TinyStories-33M into HookedTransformer
Tried to stack head results when they weren't cached. Computing head results now


100%|██████████| 4/4 [01:27<00:00, 21.77s/it]


In [124]:
print(attrs.shape, losses.shape)
num_neurons = 10
compare_dla_and_ablation(token_index=0, attrs=attrs, losses=losses, num_neurons=num_neurons, model=model)

# Get "upon" neurons
# token_index = 1
# loss_increases = losses[token_index]
# sorted_ablation_tuples, sorted_acts = get_hook_inputs_for_token_index(loss_increases, model=model)
# print(sorted_ablation_tuples[:num_neurons], sorted_acts[:num_neurons])

torch.Size([4, 12288]) torch.Size([12288, 4])
DLA:
[(0, 197), (1, 1221), (0, 2869), (3, 2448), (0, 1322), (0, 983), (0, 746), (0, 2170), (0, 2995), (0, 108)]
tensor([1.4333, 0.5989, 0.5042, 0.4689, 0.4635, 0.4530, 0.4415, 0.4353, 0.3970,
        0.3825], device='cuda:0')
Ablation:
[(0, 197), (0, 983), (0, 2037), (1, 2614), (1, 1221), (0, 742), (3, 939), (0, 367), (0, 1794), (0, 1322)]
tensor([ 1.4333,  0.4530, -0.2397,  0.1655,  0.5989, -0.0087,  0.0752,  0.2479,
         0.3282,  0.4635], device='cuda:0')


In [125]:
# Get "upon" neurons
token_index = 3
loss_increases = losses[token_index]
sorted_ablation_tuples, sorted_acts = get_hook_inputs_for_token_index(loss_increases, model, k=300)

In [None]:
num_neurons=110
print(len(sorted_ablation_tuples))
print(sorted_ablation_tuples[:num_neurons])
print(sorted_acts[:num_neurons])

hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_ablation_tuples[:num_neurons], sorted_acts[:num_neurons])
with model.hooks(hooks):
    print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

300
[(0, 2171), (0, 1304), (0, 2869), (0, 1371), (0, 563), (0, 2556), (0, 1613), (0, 321), (0, 797), (0, 372), (0, 367), (0, 2992), (0, 637), (0, 2663), (0, 954), (0, 1284), (0, 916), (0, 1332), (0, 2505), (0, 1701), (0, 2932), (0, 286), (0, 1085), (0, 2317), (0, 1714), (0, 823), (0, 1322), (0, 2814), (0, 493), (0, 1780), (0, 1404), (0, 951), (0, 459), (0, 2128), (0, 2092), (0, 3010), (0, 2961), (0, 983), (0, 269), (0, 2834), (0, 1165), (0, 3024), (0, 1975), (0, 1533), (0, 361), (0, 1563), (0, 2964), (0, 1825), (0, 1033), (0, 2210), (0, 1909), (0, 828), (0, 526), (0, 2616), (0, 1463), (0, 2913), (0, 2930), (0, 2759), (0, 663), (0, 2698), (0, 1089), (0, 1116), (0, 1733), (0, 1633), (0, 739), (0, 965), (0, 1527), (0, 2516), (0, 794), (0, 1084), (0, 2594), (0, 925), (0, 3068), (0, 1879), (0, 2694), (0, 533), (0, 477), (0, 1949), (0, 1428), (0, 1034), (0, 187), (0, 2349), (0, 1699), (0, 549), (0, 100), (0, 2891), (0, 979), (0, 1938), (0, 1736), (0, 955), (0, 811), (0, 2533), (0, 1487), (0,

  0%|          | 0/10 [00:00<?, ?it/s]

Once upon a little girl named Lily. She loved to


In [126]:
dla_layer_neuron_tuples = []

values, indices = torch.topk(attrs[-1], 100, dim=-1)
layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
dla_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))

print(dla_layer_neuron_tuples)

[(1, 191), (3, 2447), (1, 75), (0, 2995), (1, 202), (3, 1489), (2, 2471), (3, 1009), (0, 1304), (2, 952), (0, 197), (3, 2009), (1, 2984), (1, 1089), (3, 262), (0, 983), (3, 342), (3, 658), (1, 2046), (1, 854), (2, 1515), (3, 1052), (3, 1558), (0, 2505), (2, 1368), (2, 2308), (1, 1901), (1, 1876), (1, 846), (3, 2007), (1, 1221), (1, 33), (3, 901), (0, 459), (1, 1381), (0, 1085), (0, 1322), (3, 485), (0, 1116), (1, 110), (0, 2949), (1, 1917), (1, 569), (3, 624), (2, 3053), (3, 1482), (3, 2215), (1, 2374), (0, 563), (0, 1371), (3, 70), (0, 2556), (0, 2918), (3, 155), (0, 1540), (2, 632), (1, 538), (2, 1477), (3, 465), (1, 2100), (3, 1098), (3, 1638), (3, 929), (0, 1783), (1, 2370), (3, 1652), (1, 3046), (1, 469), (0, 916), (2, 1890), (0, 493), (0, 1423), (1, 1817), (3, 1792), (2, 472), (1, 2086), (3, 2787), (0, 483), (1, 2133), (0, 321), (1, 974), (1, 2837), (3, 667), (3, 2095), (3, 2701), (2, 285), (1, 1073), (3, 1475), (0, 2930), (2, 3045), (1, 88), (2, 138), (1, 947), (0, 2171), (1, 72

In [127]:
sorted_dla_tuples, sorted_acts = get_neuron_mean_acts(dla_layer_neuron_tuples, model)
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_tuples, sorted_acts)

    # print(model.generate("Once", 10, temperature=0, use_past_kv_cache=False))

In [111]:
with model.hooks(hooks):
    utils.test_prompt("Once upon a", "time", model)

Tokenized prompt: ['<|endoftext|>', 'Once', ' upon', ' a']
Tokenized answer: [' time']


Top 0th token. Logit: 15.12 Prob:  5.80% Token: | time|
Top 1th token. Logit: 14.99 Prob:  5.10% Token: | dog|
Top 2th token. Logit: 14.77 Prob:  4.08% Token: | cat|
Top 3th token. Logit: 14.71 Prob:  3.84% Token: | mighty|
Top 4th token. Logit: 14.69 Prob:  3.75% Token: | tim|
Top 5th token. Logit: 14.27 Prob:  2.46% Token: | little|
Top 6th token. Logit: 14.26 Prob:  2.44% Token: | hot|
Top 7th token. Logit: 14.19 Prob:  2.28% Token: | story|
Top 8th token. Logit: 14.17 Prob:  2.24% Token: | car|
Top 9th token. Logit: 14.09 Prob:  2.07% Token: | small|
