In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import itertools
import os
import sys
import typing
from collections import Counter, defaultdict
from pathlib import Path

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import tabulate
import torch
from IPython.display import Markdown, display
from loguru import logger
from tqdm.auto import tqdm

torch.set_grad_enabled(False)

from shared_definitions import *
from shared_visualization_utils import *

sys.path.insert(0, os.path.abspath(".."))

sns.set_theme(style="white", context="notebook", rc={"figure.figsize": (14, 10)})


In [None]:
result_df, indirect_effects_by_model_and_dataset, top_heads_by_model_and_dataset = load_and_combine_raw_results()
result_df.head()

In [None]:
SKIP_OLMO = False
if SKIP_OLMO:
    result_df = result_df[~result_df.model.str.contains("OLMo")]
    indirect_effects_by_model_and_dataset = {
        k: v for k, v in indirect_effects_by_model_and_dataset.items() if "OLMo" not in k
    }
    top_heads_by_model_and_dataset = {k: v for k, v in top_heads_by_model_and_dataset.items() if "OLMo" not in k}

RELEVANT_MODELS = ORDERED_MODELS[:]
RELEVANT_SCATTER_ORDERED_MODELS = SCATTER_ORDERED_MODELS[:]
if SKIP_OLMO:
    RELEVANT_MODELS = [model for model in RELEVANT_MODELS if "OLMo" not in model]
    RELEVANT_SCATTER_ORDERED_MODELS = [model for model in RELEVANT_SCATTER_ORDERED_MODELS if "OLMo" not in model]

# RQ2: How similar are the sets of top heads used?


In [None]:
top_heads_path = Path("/checkpoint/guyd/function_vectors/full_results_top_heads")

top_heads_by_model_and_type = defaultdict(dict)

for heads_type, glob_pattern in (
    ("prompt", "*both_all_top_heads.json"),
    ("icl", "*icl_same_test_sets_top_heads.json"),
):
    for top_heads_file in top_heads_path.glob(glob_pattern):
        model, _ = top_heads_file.name.split("_", 1)

        logger.debug(f"Loading {heads_type} heads for {model} from: {str(top_heads_file)}")

        with open(top_heads_file, "r") as f:
            top_heads_by_model_and_type[model][heads_type] = json.load(f)

In [None]:
from scipy.stats.mstats import gmean

lines = []
rows = []

for model, model_results in top_heads_by_model_and_type.items():
    lines.append(f"- {model}")
    icl_heads = model_results.get("icl", None)
    prompt_heads = model_results.get("prompt", None)
    if icl_heads is None or prompt_heads is None:
        logger.warning(f"Missing data for {model}")
        continue

    for n in (10, 20):
        iclh = [tuple(t) for t in icl_heads["top_heads"][:n]]
        icl_layers = [t[0] for t in iclh]
        ph = [tuple(t) for t in prompt_heads["top_heads"][:n]]
        prompt_layers = [t[0] for t in ph]
        lines.append(
            f"   - N={n}: {len(set(iclh) & set(ph))} shared | ICL layers: {np.mean(icl_layers):.2f} | Prompt layers: {np.mean(prompt_layers):.2f}"
        )

        rows.append(
            dict(
                model=model,
                n_heads=n,
                n_layers=MODEL_TO_N_LAYERS[model],
                icl_heads=iclh,
                prompt_heads=ph,
                icl_effects=icl_heads["top_head_effects"][:n],
                prompt_effects=prompt_heads["top_head_effects"][:n],
            )
        )

display(Markdown("\n".join(lines)))

top_heads_summary_df = pd.DataFrame(rows)
top_heads_summary_df = top_heads_summary_df.assign(
    icl_effect_mean=top_heads_summary_df.icl_effects.map(lambda x: np.mean(x)),
    icl_effect_geom_mean=top_heads_summary_df.icl_effects.map(lambda x: gmean(x)),
    prompt_effect_mean=top_heads_summary_df.prompt_effects.map(lambda x: np.mean(x)),
    prompt_effect_geom_mean=top_heads_summary_df.prompt_effects.map(lambda x: gmean(x)),
    icl_layer_mean=top_heads_summary_df.icl_heads.map(lambda x: np.mean([t[0] for t in x])),
    icl_layer_std=top_heads_summary_df.icl_heads.map(lambda x: np.std([t[0] for t in x])),
    prompt_layer_mean=top_heads_summary_df.prompt_heads.map(lambda x: np.mean([t[0] for t in x])),
    prompt_layer_std=top_heads_summary_df.prompt_heads.map(lambda x: np.std([t[0] for t in x])),
)

top_heads_summary_df = top_heads_summary_df.assign(
    icl_layer_mean_depth=top_heads_summary_df.apply(lambda row: row.icl_layer_mean / row.n_layers, axis=1),
    icl_layer_std_depth=top_heads_summary_df.apply(lambda row: row.icl_layer_std / row.n_layers, axis=1),
    prompt_layer_mean_depth=top_heads_summary_df.apply(lambda row: row.prompt_layer_mean / row.n_layers, axis=1),
    prompt_layer_std_depth=top_heads_summary_df.apply(lambda row: row.prompt_layer_std / row.n_layers, axis=1),
)


def row_summary(row: pd.Series):
    sh = len(set(row.icl_heads) & set(row.prompt_heads))
    return dict(
        key=f"{row.model} @ {row.n_heads}",
        shared_heads=sh,
        shared_head_fraction=sh / row.n_heads,
        icl_effect_mean=row.icl_effect_mean,
        prompt_effect_mean=row.prompt_effect_mean,
        mean_diff=row.icl_effect_mean - row.prompt_effect_mean,
        geom_mean_diff=row.icl_effect_geom_mean - row.prompt_effect_geom_mean,
        icl_layer=row.icl_layer_mean,
        icl_layer_std=row.icl_layer_std,
        icl_layer_depth=row.icl_layer_mean_depth,
        icl_layer_depth_std=row.icl_layer_std_depth,
        prompt_layer=row.prompt_layer_mean,
        prompt_layer_std=row.prompt_layer_std,
        prompt_layer_depth=row.prompt_layer_mean_depth,
        prompt_layer_depth_std=row.prompt_layer_std_depth,
        layer_depth_diff=row.prompt_layer_mean_depth - row.icl_layer_mean_depth,
        icl_layers=[t[0] for t in row.icl_heads],
        icl_layer_depths=[t[0] / row.n_layers for t in row.icl_heads],
        prompt_layers=[t[0] for t in row.prompt_heads],
        prompt_layer_depths=[t[0] / row.n_layers for t in row.prompt_heads],
        prompt_heads=row.prompt_heads,
        icl_heads=row.icl_heads,
    )


rows = list(
    top_heads_summary_df.apply(
        lambda row: row_summary(row),
        axis=1,
    ).values
)

layer_key_dicts = dict()
for key_dict in rows:
    key_dict = {**key_dict}
    key = key_dict.pop("key")
    layer_key_dicts[key] = key_dict

d = pd.DataFrame(rows).set_index("key").T
display(Markdown(tabulate.tabulate(d, headers="keys", tablefmt="github")))

# RQ2.1 Top heads -- first panel


In [None]:
from matplotlib.gridspec import GridSpec

# from palettable.colorbrewer.qualitative import Dark2_4 as heads_cmap
# DEFAULT_TOP_HEADS_COLORS = [(1, 1, 1), heads_cmap.mpl_colors[2], heads_cmap.mpl_colors[3], heads_cmap.mpl_colors[0]]
# from palettable.colorbrewer.qualitative import Accent_6 as heads_cmap
# DEFAULT_TOP_HEADS_COLORS = [(1, 1, 1), heads_cmap.mpl_colors[1], heads_cmap.mpl_colors[5], heads_cmap.mpl_colors[0]]
from palettable.colorbrewer.qualitative import Set1_3 as heads_cmap

DEFAULT_TOP_HEADS_COLORS = [(1, 1, 1), *heads_cmap.mpl_colors]


def get_model_shape(model_name: str) -> typing.Tuple[int, int]:
    # hack for an appendix plot
    if model_name in MODEL_TO_N_LAYERS_HEADS:
        return MODEL_TO_N_LAYERS_HEADS[model_name]
    else:
        short_name = model_name.split("_")[0]
        if short_name not in MODEL_TO_N_LAYERS_HEADS:
            raise ValueError(f"Model {model_name} not found in MODEL_TO_N_LAYERS_HEADS")
        
        return MODEL_TO_N_LAYERS_HEADS[short_name]


