In [1]:
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor
import plotly.io as pio
import numpy as np
from tqdm import trange
from collections import defaultdict

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
import hook_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cuda")

data = haystack_utils.load_json_data('data/english_europarl.json')

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-small into HookedTransformer
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


In [3]:
idioms = [
    " a blessing in disguise",
    " beat around the bush",
    " better late than never",
    " break a leg",
    " call it a day",
    " cut somebody some slack",
    " cutting corners",
    " easy does it",
    " get out of hand",
    " get something out of your system",
    " give someone the benefit of the doubt",
    " go back to the drawing board",
    " hang in there",
    " hit the sack",
    " it's not rocket science",
    " let the cat out of the bag",
    " miss the boat",
    " no pain, no gain",
    " on the ball",
    " pull yourself together",
    " so far so good",
    " speak of the devil",
    " that's the last straw",
    " the best of both worlds",
    " the ball is in your court",
    " the real deal",
    " time flies when you're having fun",
    " to get bent out of shape",
    " to make matters worse",
    " under the weather",
    " we'll cross that bridge when we come to it",
    " wrap your head around something",
    " you can't judge a book by its cover",
    " a penny for your thoughts",
    " add insult to injury",
    " bite the bullet",
    " don't cry over spilt milk",
    " every cloud has a silver lining",
    " hit the nail on the head",
    " it takes one to know one",
    " kill two birds with one stone",
    " make a long story short",
    " not playing with a full deck",
    " an arm and a leg",
    " put something on ice",
    " see eye to eye",
    " take it with a grain of salt",
    " the whole nine yards",
    " you can say that again",
    " your guess is as good as mine"
]


In [30]:
def batched_dot_product(x: torch.Tensor, y: torch.Tensor):
    return torch.vmap(torch.dot)(x, y)
    
