### Setup

In [11]:
import torch
from transformer_lens import HookedTransformer
from jaxtyping import Float
from torch import Tensor
import plotly.io as pio
import numpy as np
import pandas as pd
from tqdm import trange
from collections import defaultdict, Counter
from torchmetrics.regression import SpearmanCorrCoef
import plotly_express as px

pio.renderers.default = "notebook_connected+notebook"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

import haystack_utils
import hook_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]

LAYER, NEURON = 3, 669

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


### Utils

In [3]:
def batched_dot_product(x: torch.Tensor, y: torch.Tensor):
    return torch.vmap(torch.dot)(x, y)
    
def neuron_to_context_neuron_DLA(
        model: HookedTransformer, 
        prompt: str | list[str], 
        pos=np.s_[-1:], 
        context_neuron=tuple[int, int]
) -> tuple[Float[Tensor, "component"], list[str]]:
    '''Gets full resid decomposition including all neurons. Unbatched.'''
    tokens = model.to_tokens(prompt)
    _, cache = model.run_with_cache(prompt)
    layer, neuron = context_neuron
    neuron_attrs, neuron_labels = cache.stack_neuron_results(layer, apply_ln=True, return_labels=True, pos_slice=pos)
    neuron_attrs = neuron_attrs.squeeze(1)
    
    answer_residual_direction = model.W_in[layer, :, neuron].unsqueeze(0)  # [1 d_model]

    results = []
    for i in range(neuron_attrs.shape[1]):
        results.append(batched_dot_product(neuron_attrs[:, i], answer_residual_direction.repeat(neuron_attrs.shape[0], 1)))
    return torch.stack(results), neuron_labels

def get_neuron_mean_acts(model: HookedTransformer, data: list[str], layer_neuron_dict: dict[int, list[int]]) -> tuple[torch.Tensor, torch.Tensor]:
    sorted_layer_neuron_tuples = []
    sorted_acts = []

    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, hook_pre=False, neurons=neurons, disable_tqdm=True)
        sorted_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_layer_neuron_tuples) == len(sorted_acts)

    return sorted_layer_neuron_tuples, sorted_acts

def get_unspecified_neurons(model: HookedTransformer, layer_neuron_dict: dict[int, list[int]]):
    unspecified = []
    for layer in range(model.cfg.n_layers):
        for neuron in range(model.cfg.d_mlp):
            if not neuron in layer_neuron_dict[layer]:
                unspecified.append((layer, neuron))
    return unspecified

def get_neuron_loss_increases(model: HookedTransformer, data: list[str], prompt: str, positionwise: bool=False) -> torch.Tensor:
    n_tokens = model.to_tokens(prompt).shape[1] - 1
    original_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
    
    losses = []
    for layer in trange(model.cfg.n_layers):
        mean_acts = haystack_utils.get_mlp_activations(data[:200], layer, model, disable_tqdm=True, context_crop_start=0)
        for neuron in range(model.cfg.d_mlp):
            hook = hook_utils.get_ablate_neuron_hook(layer, neuron, mean_acts[neuron])
            with model.hooks([hook]):
                ablated_loss = model([prompt], return_type='loss', loss_per_token=positionwise)
                losses.append((ablated_loss - original_loss)[0])
    return torch.stack(losses).reshape(n_tokens, model.cfg.n_layers * model.cfg.d_mlp)

def compare_dla_and_ablation(model: HookedTransformer, dla_attrs_by_neuron: torch.Tensor, ablation_losses_by_neuron: torch.Tensor, num_neurons=20):
    print("DLA:")
    values, indices = torch.topk(dla_attrs_by_neuron, num_neurons, dim=-1)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

    print("Ablation:")
    loss_increases_by_neuron = ablation_losses_by_neuron
    values, indices = torch.topk(loss_increases_by_neuron, num_neurons)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy()[:num_neurons], (model.cfg.n_layers, model.cfg.d_mlp))
    print(list(zip(layer_indices.tolist(), neuron_indices.tolist())))
    print(dla_attrs_by_neuron[indices.tolist()])