def plot_shared_top_heads(
    heads_df: pd.DataFrame,
    model_names: typing.List[str],
    n_heads: int | typing.Dict[str, int] = 20,
    limits_from_model_layers: bool = True,
    draw_mean_lines: bool = False,
    shape: typing.Tuple[int, int] | None = None,
    panel_width: int = 4,
    panel_height: int = 4,
    min_step: int = 3,
    max_step: int = 10,
    legend_panel_height: float = 0.2,
    fontsize: int = 12,
    font_inc: int = 4,
    fontfamily: str | None = None,
    textfontweight: str | None = None,
    top_heads_colors: typing.List[tuple] = DEFAULT_TOP_HEADS_COLORS,
    save_name: typing.Optional[str] = None,
    show_legend: bool = True,
    legend_outside: bool = True,
    legend_loc: str | typing.Tuple[int, int] = None,
    legend_ax_index: int = -1,
    legend_ncol: int = 1,
    legend_bbox_to_anchor: typing.Tuple[float, float] = (1.05, 1),
    add_colorbar: bool = False,
    minimal_mode: bool = False,
    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
    annotate_panels: bool = False,
    annotate_panels_start: str = "A",
    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),
    annotate_font_inc: int = 4,
    vmin: float = 0,
    vmax: float = 3,
    suptitle: str | None = None,
    model_titles: typing.Dict[str, str] | None = None,
):
    if model_titles is None:
        model_titles = dict()
    
    if isinstance(n_heads, int):
        n_heads = {model_name: n_heads for model_name in model_names}

    if shape is None:
        if len(model_names) % 2 == 0:
            shape = (len(model_names) // 2, 2)
        else:
            shape = (len(model_names), 1)

    if isinstance(top_heads_colors, list):
        custom_cmap = matplotlib.colors.ListedColormap(top_heads_colors, name="head_colors", N=len(top_heads_colors))
    else:
        custom_cmap = top_heads_colors

    if grid_spec_kwargs is None:
        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)

    if add_colorbar and show_legend:
        raise ValueError("Cannot add colorbar and legend at the same time")

    if show_legend:
        gs = GridSpec(shape[0] + 1, shape[1], height_ratios=[1] * shape[0] + [legend_panel_height], **grid_spec_kwargs)
    elif add_colorbar:
        gs = GridSpec(shape[0], shape[1] + 1, width_ratios=[1] * shape[1] + [legend_panel_height], **grid_spec_kwargs)
    else:
        gs = GridSpec(shape[0], shape[1], **grid_spec_kwargs)

    fig = plt.figure(
        layout="constrained",
        figsize=(shape[1] * panel_width, shape[0] * panel_height + int(show_legend)),
    )

    for i, model_name in enumerate(model_names):
        n = n_heads[model_name]
        r = i // shape[1]
        c = i % shape[1]
        ax = fig.add_subplot(gs[r, c])

        if not minimal_mode:
            if r == shape[1] - 1:
                ax.set_xlabel("Layer", fontsize=fontsize, fontfamily=fontfamily, fontweight=textfontweight)
            if c == 0:
                ax.set_ylabel("Head Index", fontsize=fontsize, fontfamily=fontfamily, fontweight=textfontweight)

        mdf = heads_df[(heads_df.model == model_name) & (heads_df.n_heads == n)]
        prompt_heads = mdf.prompt_heads.values[0]
        icl_heads = mdf.icl_heads.values[0]

    
        head_array = np.zeros(get_model_shape(model_name), dtype=int)
        for head in prompt_heads:
            head_array[head] += 1

        for head in icl_heads:
            head_array[head] += 2

        ax.imshow(head_array.T, cmap=custom_cmap, vmin=vmin, vmax=vmax)

        if draw_mean_lines:
            prompt_head_set = set(prompt_heads)
            icl_head_set = set(icl_heads)
            shared_head_set = prompt_head_set & icl_head_set
            prompt_only_head_set = prompt_head_set - shared_head_set
            icl_only_head_set = icl_head_set - shared_head_set

            for head_set, color in zip(
                (prompt_only_head_set, icl_only_head_set, shared_head_set),
                top_heads_colors[1:],
            ):
                if len(head_set) > 0:
                    mean_layer = np.mean([head[0] for head in head_set])
                    ax.axvline(mean_layer, color=color, linestyle="--", linewidth=2, alpha=0.75)

        if not minimal_mode:
            if model_name in model_titles:
                title = model_titles[model_name]
            else:
                n_shared = np.sum(head_array == 3)
                title = f"{model_name}\n{n} top heads, {n_shared} shared"
            
            ax.set_title(title, fontsize=fontsize, fontfamily=fontfamily, fontweight=textfontweight)

        if minimal_mode:
            ax.set_xticks([])
            ax.set_yticks([])

        elif limits_from_model_layers:
            model_layers, model_heads = get_model_shape(model_name)
            ax.set_xlim(0, model_layers)
            ax.set_ylim(0, model_heads)

            for max_value, set_method in (
                (model_layers, ax.set_xticks),
                (model_heads, ax.set_yticks),
            ):
                step = 1
                for step in range(min_step, max_step + 1):
                    if max_value % step == 0:
                        break

                set_method(np.arange(0, max_value + step, step))

        if annotate_panels:
            ax.text(
                *annotate_panel_position,
                chr(ord(annotate_panels_start) + i),
                ha="left",
                va="top",
                fontsize=fontsize + annotate_font_inc,
                fontweight="bold",
                fontfamily=fontfamily,
                transform=ax.transAxes,
            )

        # n_prompt_only = np.sum(head_array == 1)
        # n_icl_only = np.sum(head_array == 2)

        # legend_entries = [
        #     plt.Line2D([0], [0], color=color, linestyle="-", linewidth=3, label=f"{label} = {count}")
        #     for (color, label, count) in zip (
        #         top_heads_colors[1:],
        #         ["Instruction only", "Demonstration only", "Both"],
        #         [n_prompt_only, n_icl_only, n_shared]
        #     )
        # ]

        # ax.legend(handles=legend_entries, fontsize=fontsize - font_inc, prop=dict(family=fontfamily))

    if show_legend:
        legend_entries = [
            # plt.Line2D([0], [0], color=color, linestyle="-", linewidth=5, label=label)
            matplotlib.patches.Rectangle((0, 0), 1, 1, color=color, label=label)
            for (color, label) in zip(
                top_heads_colors[1:],
                # ["Instruction FV only", "Demonstration FV only", "Shared in both FV"]
                # ["Instruction only heads", "Demonstration only heads", "Shared heads"],
                ["Instruction only", "Demonstration only", "Shared"],
            )
        ]

        legend_kwargs = dict(
            handles=legend_entries, ncol=legend_ncol, prop=dict(family=fontfamily, size=fontsize), handlelength=0.75
        )
        if legend_outside:
            legend_kwargs["bbox_to_anchor"] = legend_bbox_to_anchor
            legend_kwargs["loc"] = legend_loc
        elif legend_loc is not None:
            legend_kwargs["loc"] = legend_loc

        legend_ax = fig.add_subplot(gs[-1, :])
        legend_ax.axis("off")

        legend_ax.legend(**legend_kwargs)

    elif add_colorbar:
        cbar_ax = fig.add_subplot(gs[:, -1])
        cbar = plt.colorbar(
            matplotlib.cm.ScalarMappable(cmap=custom_cmap, norm=matplotlib.colors.Normalize(vmin=vmin, vmax=vmax)),
            cax=cbar_ax,
            orientation="vertical",
            # ticks=np.arange(vmin, vmax + 1) + 0.5,
        )
        # cbar.ax.set_yticklabels(["None", "Prompt", "ICL"], fontsize=fontsize, fontfamily=fontfamily)
        # cbar.ax.set_yticklabels(np.arange(vmin, vmax + 1), fontsize=fontsize, fontfamily=fontfamily)
        cbar.ax.tick_params(labelsize=fontsize)

    if suptitle is not None:
        fig.suptitle(
            suptitle,
            fontsize=fontsize + font_inc,
            fontweight=textfontweight,
            fontfamily=fontfamily,
        )

    # plt.tight_layout()
    if save_name is not None:
        save_plot(save_name)
    plt.show()


MAIN_PAPER_PLOT_MODELS = ["Llama-3.2-3B", "Llama-3.2-3B-Instruct", "Llama-3.1-8B", "Llama-3.1-8B-Instruct"]

plot_shared_top_heads(
    top_heads_summary_df,
    MAIN_PAPER_PLOT_MODELS,
    draw_mean_lines=True,
    legend_ax_index=2,
    legend_outside=False,
    legend_loc=(-0.05, 0),
    legend_ncol=3,
    legend_panel_height=0.05,
    grid_spec_kwargs=dict(hspace=0.5),
    fontfamily="monospace",
    textfontweight="bold",
    save_name="finding_3_shared_top_heads.pdf",
    annotate_panels=True,
    annotate_panels_start="A",
    fontsize=14,
)

# Appendix version of the above plot for more models

In [None]:
APPENDIX_MODELS = [
    model
    for model in ORDERED_MODELS
    if model not in MAIN_PAPER_PLOT_MODELS and "13b" not in model
]


plot_shared_top_heads(
    top_heads_summary_df,
    APPENDIX_MODELS,
    draw_mean_lines=True,
    legend_ax_index=2,
    legend_outside=False,
    legend_loc=(-0.05, 0),
    legend_ncol=3,
    legend_panel_height=0.05,
    panel_height=4.5,
    grid_spec_kwargs=dict(hspace=0.6),
    fontfamily="monospace",
    textfontweight="bold",
    save_name="appendix_finding_3_shared_top_heads.pdf",
    annotate_panels=True,
    annotate_panels_start="A",
    fontsize=14,
)

In [None]:
from palettable.colorbrewer.qualitative import Set1_3 as heads_cmap

FIG_1_TOP_HEAD_COLORS = [(1, 1, 1), *heads_cmap.mpl_colors]
FIG_1_MODEL = MAIN_PAPER_PLOT_MODELS[-1]

plot_shared_top_heads(
    top_heads_summary_df,
    [FIG_1_MODEL],
    top_heads_colors=FIG_1_TOP_HEAD_COLORS,
    shape=(1, 1),
    draw_mean_lines=False,
    show_legend=False,
    minimal_mode=True,
    # legend_panel_height=0.05,
    grid_spec_kwargs=dict(hspace=0.5),
    fontfamily="monospace",
    save_name="figure_1_top_heads.png",
    panel_width=8,
    panel_height=8,
    annotate_panels=False,
)

In [None]:
FIG_1_DATASET = "country-capital"

prompt_fv = torch.load(
    f"/checkpoint/guyd/function_vectors/full_results_prompt_based_short/{FIG_1_MODEL}/{FIG_1_DATASET}/country-capital_20_universal_fv.pt"
)

icl_fv = torch.load(
    f"/checkpoint/guyd/function_vectors/full_icl_results/{FIG_1_MODEL}/{FIG_1_DATASET}/country-capital_20_universal_fv.pt"
)

FIG_1_FVS = {
    FIG_1_MODEL: {
        BOTH: prompt_fv,
        ICL: icl_fv,
    }
}

In [None]:
def plot_top_head_vectors(
    heads_df: pd.DataFrame,
    model_names: typing.List[str],
    # model_to_fvs: typing.Dict[str, typing.Dict[str, torch.Tensor]],
    # fv_shape: typing.Tuple[int, int],
    vector_cmaps: typing.Tuple,
    n_heads: int | typing.Dict[str, int] = 20,
    limits_from_model_layers: bool = True,
    draw_mean_lines: bool = False,
    shape: typing.Tuple[int, int] | None = None,
    panel_width: int = 2,
    panel_height: int = 6,
    legend_panel_height: float = 0.2,
    fontsize: int = 12,
    font_inc: int = 4,
    fontfamily: str | None = None,
    save_name: typing.Optional[str] = None,
    show_legend: bool = True,
    legend_outside: bool = True,
    legend_loc: str | typing.Tuple[int, int] = None,
    legend_ax_index: int = -1,
    legend_ncol: int = 1,
    legend_bbox_to_anchor: typing.Tuple[float, float] = (1.05, 1),
    minimal_mode: bool = False,
    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
    annotate_panels: bool = False,
    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),
    annotate_font_inc: int = 4,
):
    if isinstance(n_heads, int):
        n_heads = {model_name: n_heads for model_name in model_names}
    
    shape = (len(model_names), 2)

    # shape = (len(model_to_fvs), 2)


    if grid_spec_kwargs is None:
        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)

    gs = GridSpec(shape[0], shape[1], **grid_spec_kwargs)

    fig = plt.figure(
        layout="constrained",
        figsize=(shape[1] * panel_width, shape[0] * panel_height + int(show_legend)),
    )

    for i, model_name in enumerate(model_names):
    # for r, model_name in enumerate(model_to_fvs):
        n = n_heads[model_name]
        r = i

        mdf = heads_df[(heads_df.model == model_name) & (heads_df.n_heads == n)]
        prompt_heads = mdf.prompt_heads.values[0]
        icl_heads = mdf.icl_heads.values[0]

        head_array = np.zeros(MODEL_TO_N_LAYERS_HEADS[model_name], dtype=int)
        for head in prompt_heads:
            head_array[head] += 1

        for head in icl_heads:
            head_array[head] += 2

        for c, (ignore_value, cmap) in enumerate(zip(
            (2, 1),
            vector_cmaps,
        )):
            ax = fig.add_subplot(gs[r, c])        

            head_vec = np.copy(head_array)
            head_vec[head_vec == ignore_value] == 0
            head_vec = head_vec.sum(axis=0)

            ax.imshow(head_vec.T[:, None], cmap=cmap)
            ax.set_xticks([])
            ax.set_yticks([])

        # for c, (fv_type, cmap) in enumerate(zip(
        #     (BOTH, ICL),
        #     vector_cmaps,
        # )):
        #     ax = fig.add_subplot(gs[r, c])

        #     fv = model_to_fvs[model_name][fv_type]
        #     # head_vec = np.copy(head_array)
        #     # head_vec[head_vec == ignore_value] == 0
        #     # head_vec = head_vec.sum(axis=0)

        #     ax.imshow(fv.view(fv_shape).cpu().numpy(), cmap=cmap)
        #     ax.set_xticks([])
        #     ax.set_yticks([])

    
    # plt.tight_layout()
    if save_name is not None:
        save_plot(save_name)
    plt.show()


WIDTH = 32
D_MODEL = 4096


plot_top_head_vectors(
    top_heads_summary_df,
    [MAIN_PAPER_PLOT_MODELS[-1]],
    # FIG_1_FVS,
    # (D_MODEL // WIDTH, WIDTH),
    vector_cmaps=(plt.cm.Reds, plt.cm.Blues),
    shape=(1, 1),
    draw_mean_lines=False,
    show_legend=False,
    minimal_mode=True,
    # legend_panel_height=0.05,
    grid_spec_kwargs=dict(hspace=0.5),
    fontfamily="monospace",
    save_name="figure_1_fvs.png",
    panel_width=1,
    panel_height=6,
    annotate_panels=False,
)

# Compare activation similarity only in shared heads


In [None]:
DEFAULT_UNIVERSAL_FV_TYPES = {
    SHORT: f"{BOTH}_{ALL}",
    LONG: f"{BOTH}_{ALL}",
    ICL: "",
}

In [None]:
mld = layer_key_dicts["Llama-3.2-3B @ 20"]
set(mld["prompt_heads"]) & set(mld["icl_heads"])

In [None]:
from pathlib import Path


