In [44]:
import torch
from tqdm.auto import tqdm
from transformer_lens import HookedTransformer, ActivationCache, utils, patching
from jaxtyping import Float, Int, Bool
from torch import Tensor
from tqdm.auto import tqdm
import plotly.io as pio
import ipywidgets as widgets
from IPython.display import display, clear_output
from einops import einsum
import plotly.express as px
import numpy as np
import pandas as pd

pio.renderers.default = "notebook_connected"
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.autograd.set_grad_enabled(False)
torch.set_grad_enabled(False)

from haystack_utils import get_mlp_activations
import haystack_utils

%reload_ext autoreload
%autoreload 2

In [2]:
model = HookedTransformer.from_pretrained("EleutherAI/pythia-70m",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device=device)

german_data = haystack_utils.load_json_data("data/german_europarl.json")[:200]
english_data = haystack_utils.load_json_data("data/english_europarl.json")[:200]


english_activations = {}
german_activations = {}
for layer in range(3, 4):
    english_activations[layer] = get_mlp_activations(english_data, layer, model, mean=False)
    german_activations[layer] = get_mlp_activations(german_data, layer, model, mean=False)

LAYER_TO_ABLATE = 3
NEURONS_TO_ABLATE = [669]
MEAN_ACTIVATION_ACTIVE = german_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()
MEAN_ACTIVATION_INACTIVE = english_activations[LAYER_TO_ABLATE][:, NEURONS_TO_ABLATE].mean()

def deactivate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_INACTIVE
    return value
deactivate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', deactivate_neurons_hook)]

def activate_neurons_hook(value, hook):
    value[:, :, NEURONS_TO_ABLATE] = MEAN_ACTIVATION_ACTIVE
    return value
activate_neurons_fwd_hooks=[(f'blocks.{LAYER_TO_ABLATE}.mlp.hook_post', activate_neurons_hook)]

all_ignore, not_ignore = haystack_utils.get_weird_tokens(model, plot_norms=False)

Using pad_token, but it is not set yet.


Loaded pretrained model EleutherAI/pythia-70m into HookedTransformer
data/german_europarl.json: Loaded 2000 examples with 152 to 2000 characters each.
data/english_europarl.json: Loaded 2000 examples with 165 to 2000 characters each.


  0%|          | 0/200 [00:00<?, ?it/s]

  0%|          | 0/200 [00:00<?, ?it/s]

In [12]:
answer_str = "gen"
answer_token = model.to_single_token(answer_str)
answer_token

1541

In [14]:
mlp5_out = model.W_out[-1]
answer_unembed = model.W_U[:, answer_token]
dot_product = einsum(mlp5_out, answer_unembed, "mlp res, res -> mlp")

In [19]:
positive_boost_neurons = torch.argwhere(dot_product > 0).flatten()
positive_boost_neurons

tensor([   2,    3,    4,  ..., 2041, 2042, 2044], device='cuda:0')

In [21]:
# Our top neurons are selected by the difference in their boost in gen based on the context neuron
# Many other neurons boost gen more per unit of activation
cosine_sim = torch.nn.CosineSimilarity(dim=1)
answer_residual_direction = model.tokens_to_residual_directions("gen")
neuron_weights = model.state_dict()['blocks.5.mlp.W_out']
cosine_sims = cosine_sim(neuron_weights, answer_residual_direction.unsqueeze(0))
positive_cosine_neurons = torch.argwhere(cosine_sims > 0).flatten()

In [30]:
common_tokens = haystack_utils.get_common_tokens(german_data, model, all_ignore, k=50)
prompts = haystack_utils.generate_random_prompts(" Vorschlägen", model, common_tokens, 100, length=20)

  0%|          | 0/200 [00:00<?, ?it/s]

In [53]:
with model.hooks(fwd_hooks=deactivate_neurons_fwd_hooks):
    deactivated_loss, deactivated_cache = model.run_with_cache(prompts)

with model.hooks(fwd_hooks=activate_neurons_fwd_hooks):
    activated_loss, activated_cache = model.run_with_cache(prompts)

deactivated_neuron_activations = deactivated_cache['blocks.5.mlp.hook_post'][:, -2, :].mean(0)
activated_neuron_activations = activated_cache['blocks.5.mlp.hook_post'][:, -2, :].mean(0)

boost_threshold = 0.1
boosted_by_context = activated_neuron_activations > (deactivated_neuron_activations + boost_threshold)
deboosted_by_context = activated_neuron_activations < (deactivated_neuron_activations - boost_threshold)
context_impact = torch.zeros_like(activated_neuron_activations).long()
context_impact[boosted_by_context] = 1
context_impact[deboosted_by_context] = -1

haystack_utils.two_histogram(activated_neuron_activations, deactivated_neuron_activations, "Activated", "Deactivated", "Neuron Activation")

In [80]:
mean_all_activation = haystack_utils.get_mlp_activations(german_data[:100] + english_data[:100], layer=5, model=model, mean=False)

  0%|          | 0/200 [00:00<?, ?it/s]

