In [4]:
# Handy snippet to get repo root from anywhere in the repo
import sys
from subprocess import check_output
ROOT = check_output('git rev-parse --show-toplevel', shell=True).decode("utf-8").strip()
if ROOT not in sys.path: sys.path.append(ROOT)

In [38]:
import torch as t
import einops

from datasets import load_dataset
from tqdm.auto import trange
from circuitsvis.activations import text_neuron_activations
from llama2 import load_model

from jaxtyping import Float


t.set_grad_enabled(False)
device = "cuda" if t.cuda.is_available() else "cpu"

In [2]:
model = load_model(device=device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



In [153]:
directions = t.load(f"{ROOT}/directions/llama2-7b_cities_mm.pt")
directions.shape
"shape: (layer pos d_model)"

'shape: (layer pos d_model)'

In [113]:
city_activations = t.load(f"{ROOT}/activations/llama2-7b_cities.pt")
city_activations.shape

torch.Size([1496, 32, 2, 4096])

In [162]:
city_projs = einops.einsum(
    city_activations,
    directions / directions.norm(dim=-1, keepdim=True),
    "statement layer pos d_model, layer pos d_model -> pos layer statement"
)
"shape: (pos layer statement)"
city_means = city_projs.mean(dim=-1).to(device)
"shape: (pos layer)"
city_stds = city_projs.std(dim=-1).to(device)
"shape: (pos layer)"

'shape: (pos layer)'

In [180]:
def get_direction_activations(model, prompt, scale=None, device=device):
    _, cache = model.run_with_cache(
        prompt,
        return_type="loss",
        names_filter=lambda name: "resid_post" in name,
    )
    resids = t.stack([cache[f"blocks.{layer}.hook_resid_post"][0] for layer in range(model.cfg.n_layers)], dim=0)
    activations = einops.einsum(
    resids,  # shape: (layer tok_pos d_model)
    (directions / directions.norm(dim=-1, keepdim=True)).to(device),
    "layer tok_pos d_model, layer dir_pos d_model -> tok_pos dir_pos layer"
)
    if scale == "by_tokens":
        mean = activations.mean(dim=0, keepdim=True)
        std = activations.std(dim=0, keepdim=True)
        activations = (activations - mean) / std
    if scale == "by_city":
        mean = city_means[None, :, :]
        std = city_stds[None, :, :]
        activations = (activations - mean) / std

    return activations

In [164]:
# dataset = load_dataset('c4', 'en', split='train', streaming=True)
dataset = load_dataset('wikipedia', '20220301.en', split='train', streaming=True)
iter_dataset = iter(dataset)

In [186]:
# prompt = next(iter_dataset)["text"]
# prompt = "The city of Paris is in Australia."
# prompt = "The city of Rome is in Italy."
prompt = "The city of Paris is in France. The city of Beijing is in Australia. The city of Rome is in Italy."

tokens_str = model.to_str_tokens(prompt)
activations = get_direction_activations(model, prompt, scale="by_tokens")
text_neuron_activations(
    tokens_str,
    activations,
    "Direction Pos",
    "Layer",
    ["Penultimate", "Final"],
)