In [None]:
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor
import plotly.io as pio
import numpy as np
import pandas as pd
from tqdm import trange
from collections import defaultdict, Counter

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
import hook_utils

%reload_ext autoreload
%autoreload 2

In [None]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

LAYER, NEURON = 3, 669

### Utils

In [23]:
def batched_dot_product(x: torch.Tensor, y: torch.Tensor):
    return torch.vmap(torch.dot)(x, y)
    
def neuron_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron=tuple[int, int]
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    neuron_attrs, neuron_labels = cache.stack_neuron_results(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    neuron_attrs = neuron_attrs.squeeze(1)
    
    answer_residual_direction = model.W_in[layer, :, neuron].unsqueeze(0)  # [1 d_model]

    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_direction.repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def get_neuron_mean_acts(model: HookedTransformer, data: list[str], layer_neuron_dict: dict[int, list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
    sorted_layer_neuron_tuples = []
    sorted_acts = []

    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_layer_neuron_tuples) == len(sorted_acts)

    return sorted_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, layer_neuron_dict: dict[int, list[int]]):
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_neuron_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

def get_neuron_loss_increases(model: HookedTransformer, data: list[str], prompt: str, positionwise: bool=False) -> torch.Tensor:
    n_tokens = model.to_tokens(prompt).shape[1] - 1
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return torch.stack(losses).reshape(n_tokens, model.cfg.n_layers * model.cfg.d_mlp)

def compare_dla_and_ablation(model: HookedTransformer, dla_attrs_by_neuron: torch.Tensor, ablation_losses_by_neuron: torch.Tensor, num_neurons=20):
    print("DLA:")
    values, indices = torch.topk(dla_attrs_by_neuron, num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

    print("Ablation:")
    loss_increases_by_neuron = ablation_losses_by_neuron
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy()[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

def get_hook_inputs_for_token_index(model: HookedTransformer, data: list[str], loss_increases_by_neuron: torch.Tensor, k=40):
    values, indices = torch.topk(loss_increases_by_neuron, k)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    layer_neuron_dict = defaultdict(list)
    for layer, neuron in zip(layer_indices, neuron_indices):
        layer_neuron_dict[layer].append(neuron)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []
    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

def unravel_top_k(neuron_attrs: torch.Tensor, k: int=10):
    values, indices = torch.topk(neuron_attrs, k)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    return list(zip(layer_indices.tolist(), neuron_indices.tolist()))

In [39]:
def upstream_for_prompt(prompt):
    n_tokens = model.to_tokens(prompt).shape[1]
    neuron_attrs_by_token, labels = neuron_to_context_neuron_DLA(model, prompt, np.s_[-n_tokens:], (3, 669))

    counter = Counter()
    for i in range(n_tokens):
        counter.update(unravel_top_k(neuron_attrs_by_token[i], k=10))
    return counter

test_prompt = "ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü"
print(upstream_for_prompt(test_prompt))

test_prompt = "ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü ü"
print(upstream_for_prompt(test_prompt))

Counter({(0, 1259): 79, (2, 306): 74, (2, 983): 73, (1, 347): 72, (1, 13): 67, (2, 1449): 66, (2, 1236): 59, (2, 181): 58, (0, 264): 52, (2, 819): 42, (2, 237): 28, (0, 1186): 20, (2, 1658): 19, (1, 1911): 15, (2, 1166): 15, (2, 950): 14, (0, 637): 10, (2, 743): 10, (0, 1715): 7, (0, 1758): 5, (2, 689): 4, (1, 1765): 3, (2, 1149): 2, (0, 1987): 1, (0, 563): 1, (2, 1888): 1, (0, 336): 1, (0, 1874): 1, (1, 835): 1, (1, 2035): 1, (1, 737): 1, (1, 1034): 1, (1, 1414): 1, (2, 1153): 1, (2, 1087): 1, (1, 1109): 1, (1, 1028): 1, (0, 1318): 1, (1, 509): 1})


### First run


# Ablate each neuron in turn and look at how it affects the context neuron value (meaned over all prompts)

# Check whether it activates for German words in an English context, both single common german chars and a full word
# Ablate each neuron in turn and look at how it affects the context neuron value (meaned over a German token position within an English)

# Collect a list

# Look for head that moves German tokens
# Look for head that moves German n-grams
# Look for n-gram detector and see if it moves from there



In [None]:
# Measure context neuron activation for many German tokens that never coalesce into German words
# Measure context neuron activation for English words with a single German token mixed in
# Measure above but make it semantically clear in the English that a German token is about to appear
# Measure above but with a full German word
# Measure above but semantically clear in English

# Optional (if time permits):
# Measure above but with common German unigrams

# Need a way to measure at position

with model.hooks(hook_utils.get)