In [1]:
from transformer_lens.cautils.notebook import *

from transformer_lens.rs.callum2.ioi_and_bos.ioi_functions import get_effective_embedding_2

clear_output()

In [2]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    device="cpu",
    # refactor_factored_attn_matrices=True,
)
model.set_use_attn_result(False)

clear_output()

In [34]:
W_EE = get_effective_embedding_2(model)["W_E (including MLPs)"]
W_U = model.W_U

In [35]:
W_EE_scaled = W_EE / W_EE.std(dim=-1, keepdim=True)
W_U_scaled = W_U / W_U.std(dim=0, keepdim=True)

# Explanation for the scale factors

### QK

For keys, we should use the effective embedding divided by its std (because it'll have been layernormed).

For queries, I'm not totally sure. I think we should scale it, because we're pretending that the token is predicted in the residual stream as strongly as it could possibly be.

### OV

Things are a little more suble here. `W_EE_scaled @ W_V @ W_O` gets scaled before we extract logit lens. So we need to find this matrix, find its std deviation, and then divide `W_EE_scaled @ W_V` by this. `W_O @ W_U` is kept as is, because this is meant to represent the logit lens.

In [40]:
(W_EE_V @ W_U_O)

tensor([[-13.8703,  -4.8223,  -6.6856,  ...,   9.8913,  -5.1538,  -6.4966],
        [ -4.6475,  -8.4365,  -2.8866,  ...,  -0.6939,  -5.3109,  -4.9171],
        [-10.0955,  -8.3327, -26.0346,  ...,   9.3406,   4.4333,  -4.9627],
        ...,
        [  3.1475,   2.1202,  12.1503,  ..., -31.8955,   2.2347,   5.6589],
        [  0.0972,   1.8421,  16.0855,  ...,   2.1897, -22.1990,  -0.5207],
        [ -0.2282,   1.1538,   3.2172,  ...,  -3.3892,  -5.6115,  -1.7882]])

In [36]:
W_EE_V = W_EE_scaled @ model.W_V[10, 7]
W_EE_V_O = W_EE_V @ model.W_O[10, 7]
W_EE_V_O_scale = W_EE_V_O.std(dim=-1)
W_EE_V = W_EE_V / W_EE_V_O_scale[:, None]

W_U_O = (model.W_O[10, 7] @ W_U)

W_U_Q = W_U_scaled.T @ model.W_Q[10, 7]

W_EE_K = W_EE_scaled @ model.W_K[10, 7]

dict_to_store_less = {
    "tokenizer": model.tokenizer,
    "W_EE_V": W_EE_V,
    "W_U_O": W_U_O,
    "W_U_Q": W_U_Q,
    "W_EE_K": W_EE_K,
    # "b_Q": model.b_Q[10, 7],
    # "b_K": model.b_K[10, 7],
}

In [37]:
path = "/home/ubuntu/SERI-MATS-2023-Streamlit-pages/transformer_lens/rs/callum2/st_page/media/"
with gzip.open(path + "OV_QK_circuits_less.pkl", "wb") as f:
    pickle.dump(dict_to_store_less, f)