In [1]:
from transformer_lens.cautils.notebook import *
from transformer_lens.rs.callum.generate_bag_of_words_quad_plot import get_effective_embedding, lock_attn, fwd_pass_lock_attn0_to_self, get_EE_QK_circuit
from transformer_lens.rs.callum.keys_fixed import get_effective_embedding_2



In [2]:
# MODEL_NAME = "gpt2-small"
# MODEL_NAME = "solu-10l"
MODEL_NAME = "stanford-gpt2-small-e"
model = HookedTransformer.from_pretrained(
    MODEL_NAME,
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    # refactor_factored_attn_matrices=True,
)

model.set_use_attn_result(False)
model.set_use_split_qkv_input(True)

clear_output()

In [3]:
model.cfg.model_name

'expanse-gpt2-small-x777'

In [4]:
LAYER_IDX, HEAD_IDX = {
    "SoLU_10L1280W_C4_Code": (9, 18), # (9, 18) is somewhat cheaty
    "gpt2": (10, 7),
    "expanse-gpt2-small-x777": (11, 2),
}[model.cfg.model_name]

W_U = model.W_U
W_Q_negative = model.W_Q[LAYER_IDX, HEAD_IDX]
W_K_negative = model.W_K[LAYER_IDX, HEAD_IDX]

W_E = model.W_E

# ! question - what's the approximation of GPT2-small's embedding?
# lock attn to 1 at current position
# lock attn to average
# don't include attention

In [5]:
full_QK_circuit = FactoredMatrix.FactoredMatrix(W_U.T @ W_Q_negative, W_K_negative.T @ W_E.T)

indices = t.randint(0, model.cfg.d_vocab, (250,))
full_QK_circuit_sample = full_QK_circuit.A[indices, :] @ full_QK_circuit.B[:, indices]

full_QK_circuit_sample_centered = full_QK_circuit_sample - full_QK_circuit_sample.mean(dim=1, keepdim=True)

imshow(
    full_QK_circuit_sample_centered,
    labels={"x": "Source / key token (embedding)", "y": "Destination / query token (unembedding)"},
    title="Full QK circuit for negative name mover head",
    width=700,
)

In [6]:
raw_dataset = load_dataset("stas/openwebtext-10k")
train_dataset = raw_dataset["train"]
dataset = [train_dataset[i]["text"] for i in range(len(train_dataset))]

Found cached dataset openwebtext-10k (/root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b)


  0%|          | 0/1 [00:00<?, ?it/s]

In [7]:
for i, s in enumerate(dataset):
    loss_hooked = fwd_pass_lock_attn0_to_self(model, s)
    print(f"Loss with attn locked to self: {loss_hooked:.2f}")
    loss_hooked_0 = fwd_pass_lock_attn0_to_self(model, s, ablate=True)
    print(f"Loss with attn locked to zero: {loss_hooked_0:.2f}")
    loss_orig = model(s, return_type="loss")
    print(f"Loss with attn free: {loss_orig:.2f}\n")

    # gc.collect()

    if i == 5:
        break

Loss with attn locked to self: 6.16
Loss with attn locked to zero: 4.01
Loss with attn free: 2.75

Loss with attn locked to self: 5.88
Loss with attn locked to zero: 3.84
Loss with attn free: 2.57

Loss with attn locked to self: 6.57
Loss with attn locked to zero: 4.83
Loss with attn free: 3.54

Loss with attn locked to self: 5.88
Loss with attn locked to zero: 4.14
Loss with attn free: 2.70

Loss with attn locked to self: 6.25
Loss with attn locked to zero: 3.66
Loss with attn free: 2.19

Loss with attn locked to self: 6.38
Loss with attn locked to zero: 4.35
Loss with attn free: 3.15



In [8]:
if "gpt" in model.cfg.model_name: # sigh, tied embeddings
    # sanity check this is the same 

    def remove_pos_embed(z, hook):
        return 0.0 * z

    # setup a forward pass that 
    model.reset_hooks()
    model.add_hook(
        name="hook_pos_embed",
        hook=remove_pos_embed,
        level=1, # ???
    ) 
    model.add_hook(
        name=utils.get_act_name("pattern", 0),
        hook=lock_attn,
    )
    logits, cache = model.run_with_cache(
        torch.arange(1000).to(device).unsqueeze(0),
        names_filter=lambda name: name=="blocks.1.hook_resid_pre",
        return_type="logits",
    )


    W_EE_test = cache["blocks.1.hook_resid_pre"].squeeze(0)
    W_EE_prefix = W_EE_test[:1000]

    assert torch.allclose(
        W_EE_prefix,
        W_EE_test,
        atol=1e-4,
        rtol=1e-4,
    )