def load_shared_head_activations(
    models: typing.List[str],
    key_dicts: typing.Dict[str, typing.Dict[str, typing.Any]],
    n_top_heads: int = 20,
    skip_datasets: typing.List[str] = SKIP_DATASETS,
):
    results_by_model = {}

    for model in tqdm(models):
        model_dict = key_dicts[f"{model} @ {n_top_heads}"]
        shared_heads = set(model_dict["prompt_heads"]) & set(model_dict["icl_heads"])
        model_results = defaultdict(lambda: defaultdict(dict))

        for result_type, results_path_str in RESULT_ROOTS.items():
            results_path = Path(results_path_str)

            model_results_path = results_path / model
            if not model_results_path.exists():
                logger.warning(f"Model results path {model_results_path} does not exist.")
                continue

            for model_dataset_path in model_results_path.iterdir():
                if model_dataset_path.name in skip_datasets:
                    continue

                dataset_name = model_dataset_path.name
                mean_activations_path = model_dataset_path / f"{dataset_name}_mean_head_activations.pt"
                if not mean_activations_path.exists():
                    logger.warning(f"Mean activations path {mean_activations_path} does not exist.")
                    continue

                mean_activations = torch.load(mean_activations_path)
                for L, H in shared_heads:
                    model_results[dataset_name][(L, H)][result_type] = mean_activations[L, H, -1]

        results_by_model[model] = model_results

    return results_by_model


shared_head_activations_by_model = load_shared_head_activations(MAIN_PAPER_PLOT_MODELS + APPENDIX_MODELS, layer_key_dicts)


In [None]:
from torch.nn.functional import cosine_similarity

key_pairs = list(itertools.combinations(RESULT_ROOTS.keys(), 2))


rows = []
for model, model_activations in shared_head_activations_by_model.items():
    similarities_by_rt = defaultdict(list)
    for dataset_name, dataset_activations in model_activations.items():
        for (L, H), result_type_activations in dataset_activations.items():
            for rt1, rt2 in key_pairs:
                if rt1 not in result_type_activations or rt2 not in result_type_activations:
                    continue
                similarity = cosine_similarity(result_type_activations[rt1], result_type_activations[rt2], dim=0)
                similarities_by_rt[(rt1, rt2)].append(similarity)

    rows.append(
        dict(
            model=model,
            **{
                f"{rt1} vs. {rt2}": f"{np.mean(sims):.4f} ± {np.std(sims):.4f} (n = {len(sims)})"
                for (rt1, rt2), sims in similarities_by_rt.items()
            },
        )
    )


display(Markdown(tabulate.tabulate(rows, headers="keys", tablefmt="github")))

In [None]:
from torch.nn.functional import cosine_similarity

key_pairs = list(itertools.combinations(RESULT_ROOTS.keys(), 2))


rows = []
similarities_by_model_head_rt = dict()
for model, model_activations in shared_head_activations_by_model.items():
    similarities_by_model_head_rt[model] = defaultdict(lambda: defaultdict(list))
    for dataset_name, dataset_activations in model_activations.items():
        for head, result_type_activations in dataset_activations.items():
            for rt1, rt2 in key_pairs:
                if rt1 not in result_type_activations or rt2 not in result_type_activations:
                    continue
                similarity = cosine_similarity(result_type_activations[rt1], result_type_activations[rt2], dim=0).item()
                similarities_by_model_head_rt[model][head][(rt1, rt2)].append(similarity)

    for head in sorted(similarities_by_model_head_rt[model].keys()):
        head_results = similarities_by_model_head_rt[model][head]
        rows.append(
            dict(
                model=model,
                head=head,
                **{
                    f"{rt1} vs. {rt2}": f"{np.mean(sims):.4f} ± {np.std(sims):.4f} (n = {len(sims)})"
                    for (rt1, rt2), sims in head_results.items()
                },
            )
        )


headers = rows[0].keys()
rows = [[d[k] for k in d] for d in rows]
model_counts = list(Counter([row[0] for row in rows]).values())
sep_rows = list(np.cumsum(model_counts))[:-1]
if sep_rows is not None:
    for i in sorted(sep_rows, reverse=True):
        rows.insert(i, tabulate.SEPARATING_LINE)


display(Markdown(tabulate.tabulate(rows, headers=headers, tablefmt="github")))

In [None]:
similarities_by_model_head_rt["Llama-3.1-8B"][(17, 5)].keys()

In [None]:
from palettable.colorbrewer.qualitative import Paired_6 as heads_similarity_cmap
from palettable.colorbrewer.qualitative import Set1_3 as heads_cmap

light_to_dark_diff = [
    c1 - c2 for c1, c2 in zip(heads_similarity_cmap.mpl_colors[0], heads_similarity_cmap.mpl_colors[1])
]
set1_light_blue = [c + (d * 0.75) for c, d in zip(heads_cmap.mpl_colors[1], light_to_dark_diff)]

SIMILARITY_KEY_STYLES = {
    ("short", "long"): {"label": "Short & Long\nInstructions", "color": heads_cmap.mpl_colors[0]},
    ("short", "icl"): {"label": "Short Instructions\n& Demonstrations", "color": 'cyan'},
    ("long", "icl"): {"label": "Long Instructions\n& Demonsrations", "color": heads_cmap.mpl_colors[1]},
}

SIMILARITY_GLOBAL_PLOT_STYLE = {
    "markersize": 10,
    "alpha": 0.6,
}


def plot_head_activation_similarities(
    similarities_data,
    model_groups: typing.List[typing.List[str]],
    model_plot_styles: typing.Dict[str, typing.Dict[str, typing.Any]] = MODEL_PLOT_STYLES,
    similarity_key_styles: typing.Dict[typing.Tuple[str, str], str] = SIMILARITY_KEY_STYLES,
    global_plot_style: typing.Dict[str, typing.Any] = SIMILARITY_GLOBAL_PLOT_STYLE,
    shape: typing.Tuple[int, int] | None = None,
    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
    show_error_bars: bool = True,
    limits_from_model_layers: bool = True,
    jitter_range: float = 0.5,
    fontsize: int = 16,
    font_inc: int = 4,
    text_font_inc: int = 2,
    fontfamily: str | None = None,
    textfontweight: str = "semibold",
    annotate_panels: bool = True,
    annotate_panels_start: str = "A",
    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),
    annotate_font_inc: int = 8,
    panel_width: float = 6,
    panel_height: float = 6,
    xlabel: str = "Layer",
    ylabel: str = "Head Cosine Similarity",
    ylabel_first_ax_only: bool = True,
    first_yicks_only: bool = False,
    show_legend: bool = True,
    legend_ax_index: int | None = None,
    legend_outside: bool = True,
    legend_loc: str | typing.Tuple[int, int] = None,
    legend_bbox_to_anchor: typing.Tuple[float, float] = (1.05, 1),
    legend_fontsize: int | None = None,
    legend_width: float = 2.0,
    legend_ncol: int = 5,
    legend_panel_height: float = 0.2,
    legend_order: typing.List[int] | None = None,
    subplots_adjust: typing.Dict[str, float] | None = None,
    save_name: str | None = None,
):
    if global_plot_style is None:
        global_plot_style = {}

    global_plot_style = {**GLOBAL_PLOT_STYLE, **global_plot_style}
    n_model_groups = len(model_groups)

    if shape is None:
        shape = (1, n_model_groups)

    if grid_spec_kwargs is None:
        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)

    # if show_legend:
    #     gs = GridSpec(shape[0] + 1, shape[1], height_ratios=[1] * shape[0] + [legend_panel_height], **grid_spec_kwargs)
    # else:
    gs = GridSpec(shape[0], shape[1], **grid_spec_kwargs)

    fig = plt.figure(
        layout="constrained",
        figsize=(shape[1] * panel_width, shape[0] * panel_height + int(show_legend)),
    )

    if legend_ax_index is None:
        legend_ax_index = len(model_groups) - 1

    for i, models in enumerate(model_groups):
        ax = fig.add_subplot(gs[0, i] if shape[0] == 1 else gs[i, 0])

        for model in models:
            model_sims = similarities_data[model]
            model_style = model_plot_styles[model]

            for similarity_key, similarity_style in similarity_key_styles.items():
                model_sim_key_data = [
                    (
                        L + np.random.uniform(-jitter_range, jitter_range),
                        np.mean(sims[similarity_key]),
                        np.std(sims[similarity_key]) / (len(sims[similarity_key]) ** 0.5),
                    )
                    for (L, H), sims in model_sims.items()
                ]

                x, y, err = zip(*model_sim_key_data)
                style = {**global_plot_style, **model_style, **similarity_style}
                error_bar_style = {**style}
                style["linewidth"] = 0
                ax.plot(x, y, **style, zorder=1 if SHORT in similarity_key else -1)
                if show_error_bars:
                    ax.errorbar(x, y, yerr=err, fmt="none", capsize=5, elinewidth=2, **error_bar_style)

        ax.set_ylim(0, 1.02)
        if limits_from_model_layers:
            model_layers, _ = MODEL_TO_N_LAYERS_HEADS[model]
            ax.set_xlim(0, model_layers)

            step = 1
            for divider in range(10, 0, -1):
                if model_layers % divider == 0:
                    step = model_layers / divider
                    break

            ax.set_xticks(np.arange(0, model_layers + step, step))

        if first_yicks_only and i > 0:
            ax.set_yticks([])

        ax.tick_params(axis="both", labelsize=fontsize - text_font_inc, labelfontfamily=fontfamily)
        ax.set_xlabel(xlabel, fontsize=fontsize + font_inc, fontfamily=fontfamily, fontweight=textfontweight)
        if i == 0 or not ylabel_first_ax_only:
            ax.set_ylabel(ylabel, fontsize=fontsize + font_inc, fontfamily=fontfamily, fontweight=textfontweight)
        ax.set_title(
            f"{models[0]} Family", fontsize=fontsize + font_inc, fontfamily=fontfamily, fontweight=textfontweight
        )

        if annotate_panels:
            ax.text(
                *annotate_panel_position,
                chr(ord(annotate_panels_start) + i),
                ha="left",
                va="top",
                fontsize=fontsize + annotate_font_inc,
                fontweight="bold",
                fontfamily=fontfamily,
                transform=ax.transAxes,
            )

        if show_legend and i == legend_ax_index:
            legend_entries = [
                # plt.Line2D([0], [0], color=color, linestyle="-", linewidth=5, label=label)
                matplotlib.patches.Rectangle((0, 0), 1, 1, color=style["color"], label=style["label"])
                for style in similarity_key_styles.values()
            ]
            legend_entries.append(
                matplotlib.lines.Line2D(
                    [0], [0], marker="o", color="white", markerfacecolor="black", markersize=12, label="Base Models"
                )
            )
            legend_entries.append(matplotlib.patches.Rectangle((0, 0), 1, 1, color="black", label="Instruct Models"))

            legend_labels = [style["label"] for style in similarity_key_styles.values()]
            legend_labels.append("Base Models")
            legend_labels.append("Instruct Models")

            if legend_order is not None:
                legend_entries = [legend_entries[i] for i in legend_order]
                legend_labels = [legend_labels[i] for i in legend_order]

            legend_kwargs = dict(
                ncol=legend_ncol,
                prop=dict(family=fontfamily, size=fontsize - font_inc if legend_fontsize is None else legend_fontsize),
                handlelength=0.75,
            )
            if legend_outside:
                legend_kwargs["bbox_to_anchor"] = legend_bbox_to_anchor
                legend_kwargs["loc"] = legend_loc
            elif legend_loc is not None:
                legend_kwargs["loc"] = legend_loc

            # legend_ax = fig.add_subplot(gs[-1, :])
            # legend_ax.axis("off")

            # legend_ax.legend(legend_entries, legend_labels, **legend_kwargs)
            ax.legend(legend_entries, legend_labels, **legend_kwargs)
            

    plt.tight_layout()

    if save_name is not None:
        save_plot(save_name)

    plt.show()


plot_head_activation_similarities(
    similarities_by_model_head_rt,
    [MAIN_PAPER_PLOT_MODELS[:2], MAIN_PAPER_PLOT_MODELS[2:]],
    shape=(2, 1),
    show_error_bars=False,
    ylabel_first_ax_only=False,
    first_yicks_only=False,
    panel_height=4,
    panel_width=4,
    fontsize=14,
    font_inc=2,
    # legend_bbox_to_anchor=(1.01, 0.55),
    # legend_outside=True,
    # legend_ax_index=0,
    # legend_ncol=1,
    # legend_fontsize=10,
    legend_outside=False,
    legend_loc=(0.6, 0),
    # legend_loc=(-0.15, 0),
    legend_ax_index=1,
    legend_ncol=1,
    legend_fontsize=10,
    annotate_panels=True,
    annotate_panels_start="E",
    fontfamily="monospace",
    grid_spec_kwargs=dict(hspace=0.3),
    save_name="finding_3_head_similarities.pdf"
)

