In [1]:
# Gross code to allow for importing from parent directory
import os, sys
from pathlib import Path

parent_path = str(Path(os.getcwd()).parent)
if parent_path not in sys.path:
    sys.path.append(parent_path)

# Imports
import gc

import torch
import numpy as np
import pandas as pd
import einops

from transformer_lens import HookedTransformer
from load_data import get_prompts_t
from jamesd_utils import projection_ratio
from plotting import ntensor_to_long


# Global settings and variables
torch.set_grad_enabled(False)
device = "cpu"

# Transformer Lens model names:
# https://github.com/neelnanda-io/TransformerLens/blob/3cd943628b5c415585c8ef100f65989f6adc7f75/transformer_lens/loading_from_pretrained.py#L127
MODEL_NAME = "gelu-4l"

In [2]:
model = HookedTransformer.from_pretrained(MODEL_NAME, device=device)
model.cfg.use_attn_result = True

Loaded pretrained model gelu-4l into HookedTransformer


In [3]:
prompts = get_prompts_t(
    n_text_prompts=2,
    n_code_prompts=1,
).to(device)

Loading 2 prompts from c4-tokenized-2b...


  0%|          | 0/2 [00:00<?, ?it/s]

Loading 1 prompts from code-tokenized...


  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# Run a forward pass and cache selected activations
hook_names = ["blocks.0.attn.hook_result", "ln_final.hook_scale"]

_, cache = model.run_with_cache(
    prompts,
    names_filter=lambda name: name in hook_names,
    device=device,
)

In [5]:
L0H2 = cache["blocks.0.attn.hook_result"][:, :, 2, :]  # (batch, pos, d_model)
scale = cache["ln_final.hook_scale"]  # (batch, pos, 1)
L0H2_normed = (L0H2 - L0H2.mean(keepdim=True, dim=-1)) / scale  # (batch, pos, d_model)

del _, cache
gc.collect()

229

In [6]:
direct_logits = (L0H2_normed @ model.W_U).detach().cpu().numpy()

df = pd.DataFrame()
df["dla_mean"] = np.mean(direct_logits, axis=(0,1))
df["dla_std"] = np.std(direct_logits, axis=(0,1))
df["token_str"] = [model.to_single_str_token(i) for i in range(model.cfg.d_vocab)]

# del direct_logits
# gc.collect()

In [11]:
pd.set_option("display.max_rows", 200)
df.sort_values("dla_mean", ascending=False).head(200)

Unnamed: 0,dla_mean,dla_std,token_str
4623,0.371288,0.429745,__
22991,0.355631,0.405638,duced
14270,0.346383,0.550052,lich
40104,0.339302,0.355535,ocene
44968,0.334968,0.547025,kot
733,0.334682,0.32656,vers
26735,0.332587,0.30496,ved
39134,0.326535,0.483073,vier
31744,0.325954,0.300805,Schedule
12030,0.321889,0.736661,muc
