### Load everything

In [2]:
%load_ext autoreload
%autoreload 2


In [3]:
# config

attribute = "gender"
skipped_prompts = [2, 49]
probe_layer = 26

In [4]:
import pickle

all_clean_attn_attrs = pickle.load(open("notebooks/results/all_clean_attn_attrs_thresholded.pkl", "rb"))

all_attn_AP_scores = pickle.load(open("notebooks/results/all_attn_AP_scores.pkl", "rb"))
all_labelled_tokens = pickle.load(open("notebooks/results/labelled_tokens.pkl", "rb"))

all_head_path_attrs = pickle.load(open("notebooks/results/head_path_attrs.pkl", "rb"))
all_head_out_attrs = pickle.load(open("notebooks/results/head_out_attrs.pkl", "rb"))
head_labels = pickle.load(open("notebooks/results/head_labels.pkl", "rb"))

start_labels = head_labels["start_labels"]
end_labels = head_labels["end_labels"]

In [5]:
from utils.probes import load_dataset
texts, labels = load_dataset(attribute)

  from .autonotebook import tqdm as notebook_tqdm


In [6]:
import torch
import einops

def get_sum_head_vector_attrs(all_attn_AP_scores, layers: list[int], pos_slice: slice | None = None):
    all_sum_head_vector_attrs = {}
    for attn_score in all_attn_AP_scores:
        for key, value in attn_score.items():
            n_heads = 8 if key in ["k", "v"] else 16
            if key not in all_sum_head_vector_attrs:
                all_sum_head_vector_attrs[key] = []
            sum_head_vector_attr = einops.reduce(
                value if pos_slice is None else value[pos_slice],
                "(layer head) pos -> layer head",
                "sum",
                layer=len(layers),
                head=n_heads,
            )
            all_sum_head_vector_attrs[key].append(sum_head_vector_attr)

    for key, value in all_sum_head_vector_attrs.items():
        all_sum_head_vector_attrs[key] = torch.stack(value, dim=0)
    return all_sum_head_vector_attrs

In [7]:
from typing import Literal
from tqdm import tqdm
from utils.index import Ix

def make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, half: Literal["first", "second", "both"] = "both"):
    if half == "first":
        pos_slice = Ix[:, :, :-9].as_index
    elif half == "second":
        pos_slice = Ix[:, :, -10:].as_index
    else:
        pos_slice = Ix[:, :, :].as_index
    mean_head_path_attrs = []
    for i, head_path_attr in tqdm(enumerate(all_head_path_attrs)):
        if i in skipped_prompts:
            continue
        mean_head_path_attrs.append(head_path_attr.to_dense()[pos_slice].sum(-1))

    mean_head_path_attrs = torch.stack(mean_head_path_attrs, dim=0)
    mean_head_path_attrs = mean_head_path_attrs.mean(dim=0)
    mean_head_out_attrs = []
    for i, head_out_attr in tqdm(enumerate(all_head_out_attrs)):
        if i in skipped_prompts:
            continue
        mean_head_out_attrs.append(head_out_attr.to_dense()[pos_slice[1:]].sum(-1))

    mean_head_out_attrs = torch.stack(mean_head_out_attrs, dim=0)
    mean_head_out_attrs = mean_head_out_attrs.mean(dim=0)

    return mean_head_path_attrs, mean_head_out_attrs

head_path_attrs_first_half, head_out_attrs_first_half = make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, "first")
head_path_attrs_second_half, head_out_attrs_second_half = make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, "second")
head_path_attrs, head_out_attrs = make_mean_head_attrs(all_head_path_attrs, all_head_out_attrs, "both")

50it [00:02, 17.95it/s]
50it [00:00, 2204.67it/s]
50it [00:02, 18.53it/s]
50it [00:00, 10024.63it/s]
50it [00:02, 17.34it/s]
50it [00:00, 19676.79it/s]


In [8]:
from utils.index import Ix

