### Load everything

In [1]:
%load_ext autoreload
%autoreload 2


In [2]:
# config

attribute = "gender"
skipped_prompts = [2, 49]
probe_layer = 26

In [3]:
import pickle

all_clean_attn_attrs = pickle.load(open("notebooks/results/all_clean_attn_attrs_thresholded.pkl", "rb"))

all_attn_AP_scores = pickle.load(open("notebooks/results/all_attn_AP_scores.pkl", "rb"))
all_labelled_tokens = pickle.load(open("notebooks/results/labelled_tokens.pkl", "rb"))

all_head_path_attrs = pickle.load(open("notebooks/results/head_path_attrs.pkl", "rb"))
all_head_out_attrs = pickle.load(open("notebooks/results/head_out_attrs.pkl", "rb"))
head_labels = pickle.load(open("notebooks/results/head_labels.pkl", "rb"))

start_labels = head_labels["start_labels"]
end_labels = head_labels["end_labels"]

In [4]:
from utils.probes import load_dataset
texts, labels = load_dataset(attribute)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
import torch
import einops

def get_sum_head_vector_attrs(all_attn_AP_scores, layers: list[int], pos_slice: slice | None = None):
    all_sum_head_vector_attrs = {}
    for attn_score in all_attn_AP_scores:
        for key, value in attn_score.items():
            n_heads = 8 if key in ["k", "v"] else 16
            if key not in all_sum_head_vector_attrs:
                all_sum_head_vector_attrs[key] = []
            sum_head_vector_attr = einops.reduce(
                value if pos_slice is None else value[pos_slice],
                "(layer head) pos -> layer head",
                "sum",
                layer=len(layers),
                head=n_heads,
            )
            all_sum_head_vector_attrs[key].append(sum_head_vector_attr)

    for key, value in all_sum_head_vector_attrs.items():
        all_sum_head_vector_attrs[key] = torch.stack(value, dim=0)
    return all_sum_head_vector_attrs

In [6]:
from typing import Literal
from tqdm import tqdm
from utils.index import Ix

def make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, half: Literal["first", "second", "both"] = "both"):
    if half == "first":
        pos_slice = Ix[:, :, :-9].as_index
    elif half == "second":
        pos_slice = Ix[:, :, -10:].as_index
    else:
        pos_slice = Ix[:, :, :].as_index
    mean_head_path_attrs = []
    for i, head_path_attr in tqdm(enumerate(all_head_path_attrs)):
        if i in skipped_prompts:
            continue
        mean_head_path_attrs.append(head_path_attr.to_dense()[pos_slice].sum(-1))

    mean_head_path_attrs = torch.stack(mean_head_path_attrs, dim=0)
    mean_head_path_attrs = mean_head_path_attrs.mean(dim=0)
    mean_head_out_attrs = []
    for i, head_out_attr in tqdm(enumerate(all_head_out_attrs)):
        if i in skipped_prompts:
            continue
        mean_head_out_attrs.append(head_out_attr.to_dense()[pos_slice[1:]].sum(-1))

    mean_head_out_attrs = torch.stack(mean_head_out_attrs, dim=0)
    mean_head_out_attrs = mean_head_out_attrs.mean(dim=0)

    return mean_head_path_attrs, mean_head_out_attrs

head_path_attrs_first_half, head_out_attrs_first_half = make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, "first")
head_path_attrs_second_half, head_out_attrs_second_half = make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, "second")
head_path_attrs, head_out_attrs = make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, "both")

50it [00:03, 16.35it/s]
50it [00:00, 14463.12it/s]
50it [00:02, 20.46it/s]
50it [00:00, 10711.23it/s]
50it [00:02, 19.61it/s]
50it [00:00, 11671.59it/s]


In [7]:
from utils.index import Ix

def get_sum_head_vector_attrs(all_attn_AP_scores, layers: list[int], pos_slice: slice | None = None):
    all_sum_head_vector_attrs = {}
    for attn_score in all_attn_AP_scores:
        for key, value in attn_score.items():
            n_heads = 8 if key in ["k", "v"] else 16
            if key not in all_sum_head_vector_attrs:
                all_sum_head_vector_attrs[key] = []
            sum_head_vector_attr = einops.reduce(
                value if pos_slice is None else value[pos_slice],
                "(layer head) pos -> layer head",
                "sum",
                layer=len(layers),
                head=n_heads,
            )
            all_sum_head_vector_attrs[key].append(sum_head_vector_attr)

    for key, value in all_sum_head_vector_attrs.items():
        all_sum_head_vector_attrs[key] = torch.stack(value, dim=0)
    return all_sum_head_vector_attrs
LAYERS = list(range(probe_layer + 1))
all_sum_head_vector_attrs = get_sum_head_vector_attrs(all_attn_AP_scores, LAYERS)
sum_head_vectors_first_half = get_sum_head_vector_attrs(all_attn_AP_scores, LAYERS, pos_slice=Ix[:, :-9].as_index)
sum_head_vectors_second_half = get_sum_head_vector_attrs(all_attn_AP_scores, LAYERS, pos_slice=Ix[:, -10:].as_index)

### Plot stuff

In [12]:
from utils.plots import (
    plot_top_head_path_attrs,
    plot_qkv_head_path_attrs,
    plot_interesting_heads,
    get_contributions,
)
from functools import partial

plot_top_head_path_attrs = partial(
    plot_top_head_path_attrs, start_labels=start_labels, end_labels=end_labels
)
plot_interesting_heads = partial(plot_interesting_heads, probe_layer=probe_layer)
get_contributions = partial(
    get_contributions,
    probe_layer=probe_layer,
    skipped_prompts=skipped_prompts,
    all_labelled_tokens=all_labelled_tokens,
    texts=texts,
    labels=labels,
)

In [18]:
top_end_indices, top_start_indices, top_end_labels, top_start_labels = plot_top_head_path_attrs(head_path_attrs_first_half, head_out_attrs_first_half)
plot_qkv_head_path_attrs = partial(plot_qkv_head_path_attrs, top_end_indices=top_end_indices, top_start_indices=top_start_indices, top_end_labels=top_end_labels, top_start_labels=top_start_labels)

In [19]:
plot_qkv_head_path_attrs(head_path_attrs_first_half)

In [21]:
interesting_heads = [
    (26, 9),
    (25, 14),
    (22, 9),
    (25, 8)
]


plot_interesting_heads(head_path_attrs, interesting_heads)


In [22]:
head = 1
layer = 0
get_contributions(all_attn_AP_scores, layer, head)

In [None]:
from utils.neel_plotly import imshow
def plot_attention_pattern(attn_pattern, layer, head):
    i = 0
    for pattern in attn_pattern:
        labeled_tokens = all_labelled_tokens[i]
        i += 1
        if i > 5:
            break
        while i in skipped_prompts:
            i += 1
        # imshow(pattern[layer, head], x=labeled_tokens, y=labeled_tokens, title=f"Layer {layer} Head {head} Attention Pattern")
        # convert to dense tensor
        pattern_dense = pattern[layer, head].to_dense()
        imshow(pattern_dense, x=labeled_tokens, y=labeled_tokens, title=f"Layer {layer} Head {head} Attribution Scores")
        print(all_labelled_tokens[i])


# head = 1
# layer = 25
# plot_attention_pattern(all_clean_attn_attrs, layer, head)