In [103]:
import torch
import numpy as np
from torch import einsum
from tqdm.auto import tqdm
import seaborn as sns
from transformer_lens import HookedTransformer, ActivationCache, utils
from datasets import load_dataset
from einops import einsum
import pandas as pd
from transformer_lens import utils
from rich.table import Table, Column
from rich import print as rprint
from jaxtyping import Float, Int, Bool
from torch import Tensor
import einops
import functools
from transformer_lens.hook_points import HookPoint
# import circuitsvis
from IPython.display import HTML
from plotly.express import line
import plotly.express as px
from tqdm.auto import tqdm
import json
import gc
import plotly.graph_objects as go

import sklearn
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from plotly.subplots import make_subplots
# Plotly needs a different renderer for VSCode/Notebooks vs Colab argh
import plotly.io as pio
pio.renderers.default = "colab+vscode"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import load_txt_data, get_mlp_activations, line, two_histogram
import haystack_utils

%reload_ext autoreload
%autoreload 2

## Set up model, data, and deactivate German neuron hook

In [104]:
english_neurons = [(5, 395), (5, 166), (5, 908), (5, 285), (3, 862), (5, 73), (4, 896), (5, 348), (5, 297), (3, 1204)]
german_neurons = [(4, 482), (5, 1039), (5, 407), (5, 1516), (5, 1336), (4, 326), (5, 250), (3, 669)]
french_neurons = [(5, 112), (4, 1080), (5, 1293), (5, 455), (5, 5), (5, 1901), (5, 486), (4, 975)]

model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

english_data = haystack_utils.load_txt_data("kde4_english.txt")
german_data = haystack_utils.load_txt_data("wmt_german_large.txt")

english_activations = {}
german_activations = {}
for layer in range(3, 6):
    english_activations[layer] = get_mlp_activations(english_data[:200], layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data[:200], layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
kde4_english.txt: Loaded 1007 examples with 501 to 5295 characters each.
wmt_german_large.txt: Loaded 2459 examples with 800 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

## Check classification accuracy of German neurons

Sanity check: reproduce sparse probe results on the German neuron with Pythia-v1

In [105]:
def run_single_neuron_lr(layer, neuron, num_samples=5000):
    # Check accuracy of logistic regression
    A = torch.concat([german_activations[layer][:num_samples, neuron], english_activations[layer][:num_samples, neuron]]).view(-1, 1).cpu().numpy()
    y = torch.concat([torch.ones(num_samples), torch.zeros(num_samples)]).cpu().numpy()
    A_train, A_test, y_train, y_test = train_test_split(A, y, test_size=0.2)
    lr_model = LogisticRegression()
    lr_model.fit(A_train, y_train)
    test_acc = lr_model.score(A_test, y_test)
    train_acc = lr_model.score(A_train, y_train)
    f1 = sklearn.metrics.f1_score(y_test, lr_model.predict(A_test))
    return train_acc, test_acc, f1
    

def get_neuron_accuracy(layer, neuron, plot=False):
    mean_english_activation = english_activations[layer][:,neuron].mean()
    mean_german_activation = german_activations[layer][:,neuron].mean()
    
    if plot:
        two_histogram(english_activations[layer][:,neuron], german_activations[layer][:,neuron], "English", "German", "Activation", "Frequency", f"L{layer}N{neuron} activations on English vs German text")
    train_acc, test_acc, f1 = run_single_neuron_lr(layer, neuron)
    print(f"\nL{layer}N{neuron}: F1={f1:.2f}, Train acc={train_acc:.2f}, and test acc={test_acc:.2f}")
    print(f"Mean activation English={mean_english_activation:.2f}, German={mean_german_activation:.2f}")


In [106]:
for layer, neuron in german_neurons:
    get_neuron_accuracy(layer, neuron)


L4N482: F1=0.90, Train acc=0.90, and test acc=0.91
Mean activation English=-0.07, German=1.21

L5N1039: F1=0.85, Train acc=0.84, and test acc=0.83
Mean activation English=1.02, German=-0.06

L5N407: F1=0.64, Train acc=0.64, and test acc=0.64
Mean activation English=5.23, German=3.70

L5N1516: F1=0.77, Train acc=0.76, and test acc=0.78
Mean activation English=2.31, German=1.02

L5N1336: F1=0.96, Train acc=0.96, and test acc=0.96
Mean activation English=-0.06, German=1.40

L4N326: F1=0.82, Train acc=0.83, and test acc=0.83
Mean activation English=0.03, German=0.81

L5N250: F1=0.76, Train acc=0.77, and test acc=0.78
Mean activation English=-0.00, German=-0.04

L3N669: F1=0.99, Train acc=0.99, and test acc=0.99
Mean activation English=-0.07, German=3.82


## Check loss increase from disabling each German neuron on German data

In [107]:
mean_original_loss, mean_ablated_loss, percent_increase = haystack_utils.get_ablated_performance(german_data[:1000], model, deactivate_neurons_fwd_hooks)
print(f"Mean original loss={mean_original_loss:.2f}, mean ablated loss={mean_ablated_loss:.2f}, percent increase={percent_increase:.2f}")

  0%|          | 0/1000 [00:00<?, ?it/s]

Mean original loss=3.18, mean ablated loss=3.55, percent increase=11.57


In [108]:
def get_deactivate_single_neuron_hook(layer, neuron, english_activations):
    def deactivate_single_neuron_hook(value, hook):
        value[:, :, neuron] = english_activations[layer][:, neuron].mean()
        return value
    return deactivate_single_neuron_hook

In [109]:
# Loss increase from ablating a single neuron. Measures how useful each neuron is when all others are enabled (total effect).
for layer, neuron in tqdm(german_neurons):
    fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_deactivate_single_neuron_hook(layer, neuron, english_activations))]
    mean_original_loss, mean_ablated_loss, percent_increase = haystack_utils.get_ablated_performance(german_data[:1000], model, fwd_hooks, display_tqdm=False)
    print(f"L{layer}N{neuron}: Loss original={mean_original_loss:.2f}, ablated={mean_ablated_loss:.2f} (+{percent_increase:.2f}%)")

  0%|          | 0/8 [00:00<?, ?it/s]