In [None]:
plot_head_activation_similarities(
    similarities_by_model_head_rt,
    [APPENDIX_MODELS[0:2], APPENDIX_MODELS[2:4], APPENDIX_MODELS[4:8]],
    shape=(1, 3),
    show_error_bars=False,
    ylabel_first_ax_only=False,
    first_yicks_only=False,
    panel_height=4,
    panel_width=6,
    fontsize=14,
    font_inc=2,
    # legend_bbox_to_anchor=(1.01, 0.55),
    # legend_outside=True,
    # legend_ax_index=0,
    # legend_ncol=1,
    # legend_fontsize=10,
    legend_outside=False,
    legend_loc="lower left",
    # legend_loc=(-0.15, 0),
    legend_ax_index=0,
    legend_ncol=1,
    legend_fontsize=10,
    annotate_panels=True,
    annotate_panels_start="A",
    fontfamily="monospace",
    grid_spec_kwargs=dict(hspace=0.3, wspace=0.3),
    save_name="appendix_finding_3_head_similarities.pdf"
)

# Compute mean IE by baseline/length separately to show similarity

In [None]:

N_TOP_HEADS = 20
top_head_rows = []
split_mean_ie_by_model = {}

tqdm_total = len(ORDERED_MODELS) * 3 * (len(RELEVANT_BASELINES) + 1)
prompt_datasets_below_chance_acc = defaultdict(set)

for model, prompt_types, baseline in tqdm(
    itertools.product(
        ORDERED_MODELS, [SHORT, LONG, [SHORT, LONG]], RELEVANT_BASELINES + [RELEVANT_BASELINES]
    ),
    total=tqdm_total,
    desc="Universal top heads",
):
    top_heads, top_head_effects, mean_ie = compute_top_heads(
        result_df,
        indirect_effects_by_model_and_dataset,
        model,
        prompt_types,
        baseline,
        n_top_heads=N_TOP_HEADS,
        datasets_below_chance_acc=prompt_datasets_below_chance_acc,
        return_mean=True,
    )
    pt = prompt_types if isinstance(prompt_types, str) == 1 else BOTH
    bl = baseline if isinstance(baseline, str) else ALL
    top_head_rows.append(
        dict(
            model=model,
            prompt_type=pt,
            baseline=bl,
            n=N_TOP_HEADS,
            top_heads=set(tuple(t) for t in top_heads),
            top_head_effects=top_head_effects,
            top_heads_list=top_heads,
        )
    )
    split_mean_ie_by_model[(model, pt, bl)] = mean_ie

for model, datasets in prompt_datasets_below_chance_acc.items():
    if len(datasets) > 0:
        logger.warning(f"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}")


icl_datasets_below_chance_acc = defaultdict(set)

for model in tqdm(
    ORDERED_MODELS,
    total=len(ORDERED_MODELS),
    desc="Universal top heads ICL",
):
    prompt_type = ICL
    baseline = ICL
    top_heads, top_head_effects, mean_ie = compute_top_heads(
        result_df,
        indirect_effects_by_model_and_dataset,
        model,
        prompt_type,
        baseline,
        n_top_heads=N_TOP_HEADS,
        datasets_below_chance_acc=icl_datasets_below_chance_acc,
        return_mean=True,
    )
    top_head_rows.append(
        dict(
            model=model,
            prompt_type=prompt_type,
            baseline=baseline,
            n=N_TOP_HEADS,
            top_heads=set(tuple(t) for t in top_heads),
            top_head_effects=top_head_effects,
            top_heads_list=top_heads,
        )
    )
    split_mean_ie_by_model[(model, prompt_type, baseline)] = mean_ie

for model, datasets in icl_datasets_below_chance_acc.items():
    if len(datasets) > 0:
        logger.warning(f"Model {model} datasets below chance accuracy: {', '.join(sorted(datasets))}")


universal_top_heads_df = pd.DataFrame(top_head_rows)

split_universal_top_heads_dfs = {
    model: universal_top_heads_df[(universal_top_heads_df.model == model)]
    .copy(deep=True)
    .reset_index(drop=True)
    for model in ORDERED_MODELS
}


universal_top_heads_df

In [None]:
for model_name in MAIN_PLOT_MODELS:
    model_top_heads_df = universal_top_heads_df[
        (universal_top_heads_df.model == model_name) & 
        ~(universal_top_heads_df.baseline.isin([ICL, ALL])) &
        (universal_top_heads_df.prompt_type != BOTH)
    ]

    def model_pt_baseline_rename(row):
        pt = row.prompt_type
        bl = row.baseline
        return f"{row.model}_{pt}_{bl}"


    model_top_heads_df = model_top_heads_df.rename(columns=dict(n="n_heads", top_heads="prompt_heads"))
    model_top_heads_df = model_top_heads_df.assign(
        model=model_top_heads_df.apply(model_pt_baseline_rename, axis=1),
        icl_heads=model_top_heads_df.apply(lambda row: [], axis=1),
    )
    model_titles = {
        m: f"{m.split('_')[1].capitalize()} instructions\n{m.split('_', 2)[2]} baseline"
        for m in model_top_heads_df.model.values   
    }

    plot_shared_top_heads(
        model_top_heads_df,
        list(model_titles.keys()),
        shape=(2, 3),
        draw_mean_lines=False,
        show_legend=False,
        legend_ax_index=2,
        legend_outside=False,
        legend_loc=(-0.05, 0),
        legend_ncol=3,
        legend_panel_height=0.05,
        panel_height=4,
        grid_spec_kwargs=dict(hspace=0.2, wspace=0.25, top=0.9),
        fontfamily="monospace",
        textfontweight="bold",
        save_name=f"appendix_finding_3_split_top_heads_{model_name}.pdf",
        annotate_panels=True,
        annotate_panels_start="A",
        fontsize=14,
        model_titles=model_titles,
        suptitle=model_name,
    )


And one that's more of a heatmap by how many times each head was counter

In [None]:
from palettable.colorbrewer.sequential import Reds_6 as heads_heatmap_cmap

heads_heatmap_colors = [(1, 1, 1), *heads_heatmap_cmap.mpl_colors]

heatmap_df_rows = []

for model_name in MAIN_PLOT_MODELS:
    model_top_heads_df = universal_top_heads_df[
        (universal_top_heads_df.model == model_name) & 
        ~(universal_top_heads_df.baseline.isin([ICL, ALL])) &
        (universal_top_heads_df.prompt_type != BOTH)
    ]

    all_top_heads = [tuple(th) for th in itertools.chain.from_iterable(model_top_heads_df.top_heads_list)]
    heatmap_df_rows.append(
        dict(
            model=model_name,
            n_heads=N_TOP_HEADS,
            prompt_heads=all_top_heads,
            icl_heads=[],
        )
    )

heatmap_df = pd.DataFrame(heatmap_df_rows)

plot_shared_top_heads(
    heatmap_df,
    list(heatmap_df.model.values),
    top_heads_colors=heads_heatmap_colors,
    shape=(2, 2),
    draw_mean_lines=False,
    show_legend=False,
    legend_ax_index=2,
    legend_outside=False,
    legend_loc=(-0.05, 0),
    legend_ncol=3,
    legend_panel_height=0.05,
    panel_width=4.5,
    grid_spec_kwargs=dict(hspace=0.3, wspace=0.3),
    fontfamily="monospace",
    textfontweight="bold",
    save_name="appendix_finding_3_top_heads_heatmap.pdf",
    annotate_panels=True,
    annotate_panels_start="A",
    fontsize=14,
    model_titles=model_titles,
    vmax=6,
    add_colorbar=True,
)

# RQ2.1 Top heads -- second panel


In [None]:
def repeat_elements(items, n):
    return [item for item in items for _ in range(n)]


DEFAULT_METRIC_NAMES = ["prompt_heads", "icl_heads"]
SHARED = "shared"

DEFUALT_ANNOTATE_KWARGS = dict(
    marker="*",
    s=200,
    color="gold",
    edgecolor="black",
)


def take_first_element(t):
    return t[0]


def plot_layer_rows_individual_dots(
    layer_info_dicts: typing.Dict[str, typing.Dict[str, float]],
    n_top_heads: int,
    metric_names: str | typing.Sequence[str] = DEFAULT_METRIC_NAMES,
    models: typing.Sequence[str] | None = None,
    colors: typing.Sequence[str] | None = None,
    title: str | None = None,
    ylabel: str | None = None,
    ylabel_first_ax_only: bool = True,
    fontsize: int = 12,
    font_inc: int = 4,
    fontfamily: str | None = None,
    metric_labels: typing.Dict[str, str] = None,
    metric_name_to_plot_kwargs: typing.Dict[str, typing.Dict[str, typing.Any]] = None,
    metric_name_to_err_metric: typing.Dict[str, str] = None,
    tuple_to_plot_values: typing.Callable = take_first_element,
    annotate_values: typing.Dict[str, typing.Dict[str, float]] | None = None,
    annotate_kwargs: typing.Dict[str, typing.Any] = None,
    ylim: typing.Tuple[float, float] | None = None,
    ylim_from_model_layers: bool = False,
    panel_width: int = 6,
    panel_height: int = 6,
    annotate_panels: bool = False,
    annotate_panel_position: typing.Tuple[float, float] = (1.01, 1.0),
    annotate_font_inc: int = 4,
    save_name: str | None = None,
    **global_plot_kwargs,
):
    if models is None:
        models = RELEVANT_MODELS[:]

    if isinstance(metric_names, str):
        metric_names = [metric_names]

    if metric_name_to_plot_kwargs is None:
        metric_name_to_plot_kwargs = dict()

    if metric_name_to_err_metric is None:
        metric_name_to_err_metric = dict()

    if annotate_values is None:
        annotate_values = dict()

    if annotate_kwargs is None:
        annotate_kwargs = dict()

    if metric_labels is None:
        metric_labels = dict()

    annotate_kwargs = {**DEFUALT_ANNOTATE_KWARGS, **annotate_kwargs}

    rows = []
    for model in models:
        base_model = model.replace("-Instruct", "").replace("-chat", "")
        is_instruct = ("-Instruct" in model) or ("-chat" in model)

        all_metric_values_by_model = {
            metric_name: layer_info_dicts[f"{model} @ {n_top_heads}"].get(metric_name, np.nan)
            for metric_name in metric_names
        }

        shared_values = set()
        if isinstance(all_metric_values_by_model[metric_names[0]], (tuple, list)):
            shared_values = set(all_metric_values_by_model[metric_names[0]])
            for metric_name in metric_names[1:]:
                shared_values &= set(all_metric_values_by_model[metric_name])

        skip_values = set()

        for metric_name in metric_names:
            model_metric_val = all_metric_values_by_model[metric_name]
            if isinstance(model_metric_val, (int, float)):
                rows.append(
                    dict(
                        model=model,
                        base_model=base_model,
                        is_instruct=is_instruct,
                        metric_name=metric_name,
                        model_metric=f"{model}_{metric_name}",
                        value=model_metric_val,
                    )
                )
            elif isinstance(model_metric_val, list):
                for i, val in enumerate(model_metric_val):
                    d = dict(
                        model=model,
                        base_model=base_model,
                        is_instruct=is_instruct,
                        metric_name=metric_name,
                        model_metric=f"{model}_{metric_name}",
                        value=val,
                        index=i,
                        shared=False,
                    )
                    if val in skip_values:
                        continue
                    if val in shared_values:
                        skip_values.add(val)
                        d["shared"] = True
                        d["metric_name"] = SHARED

                    if isinstance(val, tuple):
                        d["value"] = tuple_to_plot_values(val)

                    rows.append(d)

    all_values_df = pd.DataFrame(rows)
    all_values_df = all_values_df.assign(
        metric_name=all_values_df.metric_name.map(lambda x: metric_labels.get(x, x)),
    )

    n_panels = len(all_values_df.base_model.unique())
    fig, axes = plt.subplots(1, n_panels, figsize=(panel_width * n_panels, panel_height))
    axes = axes.flatten()

    for b, base_model in enumerate(all_values_df.base_model.unique()):
        base_model_df = all_values_df[all_values_df.base_model == base_model]
        n_metrics = base_model_df.metric_name.nunique()
        ax = axes[b]

        base_model_colors = colors[b * n_metrics : (b + 1) * n_metrics]

        hue_order = list(metric_labels.values())

        sns.swarmplot(
            data=base_model_df,
            x="model",
            y="value",
            hue="metric_name",
            # style="metric_name",
            # kind="swarm",
            ax=ax,
            palette=base_model_colors,
            hue_order=hue_order,
            **global_plot_kwargs,
        )

        sns.boxplot(
            data=base_model_df,
            x="model",
            y="value",
            hue="metric_name",
            hue_order=hue_order,
            showmeans=True,
            meanline=True,
            meanprops={"color": "k", "ls": "-", "lw": 2},
            medianprops={"visible": False},
            whiskerprops={"visible": False},
            zorder=10,
            showfliers=False,
            showbox=False,
            showcaps=False,
            palette=base_model_colors,
            legend=None,
            ax=ax,
        )

        ax.set_xlabel("Model", fontsize=fontsize + font_inc, fontfamily=fontfamily)
        if (not ylabel_first_ax_only) or b == 0:
            ax.set_ylabel(
                metric_name.replace("_", " ").capitalize() if ylabel is None else ylabel,
                fontsize=fontsize + font_inc,
                fontfamily=fontfamily,
            )
        else:
            ax.set_ylabel("")

        ax_title = title
        if ax_title is None:
            ax_title = f"{base_model} family (top {n_top_heads} heads)"

        ax.set_title(ax_title, fontsize=fontsize + (2 * font_inc), fontfamily=fontfamily)

        ax.tick_params(axis="both", labelsize=fontsize, labelfontfamily=fontfamily)

        if ylim_from_model_layers:
            model_layers = MODEL_TO_N_LAYERS[base_model]
            ax.set_ylim(0, model_layers)
            y_step = 1
            for divider in range(10, 0, -1):
                if model_layers % divider == 0:
                    y_step = model_layers / divider
                    break

            ax.set_yticks(np.arange(0, model_layers + y_step, y_step))

        elif ylim is not None:
            ax.set_ylim(ylim)

        if annotate_panels:
            ax.text(
                *annotate_panel_position,
                chr(ord("A") + b),
                ha="left",
                va="top",
                fontsize=fontsize + annotate_font_inc,
                fontweight="bold",
                fontfamily=fontfamily,
                transform=ax.transAxes,
            )

        # Add a legend artist for the black line representing the mean
        mean_line = plt.Line2D([0], [0], color="k", linestyle="-", linewidth=2, label="Mean")
        loc = "lower right" if "Llama-3" in base_model else "best"
        ax.legend(
            handles=ax.get_legend_handles_labels()[0] + [mean_line],
            fontsize=fontsize - font_inc,
            loc=loc,
            prop=dict(family=fontfamily),
        )

    plt.tight_layout()
    if save_name is not None:
        save_plot(save_name)
    plt.show()

    return all_values_df