def get_sum_head_vector_attrs(all_attn_AP_scores, layers: list[int], pos_slice: slice | None = None):
    all_sum_head_vector_attrs = {}
    for attn_score in all_attn_AP_scores:
        for key, value in attn_score.items():
            n_heads = 8 if key in ["k", "v"] else 16
            if key not in all_sum_head_vector_attrs:
                all_sum_head_vector_attrs[key] = []
            sum_head_vector_attr = einops.reduce(
                value if pos_slice is None else value[pos_slice],
                "(layer head) pos -> layer head",
                "sum",
                layer=len(layers),
                head=n_heads,
            )
            all_sum_head_vector_attrs[key].append(sum_head_vector_attr)

    for key, value in all_sum_head_vector_attrs.items():
        all_sum_head_vector_attrs[key] = torch.stack(value, dim=0)
    return all_sum_head_vector_attrs
LAYERS = list(range(probe_layer + 1))
all_sum_head_vector_attrs = get_sum_head_vector_attrs(all_attn_AP_scores, LAYERS)
sum_head_vectors_first_half = get_sum_head_vector_attrs(all_attn_AP_scores, LAYERS, pos_slice=Ix[:, :-9].as_index)
sum_head_vectors_second_half = get_sum_head_vector_attrs(all_attn_AP_scores, LAYERS, pos_slice=Ix[:, -10:].as_index)

### Plot stuff

In [9]:
from utils.plots import (
    plot_top_head_path_attrs,
    plot_qkv_head_path_attrs,
    plot_interesting_heads,
    get_contributions,
    get_contributions_for_each_layer
)
from utils.neel_plotly import imshow
from functools import partial

plot_top_head_path_attrs = partial(
    plot_top_head_path_attrs, start_labels=start_labels, end_labels=end_labels
)
plot_interesting_heads = partial(plot_interesting_heads, probe_layer=probe_layer)
get_contributions = partial(
    get_contributions,
    probe_layer=probe_layer,
    skipped_prompts=skipped_prompts,
    all_labelled_tokens=all_labelled_tokens,
    texts=texts,
    labels=labels,
)


In [10]:
from utils.neel_plotly import imshow
imshow(sum_head_vectors_first_half["k"].mean(dim=0), xaxis="Latent Index", yaxis="Layer", title="Mean Attribution of Key Heads (First Half)")
# imshow(sum_head_vectors_first_half["k"].std(dim=0), xaxis="Latent Index", yaxis="Layer", title="Std Attribution of Key Heads (First Half)")
imshow(sum_head_vectors_second_half["k"].mean(dim=0), xaxis="Latent Index", yaxis="Layer", title="Mean Attribution of Key Heads (Second Half)")
# imshow(sum_head_vectors_first_half["q"].std(dim=0), xaxis="Latent Index", yaxis="Layer", title="Std Attribution of Query Heads (First Half)")

In [11]:
top_end_indices, top_start_indices, top_end_labels, top_start_labels = plot_top_head_path_attrs(head_path_attrs_first_half, head_out_attrs_first_half)
plot_qkv_head_path_attrs = partial(plot_qkv_head_path_attrs, top_end_indices=top_end_indices, top_start_indices=top_start_indices, top_end_labels=top_end_labels, top_start_labels=top_start_labels)

In [12]:
plot_qkv_head_path_attrs(head_path_attrs_first_half)

In [13]:
# interesting_heads = [
#     (26, 9),
#     (25, 14),
#     (22, 9),
#     (25, 8)
# ]

interesting_heads = [
    (8, 9),
    # (10, 9),
    # (19, 3),
    # (25, 6)
    (26, 2)
]

plot_interesting_heads(head_path_attrs_first_half, interesting_heads)


In [14]:
head = 9
layer = 8
# get_contributions(all_attn_AP_scores, layer, head, max_prompts_to_show=15)

In [15]:
all_attn_AP_scores[3]["k"].shape

torch.Size([216, 333])

