In [None]:
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor
import plotly.io as pio
import numpy as np
import pandas as pd
from tqdm import trange

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
import hook_utils
import pythia_160m_utils
import plotting_utils
import probing_utils

%reload_ext autoreload
%autoreload 2

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

LAYER, NEURON = 3, 669

### Utils

In [1]:
def batched_dot_product(x, y):
    return torch.vmap(torch.dot)(x, y)
    
def neuron_DLA(
        prompt: str, 
        model: HookedTransformer, 
        pos=np.s_[-1:]) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons'''
    tokens = model.to_tokens(prompt)
    answers = tokens[:, 1:]
    tokens = tokens[:, :-1]
    
    _, cache = model.run_with_cache(tokens)
    attrs, labels = cache.get_full_resid_decomposition(-1, expand_neurons=True, apply_ln=True, return_labels=True, pos_slice=pos)
    answer_residual_directions = model.tokens_to_residual_directions(answers)
    if answer_residual_directions.ndim == 1:
        answer_residual_directions = answer_residual_directions.unsqueeze(0)  # [1 d_model]
    elif answer_residual_directions.ndim == 3:
        answer_residual_directions = answer_residual_directions[0]  # [pos d_model]
    answer_residual_directions = answer_residual_directions[pos]  # [pos d_model]
    neuron_indices = [i for i in range(len(labels)) if 'N' in labels[i]]
    neuron_labels = [labels[i] for i in neuron_indices]
    neuron_attrs = attrs[neuron_indices, :].squeeze(1)
    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_directions[[i]].repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def get_neuron_mean_acts(dla_layer_neuron_tuples: list[tuple[int, int]], model=model) -> tuple[torch.Tensor, torch.Tensor]:
    layer_neuron_dict = haystack_utils.get_neurons_by_layer(dla_layer_neuron_tuples)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []

    for layer in layer_neuron_dict.keys():
        neurons = layer_neuron_dict[layer]
        mean_acts = haystack_utils.get_mlp_activations(german_data, layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, neurons: list[tuple[int, int]]):
    layer_dict = haystack_utils.get_neurons_by_layer(neurons)
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

def get_neuron_loss_increases(prompt: str, positionwise=False, model=model, data=german_data):
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return losses

def compare_dla_and_ablation(token_index, attrs, losses, num_neurons=20, model=model):
    print("DLA:")
    values, indices = torch.topk(attrs[token_index], num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    dla_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))
    indices_1d_dla = [np.ravel_multi_index(index_2d, (model.cfg.n_layers, model.cfg.d_mlp)) for index_2d in dla_layer_neuron_tuples[:num_neurons]]
    
    print(dla_layer_neuron_tuples[:num_neurons])
    print(attrs[token_index][indices_1d_dla])

    print("Ablation:")
    loss_increases_by_neuron = torch.tensor([loss[token_index] for loss in losses])
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    indices = indices.numpy()
    layer_indices, neuron_indices = np.unravel_index(indices[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(attrs[token_index][indices.tolist()[:num_neurons]])

def get_hook_inputs_for_token_index(loss_increases_by_neuron, model=model, data=german_data, k=40):
    values, indices = torch.topk(loss_increases_by_neuron, k)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    causal_layer_neuron_tuples = list(zip(layer_indices.tolist(), neuron_indices.tolist()))
    layer_neuron_dict = haystack_utils.get_neurons_by_layer(causal_layer_neuron_tuples)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []

    for layer in layer_neuron_dict.keys():
        neurons = layer_neuron_dict[layer]
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

NameError: name 'np' is not defined

### First run


# Ablate each neuron in turn and look at how it affects the context neuron value (meaned over all prompts)

# Check whether it activates for German words in an English context, both single common german chars and a full word
# Ablate each neuron in turn and look at how it affects the context neuron value (meaned over a German token position within an English)

# Collect a list

# Look for head that moves German tokens
# Look for head that moves German n-grams
# Look for n-gram detector and see if it moves from there



In [None]:
# Measure context neuron activation for many German tokens that never coalesce into German words
# Measure context neuron activation for English words with a single German token mixed in
# Measure above but make it semantically clear in the English that a German token is about to appear
# Measure above but with a full German word
# Measure above but semantically clear in English

# Optional (if time permits):
# Measure above but with common German unigrams

# Need a way to measure at position

with model.hooks(hook_utils.get)