L4N482: Loss original=3.18, ablated=3.20 (+0.60%)
L5N1039: Loss original=3.18, ablated=3.18 (+0.16%)
L5N407: Loss original=3.18, ablated=3.19 (+0.22%)
L5N1516: Loss original=3.18, ablated=3.18 (+0.09%)
L5N1336: Loss original=3.18, ablated=3.20 (+0.59%)
L4N326: Loss original=3.18, ablated=3.18 (+0.12%)
L5N250: Loss original=3.18, ablated=3.18 (+0.01%)
L3N669: Loss original=3.18, ablated=3.55 (+11.57%)


The huge loss increase when L3N669 is disabled implies either that the other neurons aren't backups for L3N669, or the components in layer 4 that can't read from later layer neurons are significant. 

TODO: Correlational analyses of these neurons would let us narrow down whether they just activate when L3N669 does (downstream neurons with a different function specific to German) or if they're also independent German-detectors.
We could also compare loss from disabling all context neurons, with loss from enabling just one of the minor context neurons. If the loss decrease is greater than 0.6%, we can say that it acts as a backup for other context neurons.

In [110]:
# Sanity check: also calculate loss on English
print("English loss impact (should be close to 0)")
for layer, neuron in tqdm(german_neurons):
    fwd_hooks=[(f'blocks.{layer}.mlp.hook_post', get_deactivate_single_neuron_hook(layer, neuron, english_activations))]
    mean_original_loss, mean_ablated_loss, percent_increase = haystack_utils.get_ablated_performance(english_data[:1000], model, fwd_hooks, display_tqdm=False)
    print(f"L{layer}N{neuron}: Loss original={mean_original_loss:.2f}, ablated={mean_ablated_loss:.2f} (+{percent_increase:.2f}%)")

English loss impact (should be close to 0)


  0%|          | 0/8 [00:00<?, ?it/s]

L4N482: Loss original=3.96, ablated=3.96 (+0.02%)
L5N1039: Loss original=3.96, ablated=3.96 (+0.09%)
L5N407: Loss original=3.96, ablated=3.96 (+0.02%)
L5N1516: Loss original=3.96, ablated=3.95 (+-0.01%)
L5N1336: Loss original=3.96, ablated=3.96 (+-0.00%)
L4N326: Loss original=3.96, ablated=3.96 (+-0.00%)
L5N250: Loss original=3.96, ablated=3.95 (+-0.02%)
L3N669: Loss original=3.96, ablated=3.96 (+0.00%)


## Check loss breakdown by component for L3N669

In [111]:
def deactivate_neurons_hook(value, hook):
        value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
        return value

def activate_neuron_hook(value, hook):
        value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
        return value

ABLATE_HOOK=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]
ACTIVATE_HOOK=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neuron_hook)]


In [112]:
logit_attr_original, labels = haystack_utils.DLA(german_data[:1000], model)

# Patch in disabled context neurons and plot the direct logit attribution difference for each component
with model.hooks(fwd_hooks=ABLATE_HOOK):
    logit_attr_ablated, _ = haystack_utils.DLA(german_data[:1000], model)

