In [11]:
# %% Import modules
import torch
import einops

from transformer_lens import HookedTransformer, ActivationCache
from transformer_lens.utils import get_act_name


import plotly.express as px
import seaborn as sns
import pandas as pd
import numpy as np
from tqdm.notebook import trange, tqdm

from torch import Tensor
from jaxtyping import Float, Int, Bool
from typing import List, Callable, Optional
from plotting import (
    get_fig_head_to_mlp_neuron,
    get_fig_head_to_mlp_neuron_by_layer,
    get_fig_head_to_selected_mlp_neuron,
    ntensor_to_long,
)
from load_data import get_prompts_t
from utils import (
    projection,
    cos_similarity,
)

In [4]:
# Global settings and variables
sns.set()
torch.set_grad_enabled(False)
device = "cpu"

IPSUM = "Lorem Ipsum is simply dummy text of the printing and typesetting industry. Lorem Ipsum has been the industry's standard dummy text ever since the 1500s, when an unknown printer took a galley of type and scrambled it to make a type specimen book. It has survived not only five centuries, but also the leap into electronic typesetting, remaining essentially unchanged. It was popularised in the 1960s with the release of Letraset sheets containing Lorem Ipsum passages, and more recently with desktop publishing software like Aldus PageMaker including versions of Lorem Ipsum."
# IPSUM = ""

In [5]:
#%% Setup model & load data
model = HookedTransformer.from_pretrained('gelu-4l')
model.cfg.use_attn_result = True
model.to(device)

prompts_t = get_prompts_t()

Loaded pretrained model gelu-4l into HookedTransformer
Moving model to device:  cpu
Loading 80 prompts from c4-tokenized-2b...


100%|██████████| 80/80 [00:02<00:00, 27.82it/s]


Loading 20 prompts from code-tokenized...


100%|██████████| 20/20 [00:02<00:00,  6.69it/s]


In [6]:
examples = [
    {
        "text": 4 * IPSUM + " It's in the cupboard, either on the top or the",
        "correct": " bottom",
        "incorrect": " top",
    },
    {
        "text": 5 * IPSUM + " I went to university at Michigan",
        "correct": " State",
        "incorrect": " University",
    },
    {
        "text": IPSUM + " class MyClass:\n\tdef",
        "correct": " __",
        "incorrect": " on",
    },
    {
        "text": 6 * IPSUM + "The church I go to is the Seventh-day Adventist",
        "correct": " Church",
        "incorrect": " Advent",
    }
]

In [8]:
cache = model.run_with_cache(examples[0]["text"])

In [9]:
cache

(tensor([[[ 6.3408, -6.8591, -6.8589,  ..., -6.8638, -6.8522, -6.8653],
          [ 3.7090, -5.5066, -5.5378,  ..., -5.5730, -5.5053, -5.5130],
          [ 3.1071, -6.9010, -6.8954,  ..., -6.9342, -6.9581, -6.9972],
          ...,
          [ 4.4170, -6.8391, -6.9054,  ..., -6.8645, -6.8689, -6.8825],
          [ 1.5867, -6.7227, -6.7652,  ..., -6.7718, -6.8076, -6.7640],
          [ 2.6618, -6.1671, -6.2009,  ..., -6.1731, -6.2137, -6.2051]]]),
 ActivationCache with keys ['hook_embed', 'hook_pos_embed', 'blocks.0.hook_resid_pre', 'blocks.0.ln1.hook_scale', 'blocks.0.ln1.hook_normalized', 'blocks.0.attn.hook_q', 'blocks.0.attn.hook_k', 'blocks.0.attn.hook_v', 'blocks.0.attn.hook_attn_scores', 'blocks.0.attn.hook_pattern', 'blocks.0.attn.hook_z', 'blocks.0.attn.hook_result', 'blocks.0.hook_attn_out', 'blocks.0.hook_resid_mid', 'blocks.0.hook_mlp_in', 'blocks.0.ln2.hook_scale', 'blocks.0.ln2.hook_normalized', 'blocks.0.mlp.hook_pre', 'blocks.0.mlp.hook_post', 'blocks.0.hook_mlp_out', 'bl

In [None]:
def get_results(model, example):
    token_ids = model.to_tokens(example["text"])
    correct_token_id = model.to_single_token(example["correct"])
    incorrect_token_id = model.to_single_token(example["incorrect"])
    logit_diff_direction = (
        model.W_U[:, correct_token_id] - model.W_U[:, incorrect_token_id]
    )  # (d_model,)

    _, cache = model.run_with_cache(
        token_ids,
        names_filter=lambda name: (
            name in resid_names
            or name == "blocks.0.attn.hook_result"
            or name == "blocks.2.attn.hook_result"
            or name == "ln_final.hook_scale"
            or name == "ln_final.hook_normalized"
        ),
    )

    # Get activations from the cache
    resids = torch.stack(
        [cache[name] for name in resid_names], dim=0
    )  # (resid, batch, pos, d_model)
    L0H2 = cache["blocks.0.attn.hook_result"][:, :, 2, :]  # (batch, pos, d_model)
    L2HX = einops.reduce(
        cache["blocks.2.attn.hook_result"],
        "batch pos head d_model -> batch pos d_model",
        "sum",
    )
    scale = cache["ln_final.hook_scale"]

    apply_ln = lambda x: (x - x.mean(dim=-1, keepdim=True)) / scale
    resids_ln = apply_ln(resids)
    L0H2_ln = apply_ln(L0H2)
    L2HX_ln = apply_ln(L2HX)
    
    # Check that manual layernorming is correct
    assert torch.allclose(cache["ln_final.hook_normalized"], resids_ln[-1], atol=1e-5)

    return (
        resids_ln[:, 0, -1, :] @ logit_diff_direction,
        projection_ratio(resids[:, 0, -1, :], L0H2[0, -1, :].unsqueeze(0)),
        L0H2_ln[0, -1, :] @ logit_diff_direction,
        L2HX_ln[0, -1, :] @ logit_diff_direction,
    )