from palettable.colorbrewer.qualitative import Paired_12 as cmap


def flip_pairs(lst):
    return [item for pair in zip(lst[1::2], lst[::2]) for item in pair]


n_top_heads = 20
color_indices = [0, 3, 1, 6, 3, 7]
colors = [cmap.mpl_colors[i] for i in color_indices]

adf = plot_layer_rows_individual_dots(
    layer_key_dicts,
    n_top_heads,
    models=MAIN_PAPER_PLOT_MODELS,
    colors=colors,
    ylabel="Top head layer",
    fontsize=16,
    font_inc=0,
    fontfamily="monospace",
    metric_labels={
        "icl_heads": "Demonstration",
        SHARED: "Shared",
        "prompt_heads": "Instruction",
    },
    ylim_from_model_layers=True,
    dodge=True,
    size=10,
    annotate_panels=True,
    annotate_font_inc=8,
    save_name="rq2_1_top_heads_swarm.pdf",
)


# Appendix plots with the causal effects for each model


In [None]:
mean_ie_by_model = {}
N_TOP_HEADS = 20

for model in RELEVANT_MODELS:
    for name, pt, bl in ((ICL, ICL, ICL), (BOTH, [SHORT, LONG], RELEVANT_BASELINES)):
        top_heads, top_head_effects, mean_ie = compute_top_heads(
            result_df,
            indirect_effects_by_model_and_dataset,
            model,
            pt,
            bl,
            n_top_heads=N_TOP_HEADS,
            return_mean=True,
        )
        mean_ie_by_model[(model, name)] = mean_ie

mean_ie_by_model.keys()

In [None]:
from palettable.scientific.diverging import Vik_20_r as ie_colormap


def plot_mean_ies_by_model_and_type(
    mean_ie_data: typing.Dict[typing.Tuple[str, str], torch.Tensor],
    models: typing.List[str],
    colormap,
    limits_from_model_layers: bool = True,
    force_cmap_zero_middle: bool = True,
    shrink_cmap: bool = False,
    cmap_max: float = 1.0,
    ylabel_first_ax_only: bool = True,
    fontsize: int = 12,
    font_inc: int = 4,
    fontfamily: str | None = None,
    panel_width: int = 6,
    panel_height: int = 6,
    colormap_round: bool = True,
    colormap_round_scale: int = 100,
    colorbar_panel_width: float = 0.1,
    grid_spec_kwargs: typing.Optional[typing.Dict[str, typing.Any]] = None,
    save_name: str | None = None,
):
    shape = (len(models), 2)

    if grid_spec_kwargs is None:
        grid_spec_kwargs = dict(hspace=0.2, wspace=0.25)

    gs = GridSpec(shape[0], shape[1] + 1, width_ratios=[1] * shape[1] + [colorbar_panel_width], **grid_spec_kwargs)

    # fig, axes = plt.subplots(
    #     shape[0], shape[1], figsize=(shape[1] * panel_width, shape[0] * panel_height + legend_panel_height),
    # )
    fig = plt.figure(
        layout="constrained",
        figsize=((shape[1] + colorbar_panel_width) * panel_width, shape[0] * panel_height),
    )

    for m, model in enumerate(models):
        model_prompt_ax = fig.add_subplot(gs[m, 0])
        model_icl_ax = fig.add_subplot(gs[m, 1])

        model_prompt_ax.set_title(f"{model} - Instruction", fontsize=fontsize + font_inc, fontfamily=fontfamily)
        model_icl_ax.set_title(f"{model} - Demonstration", fontsize=fontsize + font_inc, fontfamily=fontfamily)

        prompt_data = mean_ie_data.get((model, BOTH), None)
        icl_data = mean_ie_data.get((model, ICL), None)
        if prompt_data is None or icl_data is None:
            raise ValueError(f"Missing data for {model}")

        prompt_data = prompt_data.cpu().numpy()
        icl_data = icl_data.cpu().numpy()
        overall_min = min(prompt_data.min(), icl_data.min())
        overall_max = max(prompt_data.max(), icl_data.max())
        colorbar_ticks = (None,)
        if colormap_round:
            overall_min = np.floor(overall_min * colormap_round_scale) / colormap_round_scale
            overall_max = np.ceil(overall_max * colormap_round_scale) / colormap_round_scale
            step = 1 / colormap_round_scale
            colorbar_ticks = np.arange(overall_min, overall_max + step, step)

        print(model, overall_min, overall_max)
        if force_cmap_zero_middle:
            norm = matplotlib.colors.TwoSlopeNorm(vmin=overall_min, vcenter=0, vmax=overall_max)
        else:
            norm = matplotlib.colors.Normalize(vmin=overall_min, vmax=overall_max)

        model_cmap = colormap
        if shrink_cmap:
            positive_range = cmap_max - 0.5
            model_cmap = matplotlib.colors.LinearSegmentedColormap.from_list(
                f"{model}_shrunk_cmap",
                colormap(np.linspace(0.5 - (positive_range * abs(overall_min) / overall_max), cmap_max, 256)),
            )

        model_prompt_ax.imshow(
            prompt_data.T,
            cmap=model_cmap,
            norm=norm,
        )
        model_icl_ax.imshow(
            icl_data.T,
            cmap=model_cmap,
            norm=norm,
        )

        model_cbar_ax = fig.add_subplot(gs[m, 2])
        cbar = fig.colorbar(
            matplotlib.cm.ScalarMappable(cmap=model_cmap, norm=norm),
            cax=model_cbar_ax,
            orientation="vertical",
            fraction=0.046,
            shrink=0.5,
            pad=0.04,
            ticks=colorbar_ticks,
        )
        cbar.set_label("Mean Indirect Effect", fontsize=fontsize + font_inc, fontfamily=fontfamily)

        for ax in (model_prompt_ax, model_icl_ax):
            ax.set_xlabel("Layer index", fontsize=fontsize + font_inc, fontfamily=fontfamily)
            if (not ylabel_first_ax_only) or m == 0:
                ax.set_ylabel(
                    "Top head index",
                    fontsize=fontsize + font_inc,
                    fontfamily=fontfamily,
                )
            else:
                ax.set_ylabel("")

            if limits_from_model_layers:
                model_layers, model_heads = MODEL_TO_N_LAYERS_HEADS[model]
                # ax.set_xlim(0, model_layers)
                # ax.set_ylim(0, model_heads)

                for max_value, set_method in (
                    (model_layers, ax.set_xticks),
                    (model_heads, ax.set_yticks),
                ):
                    step = 1
                    for divider in range(10, 0, -1):
                        if max_value % divider == 0:
                            step = max_value / divider
                            break

                    set_method(np.arange(0, max_value + step, step))

    if save_name is not None:
        save_plot(save_name)
    plt.show()


MAIN_PAPER_PLOT_MODELS = ["Llama-3.2-3B", "Llama-3.2-3B-Instruct", "Llama-3.1-8B", "Llama-3.1-8B-Instruct"]


plot_mean_ies_by_model_and_type(
    mean_ie_by_model,
    MAIN_PAPER_PLOT_MODELS,
    ie_colormap.mpl_colormap,
    force_cmap_zero_middle=False,
    shrink_cmap=True,
    cmap_max=0.8,
    fontfamily="monospace",
)

# Some calculations for the localizer experiments


In [None]:
N_TOP_HEADS = 20
rows = []

for model in MAIN_PAPER_PLOT_MODELS:
    model_info = layer_key_dicts[f"{model} @ {N_TOP_HEADS}"]
    model_prompt_heads = model_info["prompt_heads"]
    model_icl_heads = model_info["icl_heads"]

    model_prompt_mean_ie = mean_ie_by_model[(model, BOTH)]
    model_icl_mean_ie = mean_ie_by_model[(model, ICL)]

    mean_prompt_head_prompt_ie = np.mean([model_prompt_mean_ie[head] for head in model_prompt_heads])
    mean_icl_head_icl_ie = np.mean([model_icl_mean_ie[head] for head in model_icl_heads])
    mean_prompt_head_icl_ie = np.mean([model_icl_mean_ie[head] for head in model_prompt_heads])
    mean_icl_head_prompt_ie = np.mean([model_prompt_mean_ie[head] for head in model_icl_heads])
    rows.append(
        {
            "Model": model,
            "Overall median demonstration CIE": model_icl_mean_ie.median(),
            "Demonstration heads / demonstration CIE": mean_icl_head_icl_ie,
            "Instruction heads / demonstration CIE": mean_prompt_head_icl_ie,
            "Localizer difference": mean_prompt_head_icl_ie - mean_icl_head_prompt_ie,
            "Demonstration heads / instruction CIE": mean_icl_head_prompt_ie,
            "Instruction heads / instruction CIE": mean_prompt_head_prompt_ie,
            "Overall median instruction CIE": model_prompt_mean_ie.median(),
        }
    )


# table_format = "github"
table_format = "latex_booktabs"

output = tabulate.tabulate(
    pd.DataFrame(rows).set_index("Model").T, headers="keys", tablefmt=table_format, floatfmt=".4e"
)
if table_format == "github":
    display(Markdown(output))
else:
    print(output)

In [None]:
from scipy.stats import rankdata

prompt_head_ranks = np.product(model_prompt_mean_ie.shape) - rankdata(model_prompt_mean_ie).reshape(
    model_prompt_mean_ie.shape
)
icl_head_ranks = np.product(model_icl_mean_ie.shape) - rankdata(model_icl_mean_ie).reshape(model_icl_mean_ie.shape)
shared_heads = set(model_prompt_heads) & set(model_icl_heads)
prompt_only_heads = set(model_prompt_heads) - shared_heads
icl_only_heads = set(model_icl_heads) - shared_heads
np.mean([prompt_head_ranks[th] for th in icl_only_heads]), np.mean([icl_head_ranks[th] for th in prompt_only_heads])