def get_hook_inputs_for_token_index(model: HookedTransformer, data: list[str], loss_increases_by_neuron: torch.Tensor, k=40):
    values, indices = torch.topk(loss_increases_by_neuron, k)

    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    layer_neuron_dict = defaultdict(list)
    for layer, neuron in zip(layer_indices, neuron_indices):
        layer_neuron_dict[layer].append(neuron)

    sorted_dla_layer_neuron_tuples = []
    sorted_acts = []
    for layer, neurons in layer_neuron_dict.items():
        mean_acts = haystack_utils.get_mlp_activations(data, layer, model, context_crop_start=0, neurons=neurons, disable_tqdm=True)
        sorted_dla_layer_neuron_tuples.extend([(layer, neuron) for neuron in neurons])
        sorted_acts.extend(mean_acts)
        assert len(sorted_dla_layer_neuron_tuples) == len(sorted_acts)

    return sorted_dla_layer_neuron_tuples, sorted_acts

def unravel_top_k(neuron_attrs: torch.Tensor, k: int=10):
    values, indices = torch.topk(neuron_attrs, k)
    layer_indices, neuron_indices = np.unravel_index(indices.cpu().numpy(), (model.cfg.n_layers, model.cfg.d_mlp))
    return list(zip(layer_indices.tolist(), neuron_indices.tolist()))

### Investigate

In [14]:
def upstream_for_prompt(prompt):
    n_tokens = model.to_tokens(prompt).shape[1]
    neuron_attrs_by_token, labels = neuron_to_context_neuron_DLA(model, prompt, np.s_[-n_tokens:], (3, 669))

    counter = Counter()
    for i in range(n_tokens):
        counter.update(unravel_top_k(neuron_attrs_by_token[i], k=10))
    return counter

german_prompt = "beraten. H\u00f6here Investitionen in Forschung und Entwicklung sowie die Erfassung und \
    Verarbeitung von zuverl\u00e4ssigen Daten w\u00fcrde zu einer solideren und nachhaltigen Gemeinsamen \
    Fischereipolitik f\u00fchren.\nAber obwohl der Satz, den ich von einem Wissenschaftler geh\u00f6rt \
    habe (\"Das Problem ist nicht Geld, sondern Personal\") die Lage gut darstellt, werde ich nicht \
    diejenige sein, die sagt, dass die Fischereiforschung gut mit finanziellen Mitteln ausgestattet \
    ist. Ich werde vielmehr sagen, dass wir ein doppeltes Problem haben.\nAn erster Stelle, Herr \
    Kommissar, die im Siebten Rahmenprogramm f\u00fcr Meeresforschung festgelegten Betr\u00e4ge, \
    die ein horizontales Thema h\u00e4tten sein sollen, scheinen f\u00fcr den integrierten Ansatz, \
    der bei dieser Angelegenheit gegenw\u00e4rtig gew\u00fcnscht wird, unzureichend zu sein.\nAu\u00dferdem, \
    Herr Kommissar, haben Wissenschaftler - und ich kann Ihnen versichern, dass ich vor und w\u00e4hrend \
    der Ausarbeitung dieses Berichts mit vielen gesprochen habe - Probleme bei der Einreichung von Projekten \
    unter dem Siebten Forschungsrahmenprogramm. Diese Probleme sind"

english_prompt = "given the generally greater adeptness of children at using audio-visual resources, in some \
    areas there are dangers of their obtaining access to unsuitable or harmful material. This is most obvious \
    in the fields of overt sexual material and gratuitous violence.\nThe principles which have guided this \
    report are to encourage greater public awareness of these issues and to support parental responsibility \
    and to develop co-operation between the content providers, consumer organisations and the \
    respective authorities, both national and European. Self-regulation is considered to be the \
    main instrument, underpinned by legal requirements where necessary.\nThe report, which \
    analyses the Commission's evaluation report, is primarily concerned with the Internet \
    and with video games, as it was felt important not to anticipate a possible future \
    review of the Television without Frontiers directive. The report calls for user-friendly content filter systems"

# print("sample prompt:", upstream_for_prompt(german_prompt).most_common())
n_tokens = model.to_tokens(german_prompt).shape[1]

