### Setup

In [1]:
import torch
from tqdm.auto import tqdm, trange
from transformer_lens import HookedTransformer
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
import pandas as pd
import numpy as np
import plotly.express as px 
from collections import defaultdict
import matplotlib.pyplot as plt
import re
from IPython.display import display, HTML
from datasets import load_dataset
from collections import Counter
import pickle
import os
import haystack_utils
from transformer_lens import utils
from fancy_einsum import einsum
import einops
import json
import ipywidgets as widgets
from IPython.display import display
from datasets import load_dataset
import random
import math
import random
import neel.utils as nutils
from neel_plotly import *
import circuitsvis as cv

import hook_utils
import haystack_utils

SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

%reload_ext autoreload
%autoreload 2

In [2]:
model: HookedTransformer = HookedTransformer.from_pretrained("gpt2-large")

data = load_dataset("stas/openwebtext-10k", split="train")
strings = [i for i in data["text"] if len(i)>2000]
len(strings)

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2-large into HookedTransformer


Repo card metadata block was not found. Setting CardData to empty.


7083

### Methods

In [71]:
batched_dot_product = torch.vmap(torch.dot, (0, None))
batched_projection = torch.vmap(haystack_utils.get_collinear_component, (0, None))

def neuron_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron=tuple[int, int]
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    neuron_attrs, neuron_labels = cache.stack_neuron_results(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    neuron_attrs = neuron_attrs.squeeze(1)
    answer_residual_direction = model.W_in[layer, :, neuron]

    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_projection(neuron_attrs[:, i], answer_residual_direction).norm(dim=-1))
    return torch.stack(results), neuron_labels

def components_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron=tuple[int, int]
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    attrs, labels = cache.get_full_resid_decomposition(layer, apply_ln=True, return_labels=True, pos_slice=pos, expand_neurons=False)
    attrs = attrs.squeeze(1)
    answer_residual_direction = model.W_in[layer, :, neuron]

    results = []
    for i in range(attrs.shape[1]):
        results.append(batched_projection(attrs[:, i], answer_residual_direction).norm(dim=-1))
    return torch.stack(results), labels

def resid_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron:tuple[int, int]=(0,0)
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    all_attrs, labels = cache.get_full_resid_decomposition(layer+1, apply_ln=True, return_labels=True, pos_slice=pos)
    all_attrs = all_attrs.squeeze(1)
    
    answer_residual_direction = model.W_in[layer, :, neuron]

    results = []
    for i in range(all_attrs.shape[1]):
        results.append(batched_projection(all_attrs[:, i], answer_residual_direction).norm(dim=-1))
    return torch.stack(results), labels

def get_neuron_mean_acts(model: HookedTransformer, data: list[str], layer_neuron_dict: dict[int, list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
    sorted_layer_neuron_tuples = []
    sorted_acts = []

    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_layer_neuron_tuples) == len(sorted_acts)

    return sorted_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, layer_neuron_dict: dict[int, list[int]]):
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_neuron_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

def get_neuron_loss_increases(model: HookedTransformer, data: list[str], prompt: str, positionwise: bool=False) -> torch.Tensor:
    n_tokens = model.to_tokens(prompt).shape[1] - 1
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return torch.stack(losses).reshape(n_tokens, model.cfg.n_layers * model.cfg.d_mlp)