# A table of the overall CIEs

In [None]:

for model_set in (MAIN_PAPER_PLOT_MODELS, APPENDIX_MODELS[:4], APPENDIX_MODELS[4:8]):
    rows = []
    
    for model in model_set:
        row = dict(model=model)
        for N_TOP_HEADS in (10, 20, 100):
            prompt_mean_ie = mean_ie_by_model[(model, BOTH)]
            _, prompt_top_values = top_heads_from_indirect_effects_with_values(prompt_mean_ie, N_TOP_HEADS)
            # row[f'Prompt_{N_TOP_HEADS}_mean'] = np.mean(prompt_top_values)
            # row[f'Prompt_{N_TOP_HEADS}_median'] = np.median(prompt_top_values)
            
            icl_mean_ie = mean_ie_by_model[(model, ICL)]
            _, icl_top_values = top_heads_from_indirect_effects_with_values(icl_mean_ie, N_TOP_HEADS)
            row[f'{N_TOP_HEADS}_mean_ratio'] = np.mean(icl_top_values) / np.mean(prompt_top_values)
            row[f'{N_TOP_HEADS}_median_ratio'] = np.median(icl_top_values)  / np.median(prompt_top_values)

        rows.append(row)

    # rows = pd.DataFrame(rows).set_index("model").T.reset_index().rename(columns={"index": "Metric"}).to_dict(orient="records")

    headers = rows[0].keys()
    print_rows = [[row[key] for key in headers] if isinstance(row, dict) else row for row in rows]

    # table_format = "github"
    table_format = "latex_booktabs"

    output = tabulate.tabulate(
        pd.DataFrame(rows).set_index("model").T, headers="keys", tablefmt=table_format, floatfmt=".3f"
    )
    if table_format == "github":
        display(Markdown(output))
    else:
        print()
        print(output)
        print()
    

In [None]:
APPENDIX_MODELS

## Can we generate some statistical tests for these?

Let's start with the prompt vs. ICL layers, then do base vs. instruct (or more if it's OLMo)


In [None]:
from scipy.stats import mannwhitneyu

alpha = 0.05 / (len(RELEVANT_MODELS) * 2)

test_rows = []
for model in RELEVANT_MODELS:
    row_dict = dict(model=model)
    for n in (10, 20):
        model_results = layer_key_dicts[f"{model} @ {n}"]
        row_dict[f"ICL Layers ({n})"] = model_results["icl_layers"]
        row_dict[f"Prompt Layers ({n})"] = model_results["prompt_layers"]
        stat, p = mannwhitneyu(
            model_results["icl_layers"],
            model_results["prompt_layers"],
            alternative="two-sided",
        )
        row_dict[f"Test statistic ({n})"] = stat
        row_dict[f"p-value ({n})"] = f"{p:.4f}{'' if p > alpha else ' *'}"

    test_rows.append(row_dict)


display(Markdown(tabulate.tabulate(test_rows, headers="keys", tablefmt="github")))

In [None]:
from scipy.stats import mannwhitneyu

alpha = 0.05 / (len(RELEVANT_MODELS) * 2)

layers_key = "prompt_layers"

test_rows = []
for i in range(0, len(RELEVANT_MODELS), 2):
    base_model = RELEVANT_MODELS[i]
    instruct_model = RELEVANT_MODELS[i + 1]

    if "OLMo" in base_model:
        continue

    row_dict = {"Base Model": base_model}
    for n in (10, 20):
        base_model_layers = layer_key_dicts[f"{base_model} @ {n}"][layers_key]
        instruct_model_layers = layer_key_dicts[f"{instruct_model} @ {n}"][layers_key]
        row_dict[f"Base Layers ({n})"] = base_model_layers
        row_dict[f"Instruct Layers ({n})"] = instruct_model_layers
        stat, p = mannwhitneyu(
            base_model_layers,
            instruct_model_layers,
            alternative="two-sided",
        )
        row_dict[f"Test statistic ({n})"] = stat
        row_dict[f"p-value ({n})"] = f"{p:.4f}{'' if p > alpha else ' *'}"

    test_rows.append(row_dict)


display(Markdown(tabulate.tabulate(test_rows, headers="keys", tablefmt="github")))

In [None]:
from scipy.stats import mannwhitneyu

layers_key = "prompt_layers"
n_top_heads = 20

test_rows = []

olmo_models = [model for model in RELEVANT_MODELS if "OLMo" in model]
alpha = 0.05 / (len(olmo_models) * (len(olmo_models) - 1) / 2)

for first_model in olmo_models:
    row_dict = {"Model": first_model}
    for second_model in olmo_models:
        if first_model == second_model:
            row_dict[second_model] = ""

        else:
            first_model_layers = layer_key_dicts[f"{first_model} @ {n_top_heads}"][layers_key]
            second_model_layers = layer_key_dicts[f"{second_model} @ {n_top_heads}"][layers_key]
            stat, p = mannwhitneyu(
                first_model_layers,
                second_model_layers,
                alternative="two-sided",
            )
            row_dict[second_model] = f"U = {stat}, p = {p:.4f}{'' if p > alpha else ' *'}"

    test_rows.append(row_dict)


display(Markdown(tabulate.tabulate(test_rows, headers="keys", tablefmt="github")))

In [None]:
from matplotlib.ticker import MaxNLocator
from palettable.colorbrewer.qualitative import Paired_12 as cmap

# cmap = colorcet.glasbey_dark


DEFUALT_ANNOTATE_KWARGS = dict(
    marker="*",
    s=200,
    color="gold",
    edgecolor="black",
)


def plot_layer_rows_metric_bar_chart(
    layer_info_dicts: typing.Dict[str, typing.Dict[str, float]],
    n_top_heads: int,
    metric_names: str | typing.Sequence[str],
    models: typing.Sequence[str] | None = None,
    colors: typing.Sequence[str] = None,
    title: str | None = None,
    ylabel: str | None = None,
    figsize: typing.Tuple[float, float] = (8, 6),
    fontsize: int = 20,
    font_inc: int = 4,
    bar_width: float = 0.8,
    fake_value: float | None = None,
    fake_value_index: int = 0,
    fake_value_label: str | None = None,
    fake_value_color: str = "gray",
    metric_name_to_bar_kwargs: typing.Dict[str, typing.Dict[str, typing.Any]] = None,
    metric_name_to_err_metric: typing.Dict[str, str] = None,
    err_sem: bool = True,
    annotate_values: typing.Dict[str, typing.Dict[str, float]] | None = None,
    annotate_kwargs: typing.Dict[str, typing.Any] = None,
    ylim: typing.Tuple[float, float] | None = None,
    fontfamily: str | None = None,
):
    if models is None:
        models = RELEVANT_MODELS[:]

    if colors is not None and len(colors) != len(models):
        raise ValueError(f"Length of colors ({len(colors)}) does not match number of models ({len(models)}).")

    if isinstance(metric_names, str):
        metric_names = [metric_names]

    if metric_name_to_bar_kwargs is None:
        metric_name_to_bar_kwargs = dict()

    if metric_name_to_err_metric is None:
        metric_name_to_err_metric = dict()

    if annotate_values is None:
        annotate_values = dict()

    if annotate_kwargs is None:
        annotate_kwargs = dict()

    if fontfamily is None:
        fontfamily = plt.rcParams["font.family"]

    annotate_kwargs = {**DEFUALT_ANNOTATE_KWARGS, **annotate_kwargs}

    fig, ax = plt.subplots(figsize=figsize)

    all_metric_values = [
        [layer_info_dicts[f"{model} @ {n_top_heads}"].get(metric_name, np.nan) for model in models]
        for metric_name in metric_names
    ]

    labels = models[:]

    if fake_value is not None:
        if isinstance(fake_value, (int, float)):
            fake_value = [fake_value] * len(metric_names)

        for metric_values, fv in zip(all_metric_values, fake_value):
            metric_values.insert(fake_value_index, fv)

        labels.insert(fake_value_index, fake_value_label if fake_value_label is not None else "Fake Value")
        if colors is not None:
            colors.insert(fake_value_index, fake_value_color)

    for metric_name, vals in zip(metric_names, all_metric_values):
        if np.any(np.isnan(vals)):
            logger.warning(f"Missing data for metrics {metric_name}.")
            return

    bar_positions = np.arange(len(labels))
    bar_width /= len(metric_names)

    for i, (metric_name, metric_values) in enumerate(zip(metric_names, all_metric_values)):
        x_positions = bar_positions + i * bar_width
        bar_kwargs = dict(width=bar_width, **metric_name_to_bar_kwargs.get(metric_name, {}))
        if colors is not None:
            bar_kwargs["color"] = colors
        ax.bar(x_positions, metric_values, **bar_kwargs, label=labels if i == 0 else None)

        annotation_values = annotate_values.get(metric_name, None)
        if annotation_values is not None:
            ann_vals = [annotation_values.get(model, None) for model in models]
            if fake_value is not None:
                ann_vals.insert(fake_value_index, None)

            ax.scatter(x_positions, ann_vals, **annotate_kwargs)

        err_metric = metric_name_to_err_metric.get(metric_name, None)
        if err_metric is not None:
            err_values = [layer_info_dicts[f"{model} @ {n_top_heads}"].get(err_metric, np.nan) for model in models]
            if fake_value is not None:
                err_values.insert(fake_value_index, 0)

            if err_sem:
                err_values = np.array(err_values) / np.sqrt(n_top_heads)
            ax.errorbar(x_positions, metric_values, yerr=err_values, fmt="none", capsize=5, color="gray")

    ax.set_xticks([])
    ax.yaxis.set_major_locator(MaxNLocator(integer=True))
    # ax.set_xticklabels(models, fontsize=fontsize + font_inc)
    ax.set_ylabel(
        metric_name.replace("_", " ").capitalize() if ylabel is None else ylabel, fontsize=fontsize + font_inc
    )
    if title is not None:
        ax.set_title(title, fontsize=fontsize + (2 * font_inc), fontfamily=fontfamily)
    ax.legend(
        fontsize=fontsize - font_inc,
        prop=dict(family=fontfamily),
    )
    ax.tick_params(axis="y", labelsize=fontsize, labelfontfamily=fontfamily)
    if ylim is not None:
        ax.set_ylim(ylim)
    # plt.xticks(rotation=45)
    # ax.tick_params(axis='x', which='both', bottom=False, top=False)
    plt.tight_layout()
    plt.show()


# ordered_color_indices = list(range(len(RELEVANT_MODELS)))
# cmap = colorcet.glasbey_dark
# plot_layer_rows_metric_bar_chart(
#     layer_key_dicts,
#     10,
#     "shared_heads",
#     colors=[cmap[i] for i in ordered_color_indices],
#     ylim=(0, 10),
#     fake_value=10,
#     fake_value_label="Maximum possible",
#     title="Prompt & ICL shared heads (top 10)",
# )

# plot_layer_rows_metric_bar_chart(
#     layer_key_dicts,
#     20,
#     "shared_heads",
#     colors=[cmap[i] for i in ordered_color_indices],
#     ylim=(0, 20),
#     fake_value=20,
#     fake_value_label="Maximum possible",
#     title="Prompt & ICL shared heads (top 20)",
# )


ordered_color_indices = [0, 1, 6, 7]
plot_layer_rows_metric_bar_chart(
    layer_key_dicts,
    20,
    "shared_heads",
    models=[model for model in RELEVANT_MODELS if ("3.2-3B" in model) or ("3.1-8B" in model)],
    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],
    ylim=(0, 20),
    fake_value=20,
    fake_value_label="Maximum possible",
    fontfamily="monospace",
    # title="Prompt & ICL shared heads (top 20)",
)

In [None]:
def repeat_elements(items, n):
    return [item for item in items for _ in range(n)]