logit_diffs = (logit_attr_original - logit_attr_ablated).mean(0)

  0%|          | 0/1000 [00:00<?, ?it/s]

  0%|          | 0/1000 [00:00<?, ?it/s]

In [113]:
# The small discrepancies accumulated before the ablation occur during the call to cache.apply_ln_to_stack(). This seems like a bug, but they seem small enough to ignore.
haystack_utils.line(logit_diffs.cpu().numpy(), xlabel="Correct logit", ylabel="", title="(Original DLA - Ablated DLA) per component", xticks=labels)

In [114]:
# Loss increase from patching in components from the forward pass with disabled German context neuron.
# Does not include the effect of the patched component on other components.
# Does include the effect of the pre-patched component on other components.
# Does include the contribution of the pre-patched component to the layer normalizations of the residual stream.
component_names = ['embed', '0_attn_out', '0_mlp_out', '1_attn_out', '1_mlp_out', '2_attn_out', '2_mlp_out', '3_attn_out', '3_mlp_out', '4_attn_out', '4_mlp_out', '5_attn_out', '5_mlp_out']
components = []
losses = []
for later_component in range(8, 13):
    print(f"Component: {component_names[later_component]}")
    original_loss, patched_loss = haystack_utils.get_direct_loss_increase_for_component(german_data[:1000], model, fwd_hooks=deactivate_neurons_fwd_hooks, patched_component=later_component, disable_progress_bar=True)
    if len(losses) == 0:
        components.append("Original loss")
        losses.append(original_loss)
    components.append(component_names[later_component])
    losses.append(patched_loss)

Component: 3_mlp_out


AttributeError: module 'haystack_utils' has no attribute 'get_loss_increase_for_component'

In [None]:
def line(x, xlabel="", ylabel="", title="", xticks=None, width=800, hover_data=None):
    fig = px.line(x, title=title)
    fig.update_layout(xaxis_title=xlabel, yaxis_title=ylabel, width=width)
    if xticks != None:
        fig.update_layout(
            xaxis = dict(
            tickmode = 'array',
            tickvals = [i for i in range(len(xticks))],
            ticktext = xticks
            )
        )
    if hover_data != None:
        fig.update(data=[{'customdata': hover_data, 'hovertemplate': "Loss: %{y:.4f} (+%{customdata:.2f}%)"}])
    fig.show()

In [None]:
percent_increase = ((np.array(losses) - losses[0]) / losses[0]) * 100
line(losses, xlabel="Component", ylabel="Loss", title="Loss of patching individual components when ablating L3N669", xticks=components, width=800, hover_data=percent_increase.tolist())

## Tokens boosted by L3 directly

In [None]:
def unembed_residual(cache, layer, apply_ln=True):
    name = f'blocks.{layer}.hook_resid_post'
    if apply_ln:
        residual = model.ln_final(cache[name])
    else:
        residual = cache[name]
    logits = model.unembed(residual)
    return logits

In [None]:
# Check if it is done correctly
logits_original, cache_original = model.run_with_cache(german_data[:1])
final_residual_unembed = unembed_residual(cache_original, 5)
torch.testing.assert_close(final_residual_unembed, logits_original)

In [None]:
def get_unembed_differences(prompts: list[str], model):
    per_token_differences = torch.zeros(model.cfg.d_vocab).to(device)
    for prompt in tqdm(prompts):
        # Set context neuron to activated value
        with model.hooks(fwd_hooks=ACTIVATE_HOOK):
            _, cache_original = model.run_with_cache(prompt)
        # Ablate context neuron
        with model.hooks(fwd_hooks=ABLATE_HOOK):
            _, cache_ablated = model.run_with_cache(prompt)
        
        # Get normalized MLP output
        name = f'blocks.{LAYER_TO_ABLATE}.hook_resid_post'
        manual_unembed_original = unembed_residual(cache_original, LAYER_TO_ABLATE)
        manual_unembed_ablated = unembed_residual(cache_ablated, LAYER_TO_ABLATE)
        
        # Shape batch pos d_vocab
        logit_difference = (manual_unembed_ablated - manual_unembed_original).mean((0, 1))
        per_token_differences += logit_difference
    
    return per_token_differences / len(prompts)


In [None]:
token_differences = get_unembed_differences(german_data[:1000], model)
boosted_values, boosted_tokens = torch.topk(token_differences, 1000)
inhibited_values, inhibited_tokens = torch.topk(token_differences, 1000, largest=False)
boosted_labels = model.to_str_tokens(boosted_tokens)
inhibited_labels = model.to_str_tokens(inhibited_tokens)

  0%|          | 0/1000 [00:00<?, ?it/s]

In [None]:
px.histogram(token_differences.cpu().numpy(), title="Histogram of L3 logit difference between original and ablated model", labels={"value": "Logit difference"})

