In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import itertools
import json
import os
import sys
import typing
from collections import defaultdict

import numpy as np
import pandas as pd
import seaborn as sns
import tabulate
import torch
from IPython.display import Markdown, display
from loguru import logger
from tqdm.auto import tqdm

torch.set_grad_enabled(False)

from shared_definitions import *

sys.path.insert(0, os.path.abspath(".."))

sns.set_theme(style="white", context="notebook", rc={"figure.figsize": (14, 10)})


In [3]:
USE_SAME_TEST_SETS_RESULTS = True
if USE_SAME_TEST_SETS_RESULTS:
    alternative_paths = {ICL: ICL_SAME_TEST_SETS_RESULTS_ROOT}
else:
    alternative_paths = None

result_df, indirect_effects_by_model_and_dataset, top_heads_by_model_and_dataset = load_and_combine_raw_results(
    alternative_paths
)
result_df.head()

[32m2025-05-13 23:21:11.572[0m | [1mINFO    [0m | [36mshared_definitions[0m:[36mload_and_combine_raw_results[0m:[36m416[0m - [1mLoading cached results from data/full_results.pkl.gz[0m


Unnamed: 0,model,dataset,top_n_prompt_acc,top_1_prompt_acc,top_n_prompts,0_shot_acc,1_shot_acc,2_shot_acc,3_shot_acc,4_shot_acc,...,zs_universal_10_heads_max_acc_layer,zs_universal_20_heads_by_layer_acc,zs_universal_20_heads_max_acc,zs_universal_20_heads_max_acc_layer,fs_shuffled_universal_40_heads_by_layer_acc,fs_shuffled_universal_40_heads_max_acc,fs_shuffled_universal_40_heads_max_acc_layer,zs_universal_40_heads_by_layer_acc,zs_universal_40_heads_max_acc,zs_universal_40_heads_max_acc_layer
0,Llama-3.2-1B-Instruct,choose_first_of_3,0.767429,0.888571,"[Extract the first token, Output the word at t...",0.636667,0.863333,0.896667,0.896667,0.913333,...,,,,,,,,,,
1,Llama-3.2-1B-Instruct,capitalize_first_letter,0.936907,0.963093,"[What is the initial letter of this word?, Det...",0.0,0.688525,0.852459,0.918033,0.95082,...,,,,,,,,,,
2,Llama-3.2-1B-Instruct,capitalize_second_letter,0.110364,0.16,"[What's the special letter in this word?, Dete...",,,,,,...,,,,,,,,,,
3,Llama-3.2-1B-Instruct,english-french,0.6632,0.678224,"[English word in French, English word to Frenc...",0.011348,0.352482,0.587234,0.662411,0.687943,...,,,,,,,,,,
4,Llama-3.2-1B-Instruct,choose_first_of_5,0.737286,0.772857,"[Find the starting word of the input sequence,...",0.616667,0.81,0.87,0.89,0.94,...,,,,,,,,,,


In [4]:
from functools import reduce

N_TOP_HEADS_VALUES = (10, 20, 100)


top_head_rows = []
mean_ie_by_model = {}

tqdm_total = len(ORDERED_MODELS) * 3 * (len(RELEVANT_BASELINES) + 1) * len(N_TOP_HEADS_VALUES)
prompt_datasets_below_chance_acc = defaultdict(set)

for model, prompt_types, baseline, n_top_heads in tqdm(
    itertools.product(
        ORDERED_MODELS, [SHORT, LONG, [SHORT, LONG]], RELEVANT_BASELINES + [RELEVANT_BASELINES], N_TOP_HEADS_VALUES
    ),
    total=tqdm_total,
    desc="Universal top heads",
):
    top_heads, top_head_effects, mean_ie = compute_top_heads(
        result_df,
        indirect_effects_by_model_and_dataset,
        model,
        prompt_types,
        baseline,
        n_top_heads=n_top_heads,
        datasets_below_chance_acc=prompt_datasets_below_chance_acc,
        return_mean=True,
    )
    pt = prompt_types if isinstance(prompt_types, str) == 1 else BOTH
    bl = baseline if isinstance(baseline, str) else ALL
    top_head_rows.append(
        dict(
            model=model,
            prompt_type=pt,
            baseline=bl,
            n=n_top_heads,
            top_heads=set(tuple(t) for t in top_heads),
            top_head_effects=top_head_effects,
            top_heads_list=top_heads,
        )
    )

    mean_ie_by_model[(model, pt, bl)] = mean_ie

for model, datasets in prompt_datasets_below_chance_acc.items():
    if len(datasets) > 0:
        logger.warning(f"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}")


icl_datasets_below_chance_acc = defaultdict(set)

for model, n_top_heads in tqdm(
    itertools.product(ORDERED_MODELS, N_TOP_HEADS_VALUES),
    total=len(ORDERED_MODELS) * len(N_TOP_HEADS_VALUES),
    desc="Universal top heads ICL",
):
    prompt_type = ICL
    baseline = ICL
    top_heads, top_head_effects, mean_ie = compute_top_heads(
        result_df,
        indirect_effects_by_model_and_dataset,
        model,
        prompt_type,
        baseline,
        n_top_heads=n_top_heads,
        datasets_below_chance_acc=icl_datasets_below_chance_acc,
        return_mean=True,
    )
    top_head_rows.append(
        dict(
            model=model,
            prompt_type=prompt_type,
            baseline=baseline,
            n=n_top_heads,
            top_heads=set(tuple(t) for t in top_heads),
            top_head_effects=top_head_effects,
            top_heads_list=top_heads,
        )
    )
    mean_ie_by_model[(model, prompt_type, baseline)] = mean_ie

for model, datasets in icl_datasets_below_chance_acc.items():
    if len(datasets) > 0:
        logger.warning(f"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}")


universal_top_heads_df = pd.DataFrame(top_head_rows)

split_universal_top_heads_dfs = {
    (model, n): universal_top_heads_df[(universal_top_heads_df.model == model) & (universal_top_heads_df.n == n)]
    .copy(deep=True)
    .reset_index(drop=True)
    for model in ORDERED_MODELS
    for n in N_TOP_HEADS_VALUES
}

for model, df in split_universal_top_heads_dfs.items():
    th_to_idx = {th: i for i, th in enumerate(sorted(reduce(lambda th1, th2: th1 | th2, df.top_heads, set())))}

    def ths_to_binary_vector(ths):
        v = np.zeros(len(th_to_idx), dtype=np.ubyte)
        if len(ths) == 0:
            return v
        idxs = np.array([th_to_idx[th] for th in ths])
        v[idxs] = 1
        return v

    df["top_heads_vector"] = df.top_heads.apply(ths_to_binary_vector)

universal_top_heads_df

Universal top heads:   0%|          | 0/504 [00:00<?, ?it/s]



Universal top heads ICL:   0%|          | 0/42 [00:00<?, ?it/s]



Unnamed: 0,model,prompt_type,baseline,n,top_heads,top_head_effects,top_heads_list
0,Llama-3.2-1B,short,equiprobable,10,"{(15, 11), (12, 15), (8, 29), (11, 4), (7, 2),...","[0.011242678388953209, 0.008070942014455795, 0...","[[12, 15], [15, 12], [8, 12], [15, 2], [10, 22..."
1,Llama-3.2-1B,short,equiprobable,20,"{(14, 4), (7, 26), (14, 25), (9, 14), (8, 12),...","[0.011242678388953209, 0.008070942014455795, 0...","[[12, 15], [15, 12], [8, 12], [15, 2], [10, 22..."
2,Llama-3.2-1B,short,equiprobable,100,"{(7, 17), (7, 26), (5, 1), (8, 9), (8, 18), (1...","[0.011242678388953209, 0.008070942014455795, 0...","[[12, 15], [15, 12], [8, 12], [15, 2], [10, 22..."
3,Llama-3.2-1B,short,real_text,10,"{(15, 11), (9, 13), (8, 29), (8, 12), (11, 4),...","[0.01034275908023119, 0.007226495537906885, 0....","[[12, 15], [15, 12], [10, 22], [15, 2], [9, 7]..."
4,Llama-3.2-1B,short,real_text,20,"{(14, 4), (14, 25), (8, 12), (10, 9), (15, 2),...","[0.01034275908023119, 0.007226495537906885, 0....","[[12, 15], [15, 12], [10, 22], [15, 2], [9, 7]..."
...,...,...,...,...,...,...,...
541,OLMo-2-1124-7B-DPO,icl,icl,20,"{(9, 2), (12, 25), (12, 31), (19, 18), (17, 18...","[0.10074028372764587, 0.09945445507764816, 0.0...","[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1..."
542,OLMo-2-1124-7B-DPO,icl,icl,100,"{(16, 20), (15, 30), (16, 29), (22, 17), (31, ...","[0.10074028372764587, 0.09945445507764816, 0.0...","[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1..."
543,OLMo-2-1124-7B-Instruct,icl,icl,10,"{(15, 11), (12, 24), (20, 1), (15, 0), (10, 12...","[0.1116589605808258, 0.09931758791208267, 0.06...","[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1..."
544,OLMo-2-1124-7B-Instruct,icl,icl,20,"{(12, 31), (19, 18), (17, 18), (10, 12), (15, ...","[0.1116589605808258, 0.09931758791208267, 0.06...","[[15, 0], [15, 11], [17, 18], [15, 18], [10, 1..."


In [None]:
TOP_HEADS_DIR = f"{STORAGE_ROOT}/function_vectors/full_results_top_heads"


def write_top_heads_to_files(
    df: pd.DataFrame,
    prompt_type: str = "both",
    baseline: str = "all",
    n_top_heads: int = 100,
    model: str | None = None,
    output_dir: str = TOP_HEADS_DIR,
    overwrite: bool = False,
):
    if prompt_type == "ICL" and baseline != "icl":
        raise ValueError("For ICL, baseline must be 'icl'")

    filtered_df = df.copy(deep=True)
    if model is not None:
        filtered_df = filtered_df[filtered_df.model.str.contains(model)]

    filtered_df = filtered_df[
        (filtered_df.prompt_type == prompt_type) & (filtered_df.baseline == baseline) & (filtered_df.n == n_top_heads)
    ].reset_index(drop=True)

    for i, row in filtered_df.iterrows():
        _save_single_top_head_set(
            row.top_heads_list,
            row.top_head_effects,
            row.model,
            row.prompt_type, 
            row.baseline,
            output_dir=output_dir,
            overwrite=overwrite,
        )


def _save_single_top_head_set(
    model_top_heads: typing.List[typing.Tuple[int, int]],
    model_top_head_effects: typing.List[float],
    model: str,
    prompt_type: str,
    baseline: str,
    head_type: str = "top",
    output_dir: str = TOP_HEADS_DIR,
    overwrite: bool = False,
):
    if baseline == ICL:
        filename = (
            f"{model}_{baseline}_{'same_test_sets_' if USE_SAME_TEST_SETS_RESULTS else ''}{head_type}_heads.json"
        )
    else:
        filename = f"{model}_{prompt_type}_{baseline}_{head_type}_heads.json"

    filepath = os.path.join(output_dir, filename)

    if os.path.exists(filepath) and not overwrite:
        logger.warning(f"{filepath} already exists, skipping...")
        return

    with open(filepath, "w") as f:
        output = dict(
            top_heads=model_top_heads,
            top_head_effects=model_top_head_effects,
        )
        json.dump(output, f)

    logger.info(f"Saved {head_type} heads to {filepath}")


write_top_heads_to_files(universal_top_heads_df, overwrite=False)
write_top_heads_to_files(universal_top_heads_df, prompt_type=ICL, baseline=ICL, overwrite=False)

In [5]:
table_rows = []

for model in ORDERED_MODELS:
    prompt_top_100 = universal_top_heads_df[
        (universal_top_heads_df.model == model)
        & (universal_top_heads_df.prompt_type == BOTH)
        & (universal_top_heads_df.baseline == ALL)
        & (universal_top_heads_df.n == 100)
    ].top_heads.values[0]
    prompt_ie = mean_ie_by_model[(model, BOTH, ALL)]

    icl_top_100 = universal_top_heads_df[
        (universal_top_heads_df.model == model)
        & (universal_top_heads_df.prompt_type == ICL)
        & (universal_top_heads_df.baseline == ICL)
        & (universal_top_heads_df.n == 100)
    ].top_heads.values[0]
    icl_ie = mean_ie_by_model[(model, ICL, ICL)]

    for top_n in (10, 20):
        prompt_top_n = universal_top_heads_df[
            (universal_top_heads_df.model == model)
            & (universal_top_heads_df.prompt_type == BOTH)
            & (universal_top_heads_df.baseline == ALL)
            & (universal_top_heads_df.n == top_n)
        ].top_heads.values[0]
        prompt_shared = prompt_top_n & icl_top_100
        n_prompt_shared = len(prompt_shared)
        prompt_heads_mean_prompt_ie = np.mean([prompt_ie[th] for th in prompt_shared])
        prompt_heads_mean_icl_ie = np.mean([icl_ie[th] for th in prompt_shared])
        table_rows.append(
            dict(
                model=model,
                head_type="prompt",
                n_top_n=top_n,
                n_shared=n_prompt_shared,
                n_shared_perc=f"{n_prompt_shared / top_n:.0%}",
                mean_same_ie=prompt_heads_mean_prompt_ie,
                mean_other_ie=prompt_heads_mean_icl_ie,
                same_other_ratio=prompt_heads_mean_prompt_ie / prompt_heads_mean_icl_ie,
            )
        )

        icl_top_n = universal_top_heads_df[
            (universal_top_heads_df.model == model)
            & (universal_top_heads_df.prompt_type == ICL)
            & (universal_top_heads_df.baseline == ICL)
            & (universal_top_heads_df.n == top_n)
        ].top_heads.values[0]
        icl_shared = icl_top_n & prompt_top_100
        n_icl_shared = len(icl_shared)
        icl_heads_mean_icl_ie = np.mean([icl_ie[th] for th in icl_shared])
        icl_heads_mean_prompt_ie = np.mean([prompt_ie[th] for th in icl_shared])
        table_rows.append(
            dict(
                model=model,
                head_type="icl",
                n_top_n=top_n,
                n_shared=n_icl_shared,
                n_shared_perc=f"{n_icl_shared / top_n:.0%}",
                mean_same_ie=icl_heads_mean_icl_ie,
                mean_other_ie=icl_heads_mean_prompt_ie,
                same_other_ratio=icl_heads_mean_icl_ie / icl_heads_mean_prompt_ie,
            )
        )

    table_rows.append(tabulate.SEPARATING_LINE)


headers = table_rows[0].keys()
print_rows = [[row[key] for key in headers] if isinstance(row, dict) else row for row in table_rows]
display(Markdown(tabulate.tabulate(print_rows, headers=headers, tablefmt="github")))

| model                   | head_type   |   n_top_n |   n_shared | n_shared_perc   |   mean_same_ie |   mean_other_ie |   same_other_ratio |
|-------------------------|-------------|-----------|------------|-----------------|----------------|-----------------|--------------------|
| Llama-3.2-1B            | prompt      |        10 |          6 | 60%             |     0.00670754 |      0.00265102 |           2.53017  |
| Llama-3.2-1B            | icl         |        10 |          2 | 20%             |     0.0146377  |      0.00438629 |           3.33715  |
| Llama-3.2-1B            | prompt      |        20 |         14 | 70%             |     0.00462739 |      0.00363957 |           1.27141  |
| Llama-3.2-1B            | icl         |        20 |          9 | 45%             |     0.00665989 |      0.00230238 |           2.89261  |
|  |
| Llama-3.2-1B-Instruct   | prompt      |        10 |          8 | 80%             |     0.00772535 |      0.00592654 |           1.30352  |
| Llama-3.2-1B-Instruct   | icl         |        10 |          2 | 20%             |     0.0169697  |      0.00693619 |           2.44655  |
| Llama-3.2-1B-Instruct   | prompt      |        20 |         15 | 75%             |     0.00596287 |      0.00459147 |           1.29868  |
| Llama-3.2-1B-Instruct   | icl         |        20 |          8 | 40%             |     0.00816929 |      0.00439157 |           1.86022  |
|  |
| Llama-3.2-3B            | prompt      |        10 |          8 | 80%             |     0.00463581 |      0.00374096 |           1.2392   |
| Llama-3.2-3B            | icl         |        10 |          6 | 60%             |     0.0288357  |      0.00315726 |           9.13316  |
| Llama-3.2-3B            | prompt      |        20 |         15 | 75%             |     0.0037405  |      0.00403034 |           0.928085 |
| Llama-3.2-3B            | icl         |        20 |         13 | 65%             |     0.0152964  |      0.00273862 |           5.58544  |
|  |
| Llama-3.2-3B-Instruct   | prompt      |        10 |          8 | 80%             |     0.00972258 |      0.0076092  |           1.27774  |
| Llama-3.2-3B-Instruct   | icl         |        10 |          5 | 50%             |     0.0293172  |      0.005022   |           5.83776  |
| Llama-3.2-3B-Instruct   | prompt      |        20 |         14 | 70%             |     0.00774945 |      0.00635416 |           1.21959  |
| Llama-3.2-3B-Instruct   | icl         |        20 |         10 | 50%             |     0.0177729  |      0.00478754 |           3.71232  |
|  |
| Llama-3.1-8B            | prompt      |        10 |          3 | 30%             |     0.0041534  |      0.00966502 |           0.429735 |
| Llama-3.1-8B            | icl         |        10 |          3 | 30%             |     0.0144727  |      0.00227791 |           6.35349  |
| Llama-3.1-8B            | prompt      |        20 |         12 | 60%             |     0.00272862 |      0.0053503  |           0.509993 |
| Llama-3.1-8B            | icl         |        20 |          8 | 40%             |     0.00874924 |      0.0022177  |           3.94519  |
|  |
| Llama-3.1-8B-Instruct   | prompt      |        10 |          6 | 60%             |     0.00868464 |      0.0022671  |           3.83073  |
| Llama-3.1-8B-Instruct   | icl         |        10 |          4 | 40%             |     0.0348586  |      0.00325419 |          10.7119   |
| Llama-3.1-8B-Instruct   | prompt      |        20 |         14 | 70%             |     0.00659002 |      0.00549487 |           1.1993   |
| Llama-3.1-8B-Instruct   | icl         |        20 |          8 | 40%             |     0.0203032  |      0.00388063 |           5.23192  |
|  |
| Llama-2-7b-hf           | prompt      |        10 |          6 | 60%             |     0.00400909 |      0.00374708 |           1.06992  |
| Llama-2-7b-hf           | icl         |        10 |          2 | 20%             |     0.00920414 |      0.00390423 |           2.35748  |
| Llama-2-7b-hf           | prompt      |        20 |         13 | 65%             |     0.00278093 |      0.00254216 |           1.09392  |
| Llama-2-7b-hf           | icl         |        20 |          9 | 45%             |     0.00453616 |      0.00151827 |           2.98771  |
|  |
| Llama-2-7b-chat-hf      | prompt      |        10 |          7 | 70%             |     0.00974374 |      0.00566441 |           1.72017  |
| Llama-2-7b-chat-hf      | icl         |        10 |          2 | 20%             |     0.0158251  |      0.00930756 |           1.70024  |
| Llama-2-7b-chat-hf      | prompt      |        20 |         14 | 70%             |     0.00687839 |      0.00491496 |           1.39948  |
| Llama-2-7b-chat-hf      | icl         |        20 |          6 | 30%             |     0.0115129  |      0.00465885 |           2.4712   |
|  |
| Llama-2-13b-hf          | prompt      |        10 |          8 | 80%             |     0.00294895 |      0.00254961 |           1.15662  |
| Llama-2-13b-hf          | icl         |        10 |          2 | 20%             |     0.00927796 |      0.00386601 |           2.39988  |
| Llama-2-13b-hf          | prompt      |        20 |         13 | 65%             |     0.0023589  |      0.00220807 |           1.06831  |
| Llama-2-13b-hf          | icl         |        20 |          6 | 30%             |     0.00597015 |      0.0019831  |           3.01052  |
|  |
| Llama-2-13b-chat-hf     | prompt      |        10 |          6 | 60%             |     0.00997004 |      0.0111935  |           0.890699 |
| Llama-2-13b-chat-hf     | icl         |        10 |          1 | 10%             |     0.0246561  |      0.0203341  |           1.21255  |
| Llama-2-13b-chat-hf     | prompt      |        20 |          9 | 45%             |     0.00803173 |      0.00882307 |           0.910311 |
| Llama-2-13b-chat-hf     | icl         |        20 |          5 | 25%             |     0.014739   |      0.00826954 |           1.78232  |
|  |
| OLMo-2-1124-7B          | prompt      |        10 |          9 | 90%             |     0.00897255 |      0.0067609  |           1.32712  |
| OLMo-2-1124-7B          | icl         |        10 |          5 | 50%             |     0.0161641  |      0.00592047 |           2.7302   |
| OLMo-2-1124-7B          | prompt      |        20 |         18 | 90%             |     0.0069243  |      0.00454355 |           1.52398  |
| OLMo-2-1124-7B          | icl         |        20 |         12 | 60%             |     0.00928662 |      0.00614353 |           1.51161  |
|  |
| OLMo-2-1124-7B-SFT      | prompt      |        10 |          8 | 80%             |     0.00864657 |      0.00812169 |           1.06463  |
| OLMo-2-1124-7B-SFT      | icl         |        10 |          6 | 60%             |     0.0289569  |      0.00434066 |           6.6711   |
| OLMo-2-1124-7B-SFT      | prompt      |        20 |         16 | 80%             |     0.00677069 |      0.00667341 |           1.01458  |
| OLMo-2-1124-7B-SFT      | icl         |        20 |         11 | 55%             |     0.0183659  |      0.00531998 |           3.45225  |
|  |
| OLMo-2-1124-7B-DPO      | prompt      |        10 |          6 | 60%             |     0.0137725  |      0.00599133 |           2.29873  |
| OLMo-2-1124-7B-DPO      | icl         |        10 |          5 | 50%             |     0.0298255  |      0.00529881 |           5.62873  |
| OLMo-2-1124-7B-DPO      | prompt      |        20 |         12 | 60%             |     0.010389   |      0.010331   |           1.00561  |
| OLMo-2-1124-7B-DPO      | icl         |        20 |         12 | 60%             |     0.0177274  |      0.00727804 |           2.43574  |
|  |
| OLMo-2-1124-7B-Instruct | prompt      |        10 |          6 | 60%             |     0.0145396  |      0.00806555 |           1.80268  |
| OLMo-2-1124-7B-Instruct | icl         |        10 |          5 | 50%             |     0.0332614  |      0.00553057 |           6.0141   |
| OLMo-2-1124-7B-Instruct | prompt      |        20 |         12 | 60%             |     0.010936   |      0.0152142  |           0.718798 |
| OLMo-2-1124-7B-Instruct | icl         |        20 |         12 | 60%             |     0.019279   |      0.00774668 |           2.48868  |
|  |

# Finding heads with least absolute causal effect

In [None]:
for model in ORDERED_MODELS:
    prompt_mean_ie = mean_ie_by_model[(model, BOTH, ALL)]
    icl_mean_ie = mean_ie_by_model[(model, ICL, ICL)]
    absolute_total_mean_ie = prompt_mean_ie.abs() + icl_mean_ie.abs()
    min_abs_heads, min_abs_values = top_heads_from_indirect_effects_with_values(absolute_total_mean_ie, 100, largest=False)
    
    _save_single_top_head_set(
        min_abs_heads,
        min_abs_values,
        model,
        "universal",
        ALL,
        head_type="min_abs",
        # output_dir: str = TOP_HEADS_DIR,
        # overwrite: bool = False,
    )


And the ones with the smallest causal effect for each type

In [None]:
for model in ORDERED_MODELS:
    for prompt_type, baseline in ((BOTH, ALL), (ICL, ICL)):
        mean_ie = mean_ie_by_model[(model, prompt_type, baseline)]   
    
        bottom_heads, bottom_values = top_heads_from_indirect_effects_with_values(mean_ie, 100, largest=False)
    
        _save_single_top_head_set(
            bottom_heads,
            bottom_values,
            model,
            prompt_type=prompt_type,
            baseline=baseline,
            head_type="bottom",
            # output_dir: str = TOP_HEADS_DIR,
            # overwrite: bool = False,
        )



# Printing a table

In [15]:
def format_number(x):
    if x == 0:
        return "0"
    # elif x < 0.001:
    #     return f"{x:.2e}"
    else:
        return f"{x:e}"


rows = []

for model in ORDERED_MODELS:
    row = dict(model=model)
    for N_TOP_HEADS in (10, 20, 100):
        prompt_mean_ie = mean_ie_by_model[(model, BOTH, ALL)]
        _, prompt_top_values = top_heads_from_indirect_effects_with_values(prompt_mean_ie, N_TOP_HEADS)
        row[f'Prompt_{N_TOP_HEADS}_mean'] = format_number(np.mean(prompt_top_values))
        row[f'Prompt_{N_TOP_HEADS}_median'] = format_number(np.median(prompt_top_values))
        
        icl_mean_ie = mean_ie_by_model[(model, ICL, ICL)]
        _, icl_top_values = top_heads_from_indirect_effects_with_values(icl_mean_ie, N_TOP_HEADS)
        row[f'ICL_{N_TOP_HEADS}_mean'] = format_number(np.mean(icl_top_values))
        row[f'ICL_{N_TOP_HEADS}_median'] = format_number(np.median(icl_top_values))

    rows.append(row)


headers = rows[0].keys()
print_rows = [[row[key] for key in headers] if isinstance(row, dict) else row for row in rows]
display(Markdown(tabulate.tabulate(print_rows, headers=headers, tablefmt="github", floatfmt=".3e")))
    


    

| model                   |   Prompt_10_mean |   Prompt_10_median |   ICL_10_mean |   ICL_10_median |   Prompt_20_mean |   Prompt_20_median |   ICL_20_mean |   ICL_20_median |   Prompt_100_mean |   Prompt_100_median |   ICL_100_mean |   ICL_100_median |
|-------------------------|------------------|--------------------|---------------|-----------------|------------------|--------------------|---------------|-----------------|-------------------|---------------------|----------------|------------------|
| Llama-3.2-1B            |        5.816e-03 |          4.929e-03 |     3.156e-02 |       2.275e-02 |        4.407e-03 |          3.871e-03 |     1.819e-02 |       7.146e-03 |         1.786e-03 |           1.040e-03 |      4.521e-03 |        1.155e-03 |
| Llama-3.2-1B-Instruct   |        7.695e-03 |          7.026e-03 |     3.511e-02 |       3.048e-02 |        5.855e-03 |          5.011e-03 |     2.028e-02 |       6.796e-03 |         2.205e-03 |           1.335e-03 |      5.369e-03 |        1.688e-03 |
| Llama-3.2-3B            |        4.434e-03 |          3.729e-03 |     2.411e-02 |       1.018e-02 |        3.559e-03 |          3.372e-03 |     1.388e-02 |       4.996e-03 |         1.378e-03 |           8.242e-04 |      3.536e-03 |        8.542e-04 |
| Llama-3.2-3B-Instruct   |        9.555e-03 |          8.353e-03 |     4.480e-02 |       1.505e-02 |        7.161e-03 |          6.811e-03 |     2.556e-02 |       9.256e-03 |         2.620e-03 |           1.602e-03 |      6.489e-03 |        1.645e-03 |
| Llama-3.1-8B            |        3.446e-03 |          3.283e-03 |     2.239e-02 |       1.703e-02 |        2.848e-03 |          2.484e-03 |     1.365e-02 |       7.189e-03 |         1.149e-03 |           7.368e-04 |      3.887e-03 |        1.363e-03 |
| Llama-3.1-8B-Instruct   |        8.576e-03 |          7.676e-03 |     2.358e-02 |       1.764e-02 |        6.749e-03 |          5.806e-03 |     1.472e-02 |       7.761e-03 |         2.488e-03 |           1.412e-03 |      4.401e-03 |        1.797e-03 |
| Llama-2-7b-hf           |        3.392e-03 |          3.147e-03 |     1.630e-02 |       1.294e-02 |        2.565e-03 |          2.015e-03 |     9.715e-03 |       5.394e-03 |         1.212e-03 |           8.952e-04 |      2.609e-03 |        8.563e-04 |
| Llama-2-7b-chat-hf      |        8.545e-03 |          5.821e-03 |     3.942e-02 |       3.443e-02 |        6.257e-03 |          4.617e-03 |     2.383e-02 |       1.228e-02 |         2.471e-03 |           1.527e-03 |      6.770e-03 |        2.613e-03 |
| Llama-2-13b-hf          |        2.775e-03 |          2.129e-03 |     1.550e-02 |       1.151e-02 |        2.107e-03 |          1.723e-03 |     9.834e-03 |       6.372e-03 |         1.028e-03 |           7.966e-04 |      2.738e-03 |        9.504e-04 |
| Llama-2-13b-chat-hf     |        8.587e-03 |          6.800e-03 |     5.083e-02 |       3.981e-02 |        6.462e-03 |          5.400e-03 |     3.151e-02 |       1.878e-02 |         2.759e-03 |           1.856e-03 |      8.981e-03 |        2.888e-03 |
| OLMo-2-1124-7B          |        8.934e-03 |          7.554e-03 |     2.442e-02 |       1.273e-02 |        6.885e-03 |          6.173e-03 |     1.423e-02 |       6.866e-03 |         2.892e-03 |           1.908e-03 |      3.592e-03 |        9.344e-04 |
| OLMo-2-1124-7B-SFT      |        8.110e-03 |          6.537e-03 |     2.789e-02 |       1.748e-02 |        6.487e-03 |          5.598e-03 |     1.636e-02 |       7.593e-03 |         2.750e-03 |           1.896e-03 |      4.336e-03 |        1.359e-03 |
| OLMo-2-1124-7B-DPO      |        1.238e-02 |          1.102e-02 |     4.227e-02 |       2.768e-02 |        9.759e-03 |          8.215e-03 |     2.550e-02 |       1.171e-02 |         4.022e-03 |           2.668e-03 |      6.932e-03 |        2.170e-03 |
| OLMo-2-1124-7B-Instruct |        1.295e-02 |          1.139e-02 |     4.576e-02 |       3.029e-02 |        1.033e-02 |          8.876e-03 |     2.769e-02 |       1.251e-02 |         4.295e-03 |           2.830e-03 |      7.575e-03 |        2.366e-03 |