def compare_dla_and_ablation(model: HookedTransformer, dla_attrs_by_neuron: torch.Tensor, ablation_losses_by_neuron: torch.Tensor, num_neurons=20):
    print("DLA:")
    values, indices = torch.topk(dla_attrs_by_neuron, num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

    print("Ablation:")
    loss_increases_by_neuron = ablation_losses_by_neuron
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy()[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

def get_hook_inputs_for_token_index(model: HookedTransformer, data: list[str], loss_increases_by_neuron: torch.Tensor, k=40):
    values, indices = torch.topk(loss_increases_by_neuron, k)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    layer_neuron_dict = defaultdict(list)
    for layer, neuron in zip(layer_indices, neuron_indices):
        layer_neuron_dict[layer].append(neuron)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []
    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

def unravel_top_k(neuron_attrs: torch.Tensor, k: int=10):
    values, indices = torch.topk(neuron_attrs, k)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    return list(zip(layer_indices.tolist(), neuron_indices.tolist()))

def resid_to_head_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        head: tuple[int, int],
        pos=np.s_[-1:], 
        
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition and return the composition of each element of the given K matrix. Unbatched.'''
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(prompt)
    layer, head_index = head
    all_attrs, labels = cache.get_full_resid_decomposition(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    all_attrs = all_attrs.squeeze(1)
    answer_residual_direction = model.W_K[layer, head_index, :]
    results = torch.zeros(all_attrs.shape[1], all_attrs.shape[0], answer_residual_direction.shape[1])
    for i in range(all_attrs.shape[1]): # for each token
        for j in range(answer_residual_direction.shape[1]): # for each direction in head input
            token_attrs = all_attrs[:, i]
            answer = answer_residual_direction[:, j]
            results[i, :, j] = batched_projection(token_attrs, answer).norm(dim=-1)
    return results, labels


def mask_scores(attn_scores: Float[Tensor, "query_nctx key_nctx"]):
    '''Mask the attention scores so that tokens don't attend to previous tokens.'''
    # assert attn_scores.shape == (model.cfg.n_ctx, model.cfg.n_ctx)
    mask = torch.tril(torch.ones_like(attn_scores)).bool()
    neg_inf = torch.tensor(-1.0e6).to(attn_scores.device)
    masked_attn_scores = torch.where(mask, attn_scores, neg_inf)
    return masked_attn_scores
    
def resid_to_head_DLA_custom(
        model: HookedTransformer, 
        prompt: str | list[str], 
        head: tuple[int, int]
        
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''For last two tokens, figure out which components contribute the most to them paying attention to each other.'''
    _, cache = model.run_with_cache(prompt)
    layer, head_index = head

    all_attrs, labels = cache.get_full_resid_decomposition(layer, apply_ln=True, return_labels=True, pos_slice=np.s_[-2:], expand_neurons=False)
    all_attrs = all_attrs.squeeze(1).permute(1, 0, 2)

    W_QK = model.W_Q[layer, head_index] @ model.W_K[layer, head_index].T

    pos_by_pos_scores = all_attrs[0] @ W_QK @ all_attrs[1].T
    # masked_scaled = mask_scores(pos_by_pos_scores / model.cfg.d_head ** 0.5)
    # pos_by_pos_pattern = torch.softmax(masked_scaled, dim=-1)
    return pos_by_pos_scores, labels

### Act

In [7]:
# # Redo the upstream neurons method to include full MLP components in addition to the individual neuron break down
# # L0N1595

# attrs, labels = components_to_context_neuron_DLA(model, " An eye for an", pos=np.s_[-1:], context_neuron=(0, 1595))

Tried to stack head results when they weren't cached. Computing head results now


In [13]:
prompts = [
    "I climbed up the pear tree and picked a pear. I climbed up the apple tree and picked an",
    "I climbed up the pear tree and picked a pear. I climbed up the orange tree and picked an",
    "I climbed up the apple tree and picked an apple. I climbed up the pear tree and picked a",
    "I climbed up the apple tree and picked an apple. I climbed up the banana tree and picked a",
    "I climbed up the apple tree and picked an apple. I climbed up the cherry tree and picked a",
]

In [16]:
common_objects = [
    "pen", "hat", "cup", "bag", "box", "car", "dog", "cat", "key", "bed",
    "pot", "pan", "jar", "jug", "rug", "bat", "ball", "shoe", "ship", "bike",
    "desk", "door", "bell", "book", "bowl", "coin", "comb", "cord", "disk", "doll",
    "drum", "flag", "fork", "lamp", "lock", "mug", "nail", "pipe", "ring", "rope",
    "seed", "skirt", "spoon", "stamp", "star", "stick", "tent", "tie", "tooth", "toy",
    "tree", "watch", "whip", "bird", "boat", "boot", "cane", "card", "chain", "chair",
    "chalk", "clock", "cloth", "cloud", "coat", "crab", "disk", "dress", "drop", "drum",
    "duck", "dust", "fence", "flag", "floor", "flower", "fly", "fog", "fork", "fruit",
    "glass", "glove", "grass", "hair", "hand", "harp", "hat", "hill", "horn", "horse",
    "house", "island", "jewel", "jug", "kettle", "key", "kite", "knife", "leaf", "leg",
    "library", "light", "line", "loaf", "lock", "machine", "man", "map", "moon", "net",
    "nose", "nut", "office", "orange", "oven", "parcel", "pen", "pencil", "picture", "pig",
    "pin", "pipe", "plane", "plate", "plough", "pocket", "pot", "potato", "prison", "pump",
    "rail", "rat", "receipt", "ring", "rod", "roof", "root", "sail", "school", "scissors",
    "screw", "seed", "sheep", "shelf", "ship", "shirt", "shoe", "skin", "skirt", "snake",
    "sock", "spade", "sponge", "spoon", "spring", "square", "stamp", "star", "station", "stem",
    "stick", "stocking", "stomach", "store", "street", "sun", "table", "tail", "thread", "throat",
    "thumb", "ticket", "toe", "tongue", "tooth", "town", "train", "tray", "tree", "trousers",
    "umbrella", "wall", "watch", "wheel", "whistle", "window", "wire", "wing", "worm", "yarn"
]
vowel_objects = [
    "apple", "apron", "arm", "ankle", "arrow", "atom", "ant", "anchor", "album", "axe",
    "ear", "egg", "elbow", "engine", "eagle", "earring", "envelope", "eye", "eel", "earth",
    "ice", "iron", "ink", "island", "ivy", "igloo", "insect", "instrument", "image", "indicator",
    "oak", "oar", "ocean", "octopus", "onion", "orange", "organ", "oven", "owl", "ox",
    "umbrella", "urn", "utensil", "uniform", "ukelele", "unit", "unicorn", "upstairs", "underwear", "urchin",
    "emerald", "end", "elephant", "elm", "easel", "eraser", "eskimo", "entrance", "estate", "echo",
    "ash", "art", "armchair", "air", "arch", "anvil", "alloy", "alley", "atom", "amulet",
    "olive", "opera", "opal", "ottoman", "orchid", "orbit", "ostrich", "oxen", "oil", "ounce",
    "iceberg", "iris", "idea", "iguanodon", "inlet", "icon", "input", "isle", "itch", "issue",
    "udder", "uplift", "update", "upgrade", "undo", "uptake", "upbeat", "upturn", "upload", "upstream",
    "antenna", "almond", "arena", "aorta", "ape", "asteroid", "aster", "auction", "audio", "avocado",
    "edge", "eel", "eel", "equipment", "escalator", "essence", "emblem", "echo", "engineer", "equator",
    "opal", "orchard", "oboe", "oval", "oven", "overcoat", "oyster", "ounce", "outlet", "outline",
    "aerial", "airplane", "awning", "award", "agent", "agate", "arc", "arena", "armadillo", "apricot"
]
word_list = common_objects[:50] + vowel_objects[:50]
token_lengths = [len(model.to_tokens(" "+word, prepend_bos=False).squeeze(0)) for word in word_list]
word_list = [word_list[i] for i in range(len(word_list)) if token_lengths[i]==1]
len(word_list)
# word_list
# %%
prompt_template = "I climbed up the pear tree and picked a pear. I climbed up the {} tree and picked"
prompt_list = [prompt_template.format(word) for word in word_list]
tree_tokens = model.to_tokens(prompt_list)

In [17]:

prompt_template = "I climbed up the pear tree and picked a pear. I climbed up the {} tree and picked an"
vowel_prompt_list = [prompt_template.format(word) for word in vowel_objects]

prompt_template = "I climbed up the pear tree and picked a pear. I climbed up the {} tree and picked a"
common_objects_prompt_list = [prompt_template.format(word) for word in common_objects]

In [42]:
common_objects_prompts = []
vowel_prompts = []

multi_token_vowel_prompts = []
multi_token_common_objects_prompts = []

for item in common_objects_prompt_list:
    if model.to_tokens(item).shape == (1, 21):
        common_objects_prompts.append(item)
    else:
        multi_token_common_objects_prompts.append(item)

for item in vowel_prompt_list:
    if model.to_tokens(item).shape == (1, 21):
        vowel_prompts.append(item)
    else:
        multi_token_vowel_prompts.append(item)

print(len(vowel_prompts), len(common_objects_prompts))
print(len(multi_token_vowel_prompts), len(multi_token_common_objects_prompts))

96 186
44 4


In [46]:
# Neuron seems to work consistently across fruits


LAYER, NEURON = 0, 1595

def hook(value, hook):
    value[0, -5, NEURON] = 0
hook_name = f'blocks.{LAYER}.mlp.hook_post'
hooks = [(hook_name, hook)]

vowel_loss_diffs = []
vowel_original_losses = []
for prompt in vowel_prompts[:50]:
    loss, cache = model.run_with_cache(prompt, return_type='loss', loss_per_token=True, names_filter=[hook_name])
    with model.hooks(hooks):
        ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
    vowel_loss_diffs.append((ablated_loss[0, -1] - loss[0, -1]).item())
    vowel_original_losses.append(loss[0, -1].item())

def multi_token_hook(value, hook):
    value[0, -6:-4, NEURON] = 0
hook_name = f'blocks.{LAYER}.mlp.hook_post'
multi_token_hooks = [(hook_name, multi_token_hook)]

multi_token_vowel_loss_diffs = []
multi_token_original_losses = []
for prompt in multi_token_vowel_prompts[:50]:
    loss, cache = model.run_with_cache(prompt, return_type='loss', loss_per_token=True, names_filter=[hook_name])
    with model.hooks(multi_token_hooks):
        ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
    multi_token_vowel_loss_diffs.append((ablated_loss[0, -1] - loss[0, -1]).item())
    multi_token_original_losses.append(loss[0, -1].item())


consonant_losses = []
for prompt in common_objects_prompts[:50]:
    loss, cache = model.run_with_cache(prompt, return_type='loss', loss_per_token=True, names_filter=[hook_name])
    with model.hooks(hooks):
        ablated_loss, ablated_cache = model.run_with_cache(prompt, return_type='loss', loss_per_token=True)
    consonant_losses.append((ablated_loss[0, -1] - loss[0, -1]).item())

    # haystack_utils.line(cache[hook_name][0, :, NEURON].cpu().tolist())

print(np.mean(vowel_loss_diffs), np.mean(consonant_losses), np.mean(multi_token_vowel_loss_diffs))
print(np.mean(multi_token_original_losses), np.mean(vowel_original_losses))

0.12205264687538148 0.0005269411206245422 0.04601924934170463
1.435411212119189 1.4191379714012147


In [None]:
from torch import einsum

values = []
pos = np.s_[-5:-4]
results = []

for prompt in vowel_prompts[:50]:
    _, cache = model.run_with_cache(prompt)
    attrs, labels = cache.get_full_resid_decomposition(LAYER  + 1, apply_ln=True, return_labels=True, pos_slice=pos, expand_neurons=False)
    attrs = attrs.squeeze(1)
    answer_residual_direction = model.W_in[LAYER, :, NEURON]

    results.append(einsum("c d, d -> c", attrs.squeeze(1), answer_residual_direction))

results = torch.stack(results)



In [88]:
values = results.mean(dim=0).cpu().tolist()
values.pop(-4)
labels.pop(-4)


haystack_utils.line(values, xticks=labels, title=f"DLA for 'pear', {model.to_str_tokens(model.to_tokens(prompt)[0, pos])}")