In [None]:
num_tokens = 100
line(boosted_values.cpu().numpy()[:num_tokens], xlabel="Token", ylabel="Logit increase from context neuron", xticks=boosted_labels[:num_tokens], title=f"Top boosted tokens from L3N669", width=1100)


In [None]:
num_tokens = 100
stripped_labels=[x.strip() for x in inhibited_labels[:num_tokens]]
line(inhibited_values.cpu().numpy()[:num_tokens], xlabel="Token", ylabel="Logit increase from context neuron", xticks=stripped_labels, title=f"Top inhibited tokens from L3N669", width=1000)


In [None]:
# Unembed neuron direction directly

# Only works for individual neurons
# Shape batch pos d_resid
neuron_weight = model.W_out[LAYER_TO_ABLATE, NEURONS_TO_ABLATE].view(1, 1, -1)
neuron_direction_active = neuron_weight * MEAN_ACTIVATION_ACTIVE # Set German neuron to activated value (~3)
neuron_direction_inactive = neuron_weight * MEAN_ACTIVATION_INACTIVE # Set German neuron to disabled value (~0)

tokens_active = model.unembed(neuron_direction_active)
tokens_inactive = model.unembed(neuron_direction_inactive)
# Active: German neuron is active - we expect German tokens boosted
# Inactive: German neuron is inactive - we expect no boost to German tokens
# Active - Inactive: If the neuron boosts German tokens, we expect this to be positive
token_differences = (tokens_active - tokens_inactive).flatten()

boosted_values, boosted_tokens = torch.topk(token_differences, 1000)
inhibited_values, inhibited_tokens = torch.topk(token_differences, 1000, largest=False)
boosted_labels = model.to_str_tokens(boosted_tokens)
inhibited_labels = model.to_str_tokens(inhibited_tokens)

#px.histogram(token_differences.cpu().numpy(), title="Histogram of L3N669 direct logit difference between original and ablated model", labels={"value": "Logit difference"})

num_tokens = 100
stripped_labels=[x.strip() for x in inhibited_labels[:num_tokens]]
line(boosted_values.cpu().numpy()[:num_tokens], xlabel="Token", ylabel="Logit increase from context neuron", xticks=stripped_labels[:num_tokens], title=f"Top boosted tokens from L3N669", width=1100)


## Get German unigram statistics

In [None]:
# Get top German unigrams
def count_token_occurrences(prompts: list[str]):
    token_counts = torch.zeros(model.cfg.d_vocab).to(device)
    for prompt in tqdm(prompts):
        # Remove BOS
        tokens = model.to_tokens(prompt).flatten()[1:]
        token_counts[tokens] += 1
    return token_counts

In [None]:
german_unigram_counts = count_token_occurrences(german_data)
german_unigram_highest_counts, german_unigram_tokens = torch.topk(german_unigram_counts, 100)
german_unigram_labels = model.to_str_tokens(german_unigram_tokens)

num_tokens = 100
line(german_unigram_highest_counts.cpu().numpy()[:num_tokens], 
     xlabel="Token", ylabel="Counts", 
     xticks=german_unigram_labels[:num_tokens], 
     title=f"Top unigrams in WMT German data", 
     width=1100)


  0%|          | 0/2459 [00:00<?, ?it/s]

In [None]:
top_german_token_differences = token_differences[german_unigram_tokens]
print(top_german_token_differences.mean())
px.histogram(top_german_token_differences.cpu().numpy())

tensor(1.7974, device='cuda:0')


In [None]:
english_unigram_counts = count_token_occurrences(english_data)
english_unigram_highest_counts, english_unigram_tokens = torch.topk(english_unigram_counts, 100)
english_unigram_labels = model.to_str_tokens(english_unigram_tokens)

num_tokens = 100
line(english_unigram_highest_counts.cpu().numpy()[:num_tokens], 
     xlabel="Token", ylabel="Counts", 
     xticks=english_unigram_labels[:num_tokens], 
     title=f"Top unigrams in KDE English data", 
     width=1100)

  0%|          | 0/1007 [00:00<?, ?it/s]

In [None]:
top_english_token_differences = token_differences[english_unigram_tokens]
print(top_english_token_differences.mean())
px.histogram(top_english_token_differences.cpu().numpy(), title="Token differences for top English unigrams")

tensor(0.3300, device='cuda:0')


In [None]:
english_labels = model.to_str_tokens(english_unigram_tokens)
line(top_english_token_differences.cpu().numpy()[:num_tokens], 
     xlabel="Token", ylabel="Counts", 
     xticks=english_labels, 
     title=f"English token differences", 
     width=1100)