In [1]:
import os, sys
from pathlib import Path
p = Path(r"/home/ubuntu/SERI-MATS-2023-Streamlit-pages")
if os.path.exists(str_p := str(p.resolve())):
    os.chdir(str_p)
    if str_p not in sys.path:
        sys.path.append(str_p)

from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum2.utils import get_effective_embedding

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cpu",
    # refactor_factored_attn_matrices=True,
)
model.set_use_attn_result(False)

clear_output()

In [3]:
W_EE = get_effective_embedding(model, use_codys_without_attention_changes=False)["W_E (only MLPs)"]
W_U = model.W_U

In [4]:
W_EE_scaled = W_EE / W_EE.std(dim=-1, keepdim=True)
W_U_scaled = W_U / W_U.std(dim=0, keepdim=True)

# Explanation for the scale factors

### QK

For keys, we should use the effective embedding divided by its std (because it'll have been layernormed).

For queries, I'm not totally sure. I think we should scale it, because we're pretending that the token is predicted in the residual stream as strongly as it could possibly be.

### OV

Things are a little more suble here. `W_EE_scaled @ W_V @ W_O` gets scaled before we extract logit lens. So we need to find this matrix, find its std deviation, and then divide `W_EE_scaled @ W_V` by this. `W_O @ W_U` is kept as is, because this is meant to represent the logit lens.

In [6]:
mega_dict = {"tokenizer": model.tokenizer}

for layer, head in [(10, 7)]: # (11, 10)

    W_EE_V = W_EE_scaled @ model.W_V[layer, head]
    W_EE_V_O = W_EE_V @ model.W_O[layer, head]
    W_EE_V_O_scale = W_EE_V_O.std(dim=-1)
    W_EE_V = W_EE_V / W_EE_V_O_scale[:, None]

    W_U_O = (model.W_O[layer, head] @ W_U)

    W_U_Q = W_U_scaled.T @ model.W_Q[layer, head]

    W_EE_K = W_EE_scaled @ model.W_K[layer, head]

    mega_dict[f"{layer}.{head}"] = {
        "W_EE_V": W_EE_V.clone(),
        "W_U_O": W_U_O.clone(),
        "W_U_Q": W_U_Q.clone(),
        "W_EE_K": W_EE_K.clone(),
        # "b_Q": model.b_Q[10, 7],
        # "b_K": model.b_K[10, 7],
    }

In [8]:
mega_dict.keys()

dict_keys(['tokenizer', '10.1', '10.7', '11.10'])

In [7]:
path = "/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/st_page/media/"
with gzip.open(path + "OV_QK_circuits_less_local.pkl", "wb") as f:
    pickle.dump(mega_dict, f)
with gzip.open(path + "OV_QK_circuits_less_public.pkl", "wb") as f:
    pickle.dump({k: v for k, v in mega_dict.items() if k != (10, 1)}, f)

In [16]:
pickle.load(gzip.open(path + "OV_QK_circuits_less.pkl", "rb")).keys()

dict_keys(['tokenizer', 'W_EE_V', 'W_U_O', 'W_U_Q', 'W_EE_K'])