In [4]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

In [38]:
import torch as t
import einops

from datasets import load_dataset
from tqdm.auto import trange
from circuitsvis.activations import text_neuron_activations
from llama2 import load_model

from jaxtyping import Float


t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

In [2]:
model = load_model(device=device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [5]:
directions = t.load(f"{ROOT}/directions/llama2-7b_cities_mm.pt")
directions.shape
"shape: (layer pos d_model)"

'shape: (layer pos d_model)'

In [113]:
city_activations = t.load(f"{ROOT}/activations/llama2-7b_cities.pt")
city_activations.shape

torch.Size([1496, 32, 2, 4096])

In [116]:
projections = einops.einsum(
    city_activations,
    directions / directions.norm(dim=-1, keepdim=True),
    "statement layer pos d_model, layer pos d_model -> pos layer statement"
)
"shape: (pos layer statement)"

def city_proj_mean(pos, layer):
    return projections[pos, layer].mean().item()

def city_proj_std(pos, layer):
    return projections[pos, layer].std().item()

'shape: (pos layer statement)'

In [118]:
def get_direction_activations(model, prompt, layer, dir_pos, scale=True, device=device):
    _, cache = model.run_with_cache(
        prompt,
        return_type="loss",
        names_filter=f"blocks.{layer}.hook_resid_post",
    )
    chosen_dir = directions[layer, dir_pos].to(device)
    activations = cache[f'blocks.{layer}.hook_resid_post'][0] @ chosen_dir
    if scale:
        activations = (activations - city_proj_mean(dir_pos, layer)) / city_proj_std(dir_pos, layer)
    return activations

In [83]:
# dataset = load_dataset('c4', 'en', split='train', streaming=True)
dataset = load_dataset('wikipedia', '20220301.en', split='train', streaming=True)
iter_dataset = iter(dataset)

In [91]:
import pandas as pd
df = pd.DataFrame()
text = []
top1_acts = []
top2_acts = []
top3_acts = []
bot1_acts = []
bot2_acts = []
bot3_acts = []

In [92]:
for _ in trange(50):
    sample = next(iter_dataset)
    tokens = model.to_tokens(sample["text"])[..., :256]

    _, cache = model.run_with_cache(
        tokens,
        return_type="loss",
        names_filter="blocks.11.hook_resid_post",
    )
    activations = cache['blocks.11.hook_resid_post'][0] @ chosen_dir

    top1, top2, top3 = activations.topk(3).values.tolist()
    bot1, bot2, bot3 = activations.topk(3, largest=False).values.tolist()

    text.append(sample["text"])
    top1_acts.append(top1)
    top2_acts.append(top2)
    top3_acts.append(top3)
    bot1_acts.append(bot1)
    bot2_acts.append(bot2)
    bot3_acts.append(bot3)


  0%|          | 0/50 [00:00<?, ?it/s]

In [93]:
df["text"] = text
df["top1"] = top1_acts
df["top2"] = top2_acts
df["top3"] = top3_acts
df["bot1"] = bot1_acts
df["bot2"] = bot2_acts
df["bot3"] = bot3_acts

In [98]:
# df.sort_values("top1", ascending=False).head(10)
df.sort_values("bot1", ascending=True).head(10)

Unnamed: 0,text,top1,top2,top3,bot1,bot2,bot3
47,Arraignment is a formal reading of a criminal ...,1347.252563,458.780518,51.564903,-28.007977,-21.545948,-8.347266
48,"""America the Beautiful"" is a patriotic America...",1253.946777,458.780518,51.258194,-23.840515,-13.100718,-9.871751
20,Agricultural science (or agriscience for short...,1381.436768,458.780518,61.355854,-21.421581,-7.659249,-4.730247
29,"The Austroasiatic languages , also known as Mo...",1348.820923,458.780518,52.505398,-17.96032,-5.729125,-5.578624
40,"Aberdeen is a city in Scotland, United Kingdom...",1251.796631,458.780518,78.915894,-17.826765,-17.741428,-16.025068
28,"Andre Kirk Agassi ( ; born April 29, 1970) is ...",1308.811401,458.780518,55.57719,-16.197714,-7.686803,-6.584397
25,Austin is the capital of Texas in the United S...,1260.507935,458.780518,83.717796,-15.112831,-13.808863,-13.666355
39,Ada may refer to:\n\nPlaces\n\nAfrica\n Ada Fo...,1163.383911,458.780518,48.575985,-14.517107,-12.535902,-10.518171
35,"Amphibians are ectothermic, tetrapod vertebrat...",1311.817993,458.780518,37.890564,-14.172328,-3.547485,-3.216492
27,Apollo is one of the Olympian deities in clas...,1322.135254,458.780518,42.068859,-14.035411,-7.777155,-3.137388


In [141]:
# prompt = df.text.iloc[27]
# prompt = next(iter_dataset)["text"]
prompt = "The city of Paris is in Australia. The city of Rome is in Italy."
# prompt = "The city of Rome is in Italy. The city of Paris is in Australia."

In [146]:
tokens_str = model.to_str_tokens(prompt)
# activations = get_direction_activations(model, prompt, 24, 0)
activations = get_direction_activations(model, prompt, 24, 1)
text_neuron_activations(tokens_str, activations[:, None, None])