def plot_layer_rows_individual_dots(
    layer_info_dicts: typing.Dict[str, typing.Dict[str, float]],
    n_top_heads: int,
    metric_names: str | typing.Sequence[str],
    models: typing.Sequence[str] | None = None,
    colors: typing.Sequence[str] | None = None,
    title: str | None = None,
    ylabel: str | None = None,
    ylabel_first_ax_only: bool = True,
    fontsize: int = 20,
    font_inc: int = 4,
    metric_labels: typing.Dict[str, str] = None,
    metric_name_to_plot_kwargs: typing.Dict[str, typing.Dict[str, typing.Any]] = None,
    metric_name_to_err_metric: typing.Dict[str, str] = None,
    annotate_values: typing.Dict[str, typing.Dict[str, float]] | None = None,
    annotate_kwargs: typing.Dict[str, typing.Any] = None,
    ylim: typing.Tuple[float, float] | None = None,
    ylim_from_model_layers: bool = False,
    **global_plot_kwargs,
):
    if models is None:
        models = RELEVANT_MODELS[:]

    if colors is not None and len(colors) != len(models):
        raise ValueError(f"Length of colors ({len(colors)}) does not match number of models ({len(models)}).")

    if isinstance(metric_names, str):
        metric_names = [metric_names]

    if metric_name_to_plot_kwargs is None:
        metric_name_to_plot_kwargs = dict()

    if metric_name_to_err_metric is None:
        metric_name_to_err_metric = dict()

    if annotate_values is None:
        annotate_values = dict()

    if annotate_kwargs is None:
        annotate_kwargs = dict()

    if metric_labels is None:
        metric_labels = dict()

    elif isinstance(metric_labels, (list, tuple)):
        if len(metric_labels) != len(metric_names):
            raise ValueError(
                f"Length of metric_labels ({len(metric_labels)}) does not match number of metrics ({len(metric_names)})."
            )
        metric_labels = dict(zip(metric_names, metric_labels))

    annotate_kwargs = {**DEFUALT_ANNOTATE_KWARGS, **annotate_kwargs}

    rows = []
    for metric_name in metric_names:
        for model in models:
            base_model = model.replace("-Instruct", "").replace("-chat", "")
            is_instruct = ("-Instruct" in model) or ("-chat" in model)
            model_metric_vals = layer_info_dicts[f"{model} @ {n_top_heads}"].get(metric_name, np.nan)
            if isinstance(model_metric_vals, (int, float)):
                rows.append(
                    dict(
                        model=model,
                        base_model=base_model,
                        is_instruct=is_instruct,
                        metric_name=metric_name,
                        value=model_metric_vals,
                    )
                )
            elif isinstance(model_metric_vals, list):
                for i, val in enumerate(model_metric_vals):
                    rows.append(
                        dict(
                            model=model,
                            base_model=base_model,
                            is_instruct=is_instruct,
                            metric_name=metric_name,
                            value=val,
                            index=i,
                        )
                    )

    all_values_df = pd.DataFrame(rows)
    all_values_df = all_values_df.assign(
        metric_name=all_values_df.metric_name.map(lambda x: metric_labels.get(x, x)),
    )

    fig, axes = plt.subplots(1, 5, figsize=(36, 6))
    axes = axes.flatten()

    for b, base_model in enumerate(all_values_df.base_model.unique()):
        base_model_df = all_values_df[all_values_df.base_model == base_model]
        ax = axes[b]

        base_model_colors = colors[b * len(metric_names) : (b + 1) * len(metric_names)]

        sns.swarmplot(
            data=base_model_df,
            x="model",
            y="value",
            hue="metric_name",
            ax=ax,
            palette=base_model_colors,
            **global_plot_kwargs,
        )

        sns.boxplot(
            data=base_model_df,
            x="model",
            y="value",
            hue="metric_name",
            showmeans=True,
            meanline=True,
            meanprops={"color": "k", "ls": "-", "lw": 2},
            medianprops={"visible": False},
            whiskerprops={"visible": False},
            zorder=10,
            showfliers=False,
            showbox=False,
            showcaps=False,
            palette=base_model_colors,
            legend=None,
            ax=ax,
        )

        ax.set_xlabel("Model", fontsize=fontsize + font_inc)
        if (not ylabel_first_ax_only) or b == 0:
            ax.set_ylabel(
                metric_name.replace("_", " ").capitalize() if ylabel is None else ylabel, fontsize=fontsize + font_inc
            )

        ax_title = title
        if ax_title is None:
            ax_title = f"{base_model} ({n_top_heads} heads)"

        ax.set_title(ax_title, fontsize=fontsize + (2 * font_inc))

        ax.tick_params(axis="both", labelsize=fontsize)

        if ylim_from_model_layers:
            model_layers = MODEL_TO_N_LAYERS[base_model]
            ax.set_ylim(0, model_layers)
            y_step = 1
            for divider in range(10, 0, -1):
                if model_layers % divider == 0:
                    y_step = model_layers / divider
                    break

            ax.set_yticks(np.arange(0, model_layers + y_step, y_step))

        elif ylim is not None:
            ax.set_ylim(ylim)

        # Add a legend artist for the black line representing the mean
        mean_line = plt.Line2D([0], [0], color="k", linestyle="-", linewidth=2, label="Mean")
        loc = "lower right" if "Llama-3" in base_model else "best"
        ax.legend(handles=ax.get_legend_handles_labels()[0] + [mean_line], fontsize=fontsize - font_inc, loc=loc)
        # ax.legend(fontsize=fontsize - font_inc)

    plt.tight_layout()
    plt.show()


In [None]:
from palettable.colorbrewer.qualitative import Paired_12 as cmap


def flip_pairs(lst):
    return [item for pair in zip(lst[1::2], lst[::2]) for item in pair]


MODELS_NO_OLMO = [model for model in RELEVANT_MODELS if "OLMo" not in model]
OLMO_MODELS = [model for model in RELEVANT_MODELS if "OLMo" in model]

for n_top_heads in (10, 20):
    plot_layer_rows_individual_dots(
        layer_key_dicts,
        n_top_heads,
        # ['icl_layer_depths', 'prompt_layer_depths'],
        ["icl_layers", "prompt_layers"],
        models=MODELS_NO_OLMO,
        colors=flip_pairs([cmap.mpl_colors[i] for i in range(len(MODELS_NO_OLMO))]),
        ylabel="Top head layer",
        # fake_value=10,
        # fake_value_label='Maximum possible',
        # title='Prompt & ICL mean head layer depths (top 10)',
        metric_labels=["ICL", "Prompt"],
        # ylim=(0, 1),
        ylim_from_model_layers=True,
        dodge=True,
        size=10,
    )

# for n_top_heads in (10, 20):
#     plot_layer_rows_individual_dots(
#         layer_key_dicts,
#         n_top_heads,
#         # ['icl_layer_depths', 'prompt_layer_depths'],
#         ["icl_layers", "prompt_layers"],
#         models=[model for model in RELEVANT_MODELS if "OLMo" in model],
#         colors=flip_pairs([olmo_cmap.mpl_colors[i] for i in range(4)]),
#         ylabel="Top head layer",
#         # fake_value=10,
#         # fake_value_label='Maximum possible',
#         # title='Prompt & ICL mean head layer depths (top 10)',
#         metric_labels=["ICL", "Prompt"],
#         # ylim=(0, 1),
#         ylim_from_model_layers=True,
#         dodge=True,
#         size=10,
#     )

In [None]:
plot_layer_rows_metric_bar_chart(
    layer_key_dicts,
    10,
    ["icl_layer_depth", "prompt_layer_depth"],
    models=MODELS_NO_OLMO,
    colors=flip_pairs([cmap.mpl_colors[i] for i in range(len(MODELS_NO_OLMO))]),
    ylabel="Mean layer depth",
    # fake_value=10,
    # fake_value_label='Maximum possible',
    title="Prompt & ICL mean head layer depths (top 10)",
    ylim=(0, 1),
    metric_name_to_bar_kwargs=dict(
        prompt_layer_depth=dict(hatch="/"),
    ),
    metric_name_to_err_metric=dict(
        icl_layer_depth="icl_layer_depth_std",
        prompt_layer_depth="prompt_layer_depth_std",
    ),
    # annotate_values=dict(
    #     icl_layer_depth={m: d['zs_icl_10_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()},
    #     prompt_layer_depth={m: d['zs_both_10_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()}
    # ),
)

plot_layer_rows_metric_bar_chart(
    layer_key_dicts,
    20,
    ["icl_layer_depth", "prompt_layer_depth"],
    models=MODELS_NO_OLMO,
    colors=flip_pairs([cmap.mpl_colors[i] for i in range(len(MODELS_NO_OLMO))]),
    ylabel="Mean layer depth",
    # fake_value=10,
    # fake_value_label='Maximum possible',
    title="Prompt & ICL mean head layer depths (top 20)",
    ylim=(0, 1),
    metric_name_to_bar_kwargs=dict(
        prompt_layer_depth=dict(hatch="/"),
    ),
    metric_name_to_err_metric=dict(
        icl_layer_depth="icl_layer_depth_std",
        prompt_layer_depth="prompt_layer_depth_std",
    ),
    # annotate_values=dict(
    #     icl_layer_depth={m: d['zs_icl_20_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()},
    #     prompt_layer_depth={m: d['zs_both_20_max_acc_layer_depth'] for m, d in layer_depth_model_dicts.items()}
    # ),
)

In [None]:
lines = []

for i in range(0, len(RELEVANT_MODELS), 2):
    base_model = RELEVANT_MODELS[i]
    instruct_model = RELEVANT_MODELS[i + 1]

    for n_heads in (10, 20):
        base_heads = top_heads_summary_df[
            (top_heads_summary_df.model == base_model) & (top_heads_summary_df.n_heads == n_heads)
        ].prompt_heads.values[0]
        instruct_heads = top_heads_summary_df[
            (top_heads_summary_df.model == instruct_model) & (top_heads_summary_df.n_heads == n_heads)
        ].prompt_heads.values[0]
        base_head_set = set(base_heads)
        instruct_head_set = set(instruct_heads)
        shared_heads = base_head_set & instruct_head_set
        lines.append(f" - {base_model} & {instruct_model} share {len(shared_heads)} / {n_heads} ")

        shared_heads_mean_layer = np.mean([t[0] for t in shared_heads])
        base_heads_mean_layer = np.mean([t[0] for t in (base_head_set - instruct_head_set)])
        instruct_heads_mean_layer = np.mean([t[0] for t in (instruct_head_set - base_head_set)])
        lines.append(
            f"   - {base_model} only mean layer: {base_heads_mean_layer:.2f} | shared heads mean layer: {shared_heads_mean_layer:.2f} | {instruct_model} only mean layer: {instruct_heads_mean_layer:.2f}"
        )

        base_only_head_indices = [base_heads.index(t) for t in base_head_set - instruct_head_set]
        instruct_only_head_indices = [instruct_heads.index(t) for t in instruct_head_set - base_head_set]
        shared_head_indices = [base_heads.index(t) for t in shared_heads] + [
            instruct_heads.index(t) for t in shared_heads
        ]
        lines.append(
            f"   - {base_model} only head mean index: {np.mean(base_only_head_indices):.2f} | shared heads mean index: {np.mean(shared_head_indices):.2f} | {instruct_model} only mean head index: {np.mean(instruct_only_head_indices):.2f}"
        )


display(Markdown("\n".join(lines)))


# RQ2.5: What can I say about how many attention heads are necessary/helpful?

A couple of angles of attack here:

- Plot the distribution of indirect effects for each model
- Plot the overlap between similar settings in the prompt-based approach