def neuron_DLA(model: HookedTransformer, prompt: str, pos=np.s_[-1:]) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons'''
    tokens = model.to_tokens(prompt)
    answers = tokens[:, 1:]
    tokens = tokens[:, :-1]
    
    _, cache = model.run_with_cache(tokens)
    attrs, labels = cache.get_full_resid_decomposition(-1, expand_neurons=True, apply_ln=True, return_labels=True, pos_slice=pos)
    
    # I think it removes the batch dimension if it's one
    answer_residual_directions = model.tokens_to_residual_directions(answers)
    if answer_residual_directions.ndim == 1:
        answer_residual_directions = answer_residual_directions.unsqueeze(0)  # [1 d_model]
    elif answer_residual_directions.ndim == 3:
        answer_residual_directions = answer_residual_directions[0]  # [pos d_model]
    answer_residual_directions = answer_residual_directions[pos]  # [pos d_model]

    neuron_indices = [i for i in range(len(labels)) if 'N' in labels[i]]
    neuron_labels = [labels[i] for i in neuron_indices]
    neuron_attrs = attrs[neuron_indices, :].squeeze(1)
    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_directions[[i]].repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def neuron_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron=tuple[int, int]
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    neuron_attrs, neuron_labels = cache.stack_neuron_results(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    neuron_attrs = neuron_attrs.squeeze(1)
    
    answer_residual_direction = model.W_in[layer, :, neuron].unsqueeze(0)  # [1 d_model]

    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_direction.repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def resid_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron:tuple[int, int]=(0,0)
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    all_attrs, labels = cache.get_full_resid_decomposition(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    all_attrs = all_attrs.squeeze(1)
    
    answer_residual_direction = model.W_in[layer, :, neuron].unsqueeze(0)  # [1 d_model]

    results = []
    for i in range(all_attrs.shape[1]):
        results.append(batched_dot_product(all_attrs[:, i], answer_residual_direction.repeat(all_attrs.shape[0], 1)))
    return torch.stack(results), labels

def get_neuron_mean_acts(model: HookedTransformer, data: list[str], layer_neuron_dict: dict[int, list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
    sorted_layer_neuron_tuples = []
    sorted_acts = []

    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_layer_neuron_tuples) == len(sorted_acts)

    return sorted_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, layer_neuron_dict: dict[int, list[int]]):
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_neuron_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

def get_neuron_loss_increases(model: HookedTransformer, data: list[str], prompt: str, positionwise: bool=False) -> torch.Tensor:
    n_tokens = model.to_tokens(prompt).shape[1] - 1
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return torch.stack(losses).reshape(n_tokens, model.cfg.n_layers * model.cfg.d_mlp)

def compare_dla_and_ablation(model: HookedTransformer, dla_attrs_by_neuron: torch.Tensor, ablation_losses_by_neuron: torch.Tensor, num_neurons=20):
    print("DLA:")
    values, indices = torch.topk(dla_attrs_by_neuron, num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

    print("Ablation:")
    loss_increases_by_neuron = ablation_losses_by_neuron
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy()[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

def get_hook_inputs_for_token_index(model: HookedTransformer, data: list[str], loss_increases_by_neuron: torch.Tensor, k=40):
    values, indices = torch.topk(loss_increases_by_neuron, k)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    layer_neuron_dict = defaultdict(list)
    for layer, neuron in zip(layer_indices, neuron_indices):
        layer_neuron_dict[layer].append(neuron)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []
    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

def unravel_top_k(neuron_attrs: torch.Tensor, k: int=10):
    values, indices = torch.topk(neuron_attrs, k)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    return list(zip(layer_indices.tolist(), neuron_indices.tolist()))

def resid_to_head_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        head: tuple[int, int],
        pos=np.s_[-1:], 
        
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition and return the composition of each element of the given K matrix. Unbatched.'''
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(prompt)
    layer, head_index = head
    all_attrs, labels = cache.get_full_resid_decomposition(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    all_attrs = all_attrs.squeeze(1)
    answer_residual_direction = model.W_K[layer, head_index, :]
    results = torch.zeros(all_attrs.shape[1], all_attrs.shape[0], answer_residual_direction.shape[1])
    for i in range(all_attrs.shape[1]): # for each token
        for j in range(answer_residual_direction.shape[1]): # for each direction in head input
            token_attrs = all_attrs[:, i]
            answer = answer_residual_direction[:, j].unsqueeze(0).repeat(token_attrs.shape[0], 1)
            results[i, :, j] = batched_dot_product(token_attrs, answer)
    return results, labels

In [19]:
# test_prompts = [" An eye for", " An eye for an", " By clicking register, you agree to Etsy", " Beat around the bush", " Break a leg", " Bite the bullet", " Through thick and thin", " "]
# tokens = model.to_tokens([prompt for prompt in test_prompts])

In [69]:
# gooduns = idioms[7, 12, 17, 19, 30, 45, 47]
# eh = idioms[8, 23, 31, 33, 36, 43, 49]
prompt = idioms[7]
attrs, labels = neuron_DLA(model, prompt, pos=np.s_[-1:])
fig = haystack_utils.line(attrs[0].cpu().numpy(), xlabel="Neuron index", ylabel="Logit attribution", title=f"DLA by neuron for \"{prompt}\"", height=400)

Tried to stack head results when they weren't cached. Computing head results now


' Fay'

In [75]:
# layer_neuron_dict = defaultdict(list)
# for token_index in range(1):
#     values, indices = torch.topk(attrs[token_index], 3, dim=-1)
#     layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
#     for layer_index, neuron_index in zip(layer_indices.tolist(), neuron_indices.tolist()):
#         layer_neuron_dict[layer_index].append(neuron_index)
layer, neuron = np.unravel_index(33_836, (model.cfg.n_layers, model.cfg.d_mlp))
sorted_dla_tuples, sorted_acts = get_neuron_mean_acts(model, data[:200], {layer: [neuron]})
hooks = hook_utils.get_ablate_context_neurons_hooks(sorted_dla_tuples, sorted_acts)
print(model.generate(" Easy does", 10, temperature=0, use_past_kv_cache=False, verbose=False))
with model.hooks(hooks):
    print(model.generate(" Easy does", 10, temperature=0, use_past_kv_cache=False, verbose=False))

 Easy does it.

The first thing you need to
 Easy does not mean easy.

The first thing you


In [5]:
test_prompt = " Random string of"
layer, neuron = np.unravel_index(33_836, (model.cfg.n_layers, model.cfg.d_mlp))
hooks = hook_utils.get_ablate_context_neurons_hooks([(layer, neuron)], [10.0])
print(model.generate(test_prompt, 10, temperature=0, use_past_kv_cache=False, verbose=False))
with model.hooks(hooks):
    print(model.generate(test_prompt, 10, temperature=0, use_past_kv_cache=False, verbose=False))

 Random string of characters to use for the character "I"

 Random string of characters it is it it it it it it it


In [28]:
it_phrases = [
    " Go for it",
    " Deal with it",
    " Worth it",
    " Think about it",
    " Forget about it",
    " Can't help it",
    " Make the most of it",
    " Get over it",
    " That's it",
    " Keep at it",
    " Don't worry, I'm on it"
]
# it_token = model.to_single_token(" it")
it_acts = []
for prompt in it_phrases:
    _, cache = model.run_with_cache(prompt)
    it_acts.append(cache[f'blocks.{layer}.mlp.hook_post'][0, -2, neuron].item())

non_it_acts = []
for prompt in [idiom for idiom in idioms if ' it' not in idiom]:
    _, cache = model.run_with_cache(prompt)
    non_it_acts.extend(cache[f'blocks.{layer}.mlp.hook_post'][0, :, neuron].cpu())

fig = haystack_utils.two_histogram(torch.tensor(it_acts), torch.tensor(non_it_acts), 'it activations', 'non it activations', title=f'L{layer}N{neuron} activations')

In [112]:
# 33_836, 33_752
# 35_473 or 34_87 or 35_217
prompt = it_phrases[4]
attrs, labels = neuron_DLA(model, prompt, pos=np.s_[-1:])
fig = haystack_utils.line(attrs[0].cpu().numpy(), xlabel="Neuron index", ylabel="Logit attribution", title=f"DLA by neuron for \"{prompt}\"", height=400)

np.unravel_index([33_836, 33_752, 35_473, 34_870, 35_217], (model.cfg.n_layers, model.cfg.d_mlp))

Tried to stack head results when they weren't cached. Computing head results now


(array([11, 10, 11, 11, 11]), array([  44, 3032, 1681, 1078, 1425]))

In [46]:
layer, neuron = 11, 44
for phrase in it_phrases[:5]:
    attrs, labels = neuron_to_context_neuron_DLA(model, it_phrases[4], context_neuron=(layer, neuron))
    haystack_utils.line(attrs[0].cpu(), height=300, title=f"Components writing to W_in of L{layer}N{neuron}")

In [41]:
layers, neurons = np.unravel_index([33_836, 33_752, 35_473, 34_870, 35_217], (model.cfg.n_layers, model.cfg.d_mlp))

for layer, neuron in zip(layers, neurons):
    print(layer, neuron)
        
    # it_token = model.to_single_token(" it")
    it_acts = []
    for prompt in it_phrases:
        _, cache = model.run_with_cache(prompt)
        it_acts.append(cache[f'blocks.{layer}.mlp.hook_post'][0, -2, neuron].item())

    non_it_acts = []
    for prompt in [idiom for idiom in idioms if ' it' not in idiom]:
        _, cache = model.run_with_cache(prompt)
        non_it_acts.extend(cache[f'blocks.{layer}.mlp.hook_post'][0, :, neuron].cpu())

    haystack_utils.two_histogram(torch.tensor(it_acts), torch.tensor(non_it_acts), 'it activations', 'non it activations', title=f'L{layer}N{neuron} activations')

11 44


10 3032


11 1681


11 1078


11 1425
