In [2]:
# Gross code to allow for importing from parent directory
import os, sys
from pathlib import Path

parent_path = str(Path(os.getcwd()).parent)
if parent_path not in sys.path:
    sys.path.append(parent_path)

# Imports
import gc

import torch
import numpy as np
import pandas as pd
import einops

from tqdm.auto import trange
from transformer_lens import HookedTransformer
from load_data import get_prompts_t
from jamesd_utils import projection_ratio
from plotting import ntensor_to_long


# Global settings and variables
torch.set_grad_enabled(False)
device = "cpu"

# Transformer Lens model names:
# https://github.com/neelnanda-io/TransformerLens/blob/3cd943628b5c415585c8ef100f65989f6adc7f75/transformer_lens/loading_from_pretrained.py#L127
MODEL_NAME = "gelu-4l"

In [3]:
model = HookedTransformer.from_pretrained(MODEL_NAME, device=device)
model.cfg.use_attn_result = True

Loaded pretrained model gelu-4l into HookedTransformer


In [4]:
prompts = get_prompts_t(
    n_text_prompts=240,
    n_code_prompts=60,
).to(device)

Loading 240 prompts from c4-tokenized-2b...


  0%|          | 0/240 [00:00<?, ?it/s]

Loading 60 prompts from code-tokenized...


  0%|          | 0/60 [00:00<?, ?it/s]

In [5]:
hook_names = ["blocks.0.attn.hook_result", "ln_final.hook_scale"]
minibatch_size = 2

direct_logits_sum = np.zeros(model.cfg.d_vocab)
count = 0

for i in trange(0, prompts.shape[0], minibatch_size):
    _, cache = model.run_with_cache(
        prompts[i:i+minibatch_size],
        names_filter=lambda name: name in hook_names,
        device=device,
    )

    L0H2 = cache["blocks.0.attn.hook_result"][:, :, 2, :]  # (batch, pos, d_model)
    scale = cache["ln_final.hook_scale"]  # (batch, pos, 1)
    L0H2_normed = (L0H2 - L0H2.mean(keepdim=True, dim=-1)) / scale  # (batch, pos, d_model)

    direct_logits = (L0H2_normed @ model.W_U).detach().cpu().numpy()

    count += direct_logits.shape[0] * direct_logits.shape[1]
    direct_logits_sum += direct_logits.sum(axis=(0, 1))

  0%|          | 0/150 [00:00<?, ?it/s]

In [6]:
df = pd.DataFrame()
df["dla_mean"] = direct_logits_sum / count
df["token_str"] = [model.to_single_str_token(i) for i in range(model.cfg.d_vocab)]

In [8]:
pd.set_option("display.max_rows", 500)
df.sort_values("dla_mean", ascending=False).head(500)

Unnamed: 0,dla_mean,token_str
4623,0.402633,__
26735,0.364378,ved
22991,0.363347,duced
40104,0.361942,ocene
31744,0.358101,Schedule
39134,0.35572,vier
6240,0.35466,Church
733,0.346785,vers
44968,0.344089,kot
14270,0.343727,lich