In [None]:
def scatter_plot_top_effects(n_top_heads: int, negative=False, fontsize=20, print_every: int | None = None):
    fig, axes = plt.subplots(2, 5, figsize=(36, 12))
    axes = axes.flatten()

    for offset_idx, (prompt_types, name) in enumerate(
        zip(
            ((SHORT, LONG), ICL),
            ("Prompt-based", "ICL"),
        )
    ):
        for i in range(len(RELEVANT_MODELS) // 2):
            ax_offset = offset_idx * 5
            model_idx = i * 2
            ax = axes[i + ax_offset]
            ax.set_title(RELEVANT_MODELS[model_idx], fontsize=fontsize + 4)
            ax.set_xlabel("Index", fontsize=fontsize)
            ax.set_ylabel(f"{name} indirect effect", fontsize=fontsize)

            x_values = np.arange(n_top_heads)
            _, base_effects = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                RELEVANT_MODELS[model_idx],
                prompt_types,
                n_top_heads=n_top_heads,
                negative=negative,
            )
            base_effects = np.array(base_effects)
            ax.scatter(x_values, base_effects, label="Base", color="blue", alpha=0.5)
            wrong_sign = base_effects > 0 if negative else base_effects < 0
            if wrong_sign.any():
                logger.warning(
                    f"Found {'positive' if negative else 'negative'} entries in {RELEVANT_MODELS[model_idx]} ({name}) starting at index {wrong_sign.argmax()}"
                )
            if print_every is not None:
                print(f"{RELEVANT_MODELS[model_idx]}: {base_effects[::print_every]}")

            _, instruct_effects = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                RELEVANT_MODELS[model_idx + 1],
                prompt_types,
                n_top_heads=n_top_heads,
                negative=negative,
            )
            instruct_effects = np.array(instruct_effects)
            ax.scatter(x_values, instruct_effects, label="Instruct", color="orange", alpha=0.5)
            wrong_sign = instruct_effects > 0 if negative else instruct_effects < 0
            if wrong_sign.any():
                logger.warning(
                    f"Found {'positive' if negative else 'engative'} entries in {RELEVANT_MODELS[model_idx + 1]} ({name}) starting at index {wrong_sign.argmax()}"
                )
            if print_every is not None:
                print(f"{RELEVANT_MODELS[model_idx + 1]}: {instruct_effects[::print_every]}")

            ax.tick_params(axis="x", labelsize=fontsize - 4)
            ax.tick_params(axis="y", labelsize=fontsize - 4)
            ax.legend(fontsize=fontsize)
            leg = ax.legend(fontsize=fontsize)
            for lh in leg.legend_handles:
                lh.set_alpha(1)

    plt.tight_layout()
    plt.show()


scatter_plot_top_effects(
    n_top_heads=200,
    # print_every=25,
)

scatter_plot_top_effects(
    n_top_heads=200,
    negative=True,
)

In [None]:
from palettable.cartocolors.qualitative import Bold_4 as cmap


def scatter_plot_top_effects_same_axes(
    n_top_heads: int,
    negative=False,
    flip_negative=False,
    fontsize=20,
    log=False,
    print_every: int | None = None,
    colors: typing.Sequence[str] = None,
):
    fig, axes = plt.subplots(1, 5, figsize=(36, 6))
    axes = axes.flatten()

    for offset_idx, (prompt_types, name) in enumerate(
        zip(
            ((SHORT, LONG), ICL),
            ("Prompt-based", "ICL"),
        )
    ):
        for i in range(len(RELEVANT_MODELS) // 2):
            ax_offset = 0
            model_idx = i * 2
            ax = axes[i + ax_offset]
            ax.set_title(RELEVANT_MODELS[model_idx], fontsize=fontsize + 4)
            ax.set_xlabel("Index", fontsize=fontsize)
            ax.set_ylabel(f"{name} indirect effect", fontsize=fontsize)

            base_color = None
            if colors is not None:
                base_color = colors[(2 * offset_idx)]

            x_values = np.arange(n_top_heads)
            _, base_effects = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                RELEVANT_MODELS[model_idx],
                prompt_types,
                n_top_heads=n_top_heads,
                negative=negative,
            )
            base_effects = np.array(base_effects)
            if negative and flip_negative:
                base_effects = -base_effects

            ax.scatter(
                x_values,
                base_effects,
                label=f"Base ({'prompt' if offset_idx == 0 else 'ICL'})",
                color=base_color,
                alpha=0.25,
            )
            wrong_sign = base_effects > 0 if negative else base_effects < 0
            if wrong_sign.any() and not flip_negative:
                logger.warning(
                    f"Found {'positive' if negative else 'engative'} entries in {RELEVANT_MODELS[model_idx]} ({name}) starting at index {wrong_sign.argmax()}"
                )
            if print_every is not None:
                print(f"{RELEVANT_MODELS[model_idx]}: {base_effects[::print_every]}")

            _, instruct_effects = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                RELEVANT_MODELS[model_idx + 1],
                prompt_types,
                n_top_heads=n_top_heads,
                negative=negative,
            )
            instruct_effects = np.array(instruct_effects)
            if negative and flip_negative:
                instruct_effects = -instruct_effects

            instruct_color = None
            if colors is not None:
                instruct_color = colors[(2 * offset_idx) + 1]

            ax.scatter(
                x_values,
                instruct_effects,
                label=f"Instruct ({'prompt' if offset_idx == 0 else 'ICL'})",
                color=instruct_color,
                alpha=0.25,
            )
            wrong_sign = instruct_effects > 0 if negative else instruct_effects < 0
            if wrong_sign.any() and not flip_negative:
                logger.warning(
                    f"Found {'positive' if negative else 'engative'} entries in {RELEVANT_MODELS[model_idx + 1]} ({name}) starting at index {wrong_sign.argmax()}"
                )
            if print_every is not None:
                print(f"{RELEVANT_MODELS[model_idx + 1]}: {instruct_effects[::print_every]}")

            ax.tick_params(axis="x", labelsize=fontsize - 4)
            ax.tick_params(axis="y", labelsize=fontsize - 4)
            leg = ax.legend(fontsize=fontsize)
            for lh in leg.legend_handles:
                lh.set_alpha(1)
            if log:
                ax.set_yscale("log")

    plt.tight_layout()
    plt.show()


ordered_color_indices = np.arange(4)

scatter_plot_top_effects_same_axes(
    n_top_heads=200,
    log=True,
    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],
)

scatter_plot_top_effects_same_axes(
    n_top_heads=200,
    negative=True,
    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],
)

scatter_plot_top_effects_same_axes(
    n_top_heads=200,
    negative=True,
    flip_negative=True,
    log=True,
    colors=[cmap.mpl_colors[i] for i in ordered_color_indices],
)

In [None]:
def histogram_plot_all_effects(bins: int = 50, fontsize=20, **hist_kwargs):
    fig, axes = plt.subplots(2, 5, figsize=(36, 12))
    axes = axes.flatten()

    for offset_idx, (prompt_types, name) in enumerate(
        zip(
            ((SHORT, LONG), ICL),
            ("Prompt-based", "ICL"),
        )
    ):
        for i in range(len(RELEVANT_MODELS) // 2):
            ax_offset = offset_idx * 5
            model_idx = i * 2
            ax = axes[i + ax_offset]
            ax.set_title(RELEVANT_MODELS[model_idx], fontsize=fontsize + 4)
            ax.set_xlabel(f"{name} indirect effect", fontsize=fontsize)
            ax.set_ylabel("Proportion", fontsize=fontsize)

            _, _, base_mean_effects = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                RELEVANT_MODELS[model_idx],
                prompt_types,
                return_mean=True,
            )
            _, _, inst_mean_effects = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                RELEVANT_MODELS[model_idx + 1],
                prompt_types,
                return_mean=True,
            )
            base_mean_effects = base_mean_effects.flatten().numpy()
            inst_mean_effects = inst_mean_effects.flatten().numpy()
            ax.hist(
                [base_mean_effects, inst_mean_effects],
                bins=bins,
                label=["Base", "Instruct"],
                color=["blue", "orange"],
                alpha=0.5,
                **hist_kwargs,
            )
            ax.tick_params(axis="x", labelsize=fontsize - 4)
            ax.tick_params(axis="y", labelsize=fontsize - 4)
            ax.legend(fontsize=fontsize)

    plt.tight_layout()
    plt.show()


histogram_plot_all_effects(density=True, log=True)

In [None]:
def box_plot_all_effects(fontsize=20, **boxplot_kwargs):
    fig, axes = plt.subplots(1, 2, figsize=(12, 12))
    axes = axes.flatten()

    for ax_idx, (prompt_types, name) in enumerate(
        zip(
            ((SHORT, LONG), ICL),
            ("Prompt-based", "ICL"),
        )
    ):
        ax = axes[ax_idx]
        ax.set_title(f"{name} indirect effects", fontsize=fontsize + 4)
        ax.set_xlabel("Indirect effect", fontsize=fontsize)
        ax.set_ylabel("Model", fontsize=fontsize)

        mean_effects = [
            compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                model,
                prompt_types,
                return_mean=True,
            )[2]
            for model in RELEVANT_MODELS
        ]

        print([(model, mean_effects[i].shape) for i, model in enumerate(RELEVANT_MODELS)])

        mean_effects = [me.flatten().numpy() for me in mean_effects]

        model_name_formatter = lambda x: RELEVANT_MODELS[x]

        sns.stripplot(mean_effects, ax=ax, formatter=model_name_formatter, **boxplot_kwargs)

    plt.tight_layout()
    plt.show()


box_plot_all_effects(orient="h", size=3)

In [None]:
def moving_average(data_set, periods=3):
    weights = np.ones(periods) / periods
    return np.convolve(data_set, weights, mode="valid")


def scatter_plot_top_heads_overlap(
    comparison_groups: typing.Dict[str, typing.Dict[str, typing.Any]],
    min_n_top_heads: int = 10,
    max_n_top_heads: int = 100,
    moving_average_periods: int | None = None,
    fontsize=20,
    title=None,
):
    fig, axes = plt.subplots(2, 5, figsize=(36, 12))
    axes = axes.flatten()

    for i, model in enumerate(RELEVANT_SCATTER_ORDERED_MODELS):
        ax = axes[i]
        ax.set_title(model, fontsize=fontsize + 4)
        ax.set_xlabel("Index", fontsize=fontsize)
        ax.set_ylabel("% top heads shared", fontsize=fontsize)

        model_max_n_top_heads = max_n_top_heads
        if isinstance(model_max_n_top_heads, float):
            layers, heads = MODEL_TO_N_LAYERS_HEADS[model]
            total_heads = layers * heads
            model_max_n_top_heads = int(total_heads * model_max_n_top_heads)

        for first_key, second_key in itertools.combinations(comparison_groups.keys(), 2):
            x_values = np.arange(min_n_top_heads, model_max_n_top_heads + 1)
            first_kwargs = comparison_groups[first_key]
            second_kwargs = comparison_groups[second_key]

            first_top_heads, _ = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                model,
                **first_kwargs,
                n_top_heads=model_max_n_top_heads,
            )
            second_top_heads, _ = compute_top_heads(
                result_df,
                indirect_effects_by_model_and_dataset,
                model,
                **second_kwargs,
                n_top_heads=model_max_n_top_heads,
            )

            first_top_head_set = set([tuple(t) for t in first_top_heads[:min_n_top_heads]])
            second_top_head_set = set([tuple(t) for t in second_top_heads[:min_n_top_heads]])
            shared_fractions = []

            for n in range(min_n_top_heads, model_max_n_top_heads + 1):
                first_top_head_set.add(tuple(first_top_heads[n - 1]))
                second_top_head_set.add(tuple(second_top_heads[n - 1]))
                shared = len(first_top_head_set & second_top_head_set)
                shared_fraction = shared / n
                shared_fractions.append(shared_fraction)

            if moving_average_periods is not None:
                shared_fractions = moving_average(np.array(shared_fractions), moving_average_periods)
                missing = int(np.floor(moving_average_periods / 2))
                x_values = x_values[missing:-missing]

            ax.plot(x_values, shared_fractions, label=f"{first_key} vs {second_key}", alpha=0.5)
            ax.tick_params(axis="x", labelsize=fontsize - 4)
            ax.tick_params(axis="y", labelsize=fontsize - 4)
            if len(comparison_groups) > 2:
                ax.legend(fontsize=fontsize - 4)

    if title is not None:
        if moving_average_periods is not None:
            title = f"{title} ({moving_average_periods}-step moving average)"
        fig.suptitle(title, fontsize=fontsize + 8)
    plt.tight_layout()
    plt.show()


scatter_plot_top_heads_overlap(
    comparison_groups={
        SHORT: dict(prompt_types=[SHORT]),
        LONG: dict(prompt_types=[LONG]),
    },
    max_n_top_heads=0.25,
    moving_average_periods=5,
    title="Short vs. long prompts top head overlap",
)

In [None]:
scatter_plot_top_heads_overlap(
    comparison_groups={b: dict(prompt_types=[SHORT, LONG], baselines=[b]) for b in BASELINES},
    max_n_top_heads=0.25,
    moving_average_periods=5,
    title="Top head overlap by baseline",
)

In [None]:
scatter_plot_top_heads_overlap(
    comparison_groups={
        "Prompts": dict(prompt_types=[SHORT, LONG]),
        ICL: dict(prompt_types=[ICL]),
    },
    max_n_top_heads=0.25,
    moving_average_periods=5,
    title="Prompts vs. ICL top head overlap",
)