In [9]:
NAME_MOVERS = {
    "gpt2": [(9, 9), (10, 0), (9, 6)],
    "SoLU_10L1280W_C4_Code": [(7, 12), (5, 4), (8, 3)],
    "expanse-gpt2-small-x777": [(10, 5),]
}[model.cfg.model_name]

NEGATIVE_NAME_MOVERS = {
    "gpt2": [(LAYER_IDX, HEAD_IDX), (11, 10)],
    "SoLU_10L1280W_C4_Code": [(LAYER_IDX, HEAD_IDX), (9, 15)], # second one on this one IOI prompt only...
    "expanse-gpt2-small-x777": [(11, 2)],
}[model.cfg.model_name]

In [10]:
# Prep some bags of words...
# OVERLY LONG because it really helps to have the bags of words the same length

bags_of_words = []

OUTER_LEN = 50
INNER_LEN = 100

idx = -1
while len(bags_of_words) < OUTER_LEN:
    idx += 1
    cur_tokens = model.tokenizer.encode(dataset[idx])
    cur_bag = []
    
    for i in range(len(cur_tokens)):
        if len(cur_bag) == INNER_LEN:
            break
        if cur_tokens[i] not in cur_bag:
            cur_bag.append(cur_tokens[i])

    if len(cur_bag) == INNER_LEN:
        bags_of_words.append(cur_bag)

Token indices sequence length is longer than the specified maximum sequence length for this model (1094 > 1024). Running this sequence through the model will result in indexing errors


In [11]:
embeddings_dict = get_effective_embedding_2(model)

In [12]:
# embeddings_dict_keys

better_labels = {
    'W_E (including MLPs)': 'Att_0 + W_E + MLP0', 
    'W_E (no MLPs)': 'W_E',
    'W_E (only MLPs)': 'MLP0',
    'W_U': 'W_U',
}

In [13]:
# TODO should really make this outer_len and inner_len, but I forgot
assert all([len(b)==len(bags_of_words[0]) for b in bags_of_words])

In [14]:
# Getting just diag patterns for a single head

from transformer_lens import FactoredMatrix

LAYER = NEGATIVE_NAME_MOVERS[0][0]
HEAD = NEGATIVE_NAME_MOVERS[0][1]
NORM = True

all_results = []
embeddings_dict_keys = sorted(embeddings_dict.keys())
labels = []

data = []
lines = []

for q_side_matrix, k_side_matrix in tqdm(list(itertools.product(embeddings_dict_keys, embeddings_dict_keys))):
    labels.append(f"Q = {better_labels[q_side_matrix]}<br>K = {better_labels[k_side_matrix]}")
    
    log_attentions_to_self = torch.zeros((len(bags_of_words), len(bags_of_words[0])))
    sorted_log_attentions = torch.zeros((len(bags_of_words), len(bags_of_words[0]), len(bags_of_words[0])))

    if "K = W_U" in labels[-1]: 
        labels= labels[:-1]
        continue
    if "Q = W_E" in labels[-1]: 
        labels = labels[:-1]
        continue

    results = []
    for idx in range(OUTER_LEN):
        attn = get_EE_QK_circuit(
            LAYER,
            HEAD,
            model,
            show_plot=False,
            random_seeds=None,
            num_samples=None,
            bags_of_words=bags_of_words[idx: idx+1],
            mean_version=False,
            query_side_bias=False,
            key_side_bias=False,
            W_E_query_side=embeddings_dict[q_side_matrix].clone(),
            W_E_key_side=embeddings_dict[k_side_matrix].clone(),
            apply_softmax=False,
            apply_log_softmax=True,
            norm=NORM,
            # ten_x="W_U" in labels[-1],
        )
        assert len(attn.squeeze().shape) == 2
        # if idx==0:
        #     imshow(
        #         attn,
        #         x = list(enumerate(model.to_str_tokens(torch.tensor(bags_of_words[idx])))),
        #         y = list(enumerate(model.to_str_tokens(torch.tensor(bags_of_words[idx])))),
        #         title = labels[-1],
        #     )

        log_attentions_to_self[idx] = attn[torch.arange(INNER_LEN), torch.arange(INNER_LEN)]
        for j in range(INNER_LEN):
            sorted_log_attentions[idx][j] = torch.tensor(
                sorted(attn[j].tolist(), reverse=True)
            )
    
    data.append(sorted_log_attentions.mean(dim=(0, 1)))
    lines.append(log_attentions_to_self.mean())
    t.cuda.empty_cache()

  0%|          | 0/16 [00:00<?, ?it/s]

In [15]:
fig = make_subplots(
    rows=3, 
    cols=3, 
    start_cell="bottom-left",
    subplot_titles=tuple(labels),
)