In [54]:
from utils.plots import get_contributions_a_prompt
layers_to_plot = [5, 8, 10, 15, 20, 25]
contribs, labelled_tokens = get_contributions_a_prompt(all_attn_AP_scores[3], probe_layer, all_labelled_tokens[4], key="q", layers=layers_to_plot)

In [51]:
top_k_per_layer = []
top_k_labelled_tokens_per_layer = []
labelled_tokens = all_labelled_tokens[4]
k = 10
for contribs_per_layer in contribs:
    top_k_indices = contribs_per_layer.abs().sort(descending=True).indices[:k]
    top_k_values = contribs_per_layer[top_k_indices]
    top_k_per_layer.append(top_k_values)
    top_k_toks = []
    for i in range(k):
        top_k_toks.append(labelled_tokens[top_k_indices[i]].split("/")[0])
    top_k_labelled_tokens_per_layer.append(top_k_toks)

In [52]:
# count occurences of each token
# from collections import Counter
# counts = []
# for layer in range(len(top_k_labelled_tokens_per_layer)):
#     counts.append(dict(Counter(top_k_labelled_tokens_per_layer[layer])))


In [56]:
def plot_mean_token_attribution_by_layer(top_k_labelled_tokens_per_layer, top_k_per_layer, layers: list[int]):
    # Get unique tokens across all layers
    all_tokens = set()
    for layer_tokens in top_k_labelled_tokens_per_layer:
        all_tokens.update(layer_tokens)
    
    # Create a dictionary to store mean attributions per token per layer
    token_means = {token: [] for token in all_tokens}
    
    # Calculate mean attribution for each token in each layer
    for layer_idx in range(len(top_k_labelled_tokens_per_layer)):
        layer_tokens = top_k_labelled_tokens_per_layer[layer_idx]
        layer_values = top_k_per_layer[layer_idx]
        
        # Initialize means for this layer
        layer_sums = {token: 0.0 for token in all_tokens}
        layer_counts = {token: 0 for token in all_tokens}
        
        # Sum up attributions for each token
        for token, value in zip(layer_tokens, layer_values):
            layer_sums[token] += value.item()
            layer_counts[token] += 1
        
        # Calculate means and store them
        for token in all_tokens:
            mean = layer_sums[token] / layer_counts[token] if layer_counts[token] > 0 else 0
            token_means[token].append(mean)
    
    # Plot using imshow
    import numpy as np
    matrix = np.array([token_means[token] for token in all_tokens])
    
    imshow(
        matrix,
        x=[f"Layer {l}" for l in layers],
        y=list(all_tokens),
        title="Mean Token Attribution by Layer",
        # width=800,
        # height=max(400, len(all_tokens) * 20)  # Adjust height based on number of tokens
    )

# Call the function
plot_mean_token_attribution_by_layer(top_k_labelled_tokens_per_layer, top_k_per_layer, layers=layers_to_plot)

In [None]:
get_contributions_for_each_layer = partial(
    get_contributions_for_each_layer,
    probe_layer=probe_layer,
    skipped_prompts=skipped_prompts,
    all_labelled_tokens=all_labelled_tokens,
    texts=texts,
    labels=labels,
    key = "k"
)
get_contributions_for_each_layer(all_attn_AP_scores, max_prompts_to_show=8, layers=[3, 5, 8, 15, 25])

In [15]:
def plot_attention_pattern(attn_pattern, layer, head):
    i = 0
    for pattern in attn_pattern:
        labeled_tokens = all_labelled_tokens[i]
        i += 1
        if i > 5:
            break
        while i in skipped_prompts:
            i += 1
        # imshow(pattern[layer, head], x=labeled_tokens, y=labeled_tokens, title=f"Layer {layer} Head {head} Attention Pattern")
        # convert to dense tensor
        pattern_dense = pattern[layer, head].to_dense()
        imshow(pattern_dense, x=labeled_tokens, y=labeled_tokens, title=f"Layer {layer} Head {head} Attribution Scores")
        print(all_labelled_tokens[i])


# head = 1
# layer = 25
# plot_attention_pattern(all_clean_attn_attrs, layer, head)