In [88]:
non_zero_sum = torch.zeros(2048)
for i in range(2048):
    neuron_res = mean_all_activation[:, i]
    positive_sum = torch.mean(neuron_res[neuron_res>0]).item()
    non_zero_sum[i] = positive_sum

In [89]:
px.histogram(non_zero_sum.cpu().numpy())

In [91]:
df = pd.DataFrame({
    "Neuron": list(range(2048)),
    "DotProduct": dot_product.tolist(),
    "MeanGenActivationContextActive": activated_neuron_activations.tolist(),
    "MeanGenActivationContextInactive": deactivated_neuron_activations.tolist(),
    "ActiveBaseline": non_zero_sum.tolist(),
    "ContextImpact": context_impact.tolist(),
})

df["ContextImpactLabel"] = df.apply(lambda row: "Boosted" if row["ContextImpact"] == 1 else "Deboosted" if row["ContextImpact"] == -1 else "Neutral", axis=1)

In [92]:
px.histogram(df, x="DotProduct", color="ContextImpactLabel", title="Context neuron impact on neurons with high correct token dot product", width=800)

In [96]:
df["IsActiveOnGenWithContext"] = df["MeanGenActivationContextActive"] > (df["ActiveBaseline"]/2)
px.histogram(df, x="DotProduct", color="IsActiveOnGenWithContext", title="Active Neurons against their boost on 'gen'", width=800)

In [97]:
df["IsActiveOnGenWithoutContext"] = df["MeanGenActivationContextInactive"] > (df["ActiveBaseline"]/2)
df["IsOnlyActiveWithContext"] = df["IsActiveOnGenWithContext"] & ~df["IsActiveOnGenWithoutContext"]
px.histogram(df, x="DotProduct", color="IsOnlyActiveWithContext", title="Neurons activated by the context neuron against their boost on 'gen'", width=800)

In [99]:
gen_token = model.to_single_token("gen")
gen_token

1541

In [109]:
from datasets import load_dataset

dataset = load_dataset("NeelNanda/pile-10k")
data = dataset["train"]

Found cached dataset parquet (/root/.cache/huggingface/datasets/NeelNanda___parquet/NeelNanda--pile-10k-72f566e9f7c464ab/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7)


  0%|          | 0/1 [00:00<?, ?it/s]

In [112]:
for x in data:
    print(x["text"])
    break

It is done, and submitted. You can play “Survival of the Tastiest” on Android, and on the web. Playing on the web works, but you have to simulate multi-touch for table moving and that can be a bit confusing.

There’s a lot I’d like to talk about. I’ll go through every topic, insted of making the typical what went right/wrong list.

Concept

Working over the theme was probably one of the hardest tasks I had to face.

Originally, I had an idea of what kind of game I wanted to develop, gameplay wise – something with lots of enemies/actors, simple graphics, maybe set in space, controlled from a top-down view. I was confident I could fit any theme around it.

In the end, the problem with a theme like “Evolution” in a game is that evolution is unassisted. It happens through several seemingly random mutations over time, with the most apt permutation surviving. This genetic car simulator is, in my opinion, a great example of actual evolution of a species facing a challenge. But is it a game?



In [113]:
for prompt in data:
    prompt=prompt["text"]
    prompt_tokens = model.to_tokens(prompt)
    if gen_token in prompt_tokens:
        index = prompt_tokens[0].tolist().index(gen_token)
        print(model.to_str_tokens(prompt_tokens[0, index-5:index+5]))

['.', '\n', '\n', '-', ' In', 'gen', ' he', 'k', 'se', 'j']
[' Det', ' sid', 'ste', ' for', ' e', 'gen', ' reg', 'ning', '.', '\n']
['loader', '-', 'bas', 'ics', '-', 'gen', '.', 'lua', '\n', 'share']
[' ho', 'opt', ' te', ' k', 'rij', 'gen', '.', ' H', 'ij', ' no']
['\n', 'Using', ' ssh', '-', 'key', 'gen', ' I', ' created', ' a', ' key']
['an', 'st', 'ieg', ' ent', 'ge', 'gen', 'w', 'ir', 'ken', ' können']
['’', 's', ' spectacular', ' G', 'ug', 'gen', 'heim', ' Bil', 'ba', 'o']
[' cells', ' harbour', 'ing', ' a', ' sub', 'gen', 'omic', ' replic', 'on', ' containing']
['",', ' "', 'aa', '")', '\n', 'gen', ' <-', ' data', '.', 'frame']
['->', 'get', '_', 'where', "('", 'gen', 're', "',", 'array', "('"]
['BM', ')', ' stem', '/', 'pro', 'gen', 'itor', ' cells', ' in', ' regulating']
[' throwing', ' around', ' the', ' word', ' "', 'gen', 'ius', ',"', ' but', ' what']
[' territ', 'or', 'ios', ' ind', 'í', 'gen', 'as', ' y', ' z', 'onas']
[' ocular', ' tissues', '.', '\n', 'Hydro', 'gen', '