MAX_LEN = 5
orders = ["1st", "2nd", "3rd", "4th", "5th"]
# MAX_VAL = max([cur_data[:MAX_LEN].max().item() for cur_data in data])
# MAX_VAL += 0.1 * abs(MAX_VAL)
# MAX_VAL = max(0, MAX_VAL)
# MIN_VAL = min([cur_data[:MAX_LEN].min().item() for cur_data in data])
# MIN_VAL = min([MIN_VAL] + [line for line in lines])
# MIN_VAL -= 0.1 * abs(MIN_VAL)
# MIN_VAL = -6

for idx, (cur_data, line) in enumerate(zip(data, lines, strict=True)):
    fig.add_trace(
        go.Bar(
            x=["Self Attention"], #[labels[idx]],
            y=[line],
            marker_color='red'
        ),
        row=1 + (idx//3), # ??? 0-indexed lol
        col=1 + (idx%3),
    )
    fig.add_trace(
        go.Bar(
            x=orders[:MAX_LEN],
            y=cur_data[:MAX_LEN],
            marker_color='blue'
        ),
        row=1 + (idx//3),
        col=1 + (idx%3),
    )

    # fig.update_yaxes(range=[MIN_VAL, MAX_VAL], row=1 + (idx//3), col=1 + (idx%3))

fig.update_layout(
    title=f"Average attention scores for {LAYER}.{HEAD} compared to top 5 attention scores in bag of words{', no normalization' if not NORM else ''}",
    height=1000,
    width=1000,
    yaxis_title="Log attention",
)

fig.show()

# print(log_attentions_to_self.mean())
# imshow(log_attentions_to_self)
# print("sdjkfl")
# imshow(
#     sorted_log_attentions.mean(1),
# )

In [16]:
data[-1]

tensor([-0.0451,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,    -inf,
           -inf,    -inf,    -inf,    -i

In [17]:
min(min(lines), -2)

tensor(-inf)

In [18]:
imshow(
    all_results,
    facet_col=0,
    facet_col_wrap=len(embeddings_dict)-1,
    facet_labels=labels,
    title=f"Sample of average log softmax for attention approximations with different effective embeddings: head {LAYER}.{HEAD}",
    labels={"x": "Key", "y": "Query"},
    height=900, width=900
)

ValueError: zero-size array to reduction operation maximum which has no identity

In [22]:
# embeddings_dict_keys

# better_labels = {
#     'W_E (including MLPs)': 'Att_0 + W_E + MLP0', 
#     'W_E (no MLPs)': 'W_E',
#     'W_E (only MLPs)': 'MLP0',
#     'W_U': 'W_U',
# }

scores = t.zeros(12, 12).float().to(device)

for layer, head in tqdm(list(itertools.product(range(12), range(12)))):
    results = []
    for idx in range(OUTER_LEN):
        softmaxed_attn = get_EE_QK_circuit(
            layer,
            head,
            model,
            show_plot=False,
            random_seeds=None,
            bags_of_words=bags_of_words[idx:idx+1],
            num_samples=None,
            mean_version=False,
            W_E_query_side=embeddings_dict["W_U"],
            W_E_key_side=embeddings_dict["W_E (including MLPs)"],  # "W_E (only MLPs)"
        )

        # now sort each 

        results.append(softmaxed_attn.diag().mean())

    results = sum(results) / len(results)

    scores[layer, head] = results

imshow(scores, width=750, labels={"x": "Head", "y": "Layer"}, title="Prediction-attn scores for bag of words (including MLPs in embedding)")

  0%|          | 0/144 [00:00<?, ?it/s]

In [20]:
scores = t.zeros(12, 12).float().to(device)

for layer, head in tqdm(list(itertools.product(range(12), range(12)))):
    results = []
    for idx in range(OUTER_LEN):
        softmaxed_attn = get_EE_QK_circuit(
            layer,
            head,
            model,
            show_plot=False,
            random_seeds=None,
            bags_of_words=bags_of_words[idx:idx+1],
            mean_version=False,
            W_E_query_side=embeddings_dict["W_U (or W_E, no MLPs)"],
            W_E_key_side=embeddings_dict["W_E (only MLPs)"],  # 
        )
        results.append(softmaxed_attn.diag().mean())

    results = sum(results) / len(results)

    scores[layer, head] = results

imshow(scores, width=750, labels={"x": "Head", "y": "Layer"}, title="Prediction-attn scores for bag of words (only MLPs in embedding)")

  0%|          | 0/144 [00:00<?, ?it/s]

KeyError: 'W_U (or W_E, no MLPs)'

In [None]:
print("Do a thing where we make the softmax denominator the same???")

Do a thing where we make the softmax denominator the same???