prompts = [] # german_prompt, english_prompt
for token in [" ä", " ö", " ü", " ß"]:
    prompts.append("".join([token for _ in range(n_tokens)]))
    print(token, upstream_for_prompt(prompts[-1]).most_common())

 ä [((1, 1911), 346), ((1, 835), 333), ((2, 1449), 320), ((2, 181), 309), ((2, 1236), 309), ((2, 983), 308), ((2, 1166), 296), ((2, 1149), 289), ((2, 1003), 249), ((2, 1747), 223), ((0, 751), 139), ((0, 191), 83), ((0, 1452), 59), ((2, 230), 39), ((0, 777), 37), ((2, 621), 30), ((0, 596), 27), ((2, 1299), 14), ((0, 146), 9), ((2, 720), 9), ((1, 1034), 8), ((1, 961), 8), ((1, 309), 7), ((1, 1690), 6), ((2, 1489), 6), ((1, 1388), 4), ((2, 145), 4), ((1, 1109), 3), ((1, 1308), 3), ((1, 70), 2), ((2, 467), 2), ((0, 1987), 1), ((0, 563), 1), ((2, 1888), 1), ((0, 336), 1), ((0, 1874), 1), ((1, 1414), 1), ((1, 1032), 1), ((0, 1435), 1), ((2, 1188), 1)]
 ö [((2, 1449), 335), ((1, 707), 330), ((2, 181), 329), ((2, 1236), 329), ((2, 1931), 326), ((2, 1149), 324), ((2, 1003), 324), ((2, 1166), 322), ((1, 29), 320), ((2, 983), 273), ((2, 863), 73), ((1, 703), 26), ((1, 1966), 23), ((2, 1575), 22), ((0, 736), 20), ((0, 1849), 19), ((0, 645), 19), ((1, 1610), 18), ((1, 1911), 10), ((2, 315), 7), ((2

In [20]:
# Get the rank correlation within long prompts of different types
# Hopefully it's highly correlated
# Get the rank correlation between samples or average of prompts of different types
spearman = SpearmanCorrCoef()

prompt_mean_rhos = torch.zeros(len(prompts))
for prompt_n, prompt in enumerate(prompts):
    n_tokens = model.to_tokens(prompt).shape[1]
    neuron_attrs_by_token, _ = neuron_to_context_neuron_DLA(model, prompt, np.s_[-n_tokens:], (3, 669)) # tokens d_mlp
    average_neuron_attrs = neuron_attrs_by_token.mean(dim=0) # d_mlp

    rhos = torch.zeros(n_tokens)
    for i in range(n_tokens):
        rhos[i] = spearman(neuron_attrs_by_token[i], average_neuron_attrs)
    prompt_mean_rhos[prompt_n] = rhos.mean()

print(prompt_mean_rhos)

average_neuron_attrs = []
for prompt_n, prompt in enumerate(prompts):
    n_tokens = model.to_tokens(prompt).shape[1]
    neuron_attrs_by_token, _ = neuron_to_context_neuron_DLA(model, prompt, np.s_[-n_tokens:], (3, 669))
    average_neuron_attrs.append(neuron_attrs_by_token.mean(dim=0))
    
rhos = []
for i in range(len(average_neuron_attrs)):
    for j in range(i + 1, len(average_neuron_attrs)):
        if i == j:
            continue
        rhos.append(f'{spearman(average_neuron_attrs[i], average_neuron_attrs[j]).item():2f}')

print(rhos)


### First run


# Ablate each neuron in turn and look at how it affects the context neuron value (meaned over all prompts)

# Check whether it activates for German words in an English context, both single common german chars and a full word
# Ablate each neuron in turn and look at how it affects the context neuron value (meaned over a German token position within an English)

# Collect a list

# Look for head that moves German tokens
# Look for head that moves German n-grams
# Look for n-gram detector and see if it moves from there



In [None]:
# Measure context neuron activation for many German tokens that never coalesce into German words
# Measure context neuron activation for English words with a single German token mixed in
# Measure above but make it semantically clear in the English that a German token is about to appear
# Measure above but with a full German word
# Measure above but semantically clear in English

# Optional (if time permits):
# Measure above but with common German unigrams

# Need a way to measure at position

with model.hooks(hook_utils.get)