# Captain Analysis

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json
import glob
import re

import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd

from battleship.run_captain_benchmarks import rebuild_captain_summary_from_results
from battleship.utils import resolve_project_path
from battleship.agents import EIGCalculator, CodeQuestion, Question
from battleship.game import Board

from analysis import CAPTAIN_TYPE_LABELS, MODEL_DISPLAY_NAMES
from analysis import human_round_summaries

In [None]:
%config InlineBackend.figure_format = 'retina'

# set seaborn color palette
sns.set_palette("tab10")

# set seaborn style
sns.set_style("whitegrid")
sns.set_context("talk")

In [None]:
HUMAN_EXPERIMENT_NAME = "battleship-final-data"
PATH_DATA = os.path.join("data", HUMAN_EXPERIMENT_NAME)
PATH_EXPORT = os.path.join(PATH_DATA, "export")

CAPTAIN_EXPERIMENT_PATH = (
    "experiments/collaborative/captain_benchmarks/"
)

## Data loading

### Human data

In [None]:
human_df = human_round_summaries(
    experiment_path=PATH_DATA,
)
human_df = pd.DataFrame(human_df)

human_df = human_df.assign(llm="Human")
human_df

### Model data

In [None]:
model_round_data_unresolved_paths = [
    ("gpt-4o", "run_2025_08_25_16_28_19"),
    ("gpt-5", "run_2025_08_25_22_02_29"),
    ("llama-4-scout", "run_2025_08_26_17_56_46"),
    ("Baseline", "run_2025_08_26_17_23_23"),
]

model_round_data_paths = [
    (name, resolve_project_path(os.path.join(CAPTAIN_EXPERIMENT_PATH, path)))
    for name, path in model_round_data_unresolved_paths
]
for name, path in model_round_data_paths:
    if not os.path.exists(path):
        print(f"The path {path} does not exist.")

dfs = []
for name, path in model_round_data_paths:
    df = pd.DataFrame(rebuild_captain_summary_from_results(path))
    df["llm"] = name
    dfs.append(df)

model_df = pd.concat(dfs, ignore_index=True)
model_df

In [None]:
# Append summary_df to round_df
df = pd.concat([human_df, model_df], ignore_index=True)

primary_columns = ["captain_type_display", "llm_display", "board_id", "seed"]

# Create categorical column for captain_type_display
df["captain_type_display"] = pd.Categorical(
    df["captain_type"].map(CAPTAIN_TYPE_LABELS),
    categories=list(dict.fromkeys(CAPTAIN_TYPE_LABELS.values())),
    ordered=True,
)

# Create categorical column for llm_display
df["llm_display"] = pd.Categorical(
    df["llm"],
    categories=["Human", "Baseline"] + [x for x in MODEL_DISPLAY_NAMES.values() if x in df["llm"].unique()],
    ordered=True,
)

# Move primary columns to the front
df = df[primary_columns + [col for col in df.columns if col not in primary_columns]]

# Sort the DataFrame by primary columns
df = df.sort_values(by=primary_columns, ascending=True).reset_index(drop=True)

df

## Precision/Recall Stats

In [None]:
print("\nBreakdown by captain_type_display:")
for captain_type in df['captain_type_display'].cat.categories:
    llms = df[df['captain_type_display'] == captain_type]['llm'].unique()
    print(f"{captain_type}: {llms}")


# Colorblind-friendly palette (Okabe–Ito)
llm_palette = {
    "Human": "#009E73",  # green
    "Baseline": "#0072B2",  # blue
    "llama-4-scout": "#CC79A7",  # purple
    "gpt-4o": "#E69F00",  # orange (similar to gpt-5)
    "gpt-5": "#D55E00",  # vermillion
}

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))

sns.boxplot(
    data=df,
    x="captain_type_display",
    y="f1_score",
    hue="llm",
    palette=llm_palette,
    ax=ax,
)
sns.despine()

plt.xlabel("Captain Type")
plt.ylabel("Firing Accuracy (F1)")

plt.xticks(rotation=90)

ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="")

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))

# Prepare ordered captain categories that actually appear in the data
captain_categories = [
    c for c in df["captain_type_display"].cat.categories
    if c in df["captain_type_display"].values
]

# Determine max number of llm groups present for any captain (for consistent box widths)
llm_counts = df.groupby("captain_type_display")["llm"].nunique()
max_llms = int(llm_counts.max()) if len(llm_counts) > 0 else 1

# Base positions for each captain on the x axis
x_positions = np.arange(len(captain_categories))

# Box width: leave some padding between captain groups
group_width = 0.5  # total width occupied by boxes for one captain
box_width = group_width / max_llms

# Ensure grid lines are drawn below plot elements and only horizontal gridlines are shown
ax.set_axisbelow(True)
ax.xaxis.grid(False)
ax.yaxis.grid(True)

# Map captain -> present llms to ensure we only plot existing combinations
# Use the llm_display categorical ordering so order is consistent across plots
llm_order = list(df["llm_display"].cat.categories) if "llm_display" in df.columns else sorted(df["llm"].unique())

for i, captain in enumerate(captain_categories):
    present_llms_unsorted = df[df["captain_type_display"] == captain]["llm"].unique()
    # Preserve the display order
    present_llms = [llm for llm in llm_order if llm in present_llms_unsorted]

    m = len(present_llms)
    if m == 0:
        continue

    # Offsets to center m boxes around the captain x position
    offsets = (np.arange(m) - (m - 1) / 2.0) * box_width

    for j, llm in enumerate(present_llms):
        subset = df[(df["captain_type_display"] == captain) & (df["llm"] == llm)]["f1_score"].dropna()
        if subset.empty:
            continue

        pos = x_positions[i] + offsets[j]
        color = llm_palette.get(llm, "#808080")

        # Use matplotlib's boxplot to place each box at the computed numeric position
        bp = ax.boxplot(subset.values,
                        positions=[pos],
                        widths=box_width * 0.9,
                        patch_artist=True,
                        manage_ticks=False)

        # Style the box elements
        for element in ["boxes", "whiskers", "caps", "medians"]:
            plt.setp(bp[element], color=color)
        for patch in bp["boxes"]:
            patch.set(facecolor=color, alpha=0.6)

        # Make fliers (outliers) less visually distinctive: smaller, lower-alpha, and same color as box
        if "fliers" in bp:
            for f in bp["fliers"]:
                f.set(marker='o', markersize=3, markerfacecolor=color, markeredgecolor=color, alpha=0.35, markeredgewidth=0)

# Create legend handles for llm types present in the full DataFrame, in llm_display order
from matplotlib.patches import Patch
all_present_llms = [llm for llm in llm_order if llm in df["llm"].unique()]
legend_handles = [Patch(facecolor=llm_palette[k], label=k, alpha=0.6) for k in all_present_llms]

ax.legend(handles=legend_handles, loc="upper left", bbox_to_anchor=(1, 1), title="")

# Final formatting
ax.set_xticks(x_positions)
ax.set_xticklabels(captain_categories, rotation=90)
ax.set_xlabel("Captain Type")
ax.set_ylabel("Firing Accuracy (F1)")
ax.set_xlim(-0.5, len(captain_categories) - 0.5)

sns.despine()
# plt.tight_layout()
# plt.show()

plt.savefig(os.path.join(PATH_EXPORT, "captain_f1_boxplot.pdf"), dpi=300, bbox_inches="tight")

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))

sns.stripplot(
    data=df,
    x="captain_type_display",
    y="f1_score",
    hue="llm",
    palette=llm_palette,
    alpha=0.7,
    ax=ax,
)
sns.despine()

plt.xlabel("Captain Type")
plt.ylabel("Firing Accuracy (F1)")

plt.xticks(rotation=90)

ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="")

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

sns.swarmplot(
    data=df,
    x="captain_type_display",
    y="f1_score",
    hue="llm",
    palette=llm_palette,
    # alpha=0.7,
    ax=ax,
)
sns.despine()

plt.xlabel("Captain Type")
plt.ylabel("Firing Accuracy (F1)")

plt.xticks(rotation=90)

ax.legend(loc="upper left", bbox_to_anchor=(1, 1), title="")

plt.show()


In [None]:
sns.displot(
    data=df,
    kind="ecdf",
    x="f1_score",
    hue="captain_type_display",
)

In [None]:
df["move_count"] = df["hits"] + df["misses"]

sns.barplot(data=df, x="captain_type", y="move_count", hue="captain_type")
plt.xticks(rotation=45)

## Win rates

In [None]:
counts = df.groupby(["captain_type_display", "llm_display"])["board_id"].count()
counts_nonzero = counts[counts > 0].reset_index(name="count")
counts_nonzero

In [None]:
df

In [None]:
# --- Win Rate Computation Utilities ---
from itertools import product, combinations
from functools import lru_cache


# Create a composite competitor identifier (LLM first, then captain type)
# This treats each (llm_display, captain_type_display) pair as a distinct competitor.
df["competitor"] = (
    df["llm_display"].astype(str)
    + " | "
    + df["captain_type_display"].astype(str)
)

# Enforce categorical ordering: iterate LLMs (primary) then captain types (secondary) in their existing category orders.
competitor_order = []
llm_categories = list(df["llm_display"].cat.categories) if isinstance(df["llm_display"].dtype, pd.CategoricalDtype) else sorted(df["llm_display"].dropna().unique())
cap_categories = list(df["captain_type_display"].cat.categories) if isinstance(df["captain_type_display"].dtype, pd.CategoricalDtype) else sorted(df["captain_type_display"].dropna().unique())
for llm in llm_categories:
    # Which captain types exist for this llm
    present_caps = [c for c in cap_categories if ((df["llm_display"] == llm) & (df["captain_type_display"] == c)).any()]
    for cap in present_caps:
        competitor_order.append(f"{llm} | {cap}")

df["competitor"] = pd.Categorical(df["competitor"], categories=competitor_order, ordered=True)


def compute_board_level_win_rate(a_vals, b_vals, higher_is_better=True):
    """Compute win rate of A over B for all pairwise comparisons of metric values.
    Ties count as 0.5 wins. Returns (win_rate, wins, comparisons)."""
    wins = 0.0
    comparisons = 0
    for a, b in product(a_vals, b_vals):
        if pd.isna(a) or pd.isna(b):
            continue
        comparisons += 1
        if higher_is_better:
            if a > b:
                wins += 1
            elif a == b:
                wins += 0.5
        else:  # lower is better
            if a < b:
                wins += 1
            elif a == b:
                wins += 0.5
    if comparisons == 0:
        return np.nan, wins, comparisons
    return wins / comparisons, wins, comparisons


def compute_pairwise_win_rates(
    df,
    metric="f1_score",
    higher_is_better=True,
    captain_col="captain_type_display",
    board_col="board_id",
    round_col="round_id",
):
    """Compute pairwise win rates between all competitors (captain_col values).

    For each pair of competitors (A,B) and each board, take the cross product of all rounds
    (identified by board_id + round_id) for A vs B and compute the proportion of wins.
    A board-level win rate is that proportion (ties=0.5 wins).

    Two aggregate views:
      * mean_board_win_rate: unweighted mean over boards with ≥1 round for both competitors.
      * weighted_all_pairs_win_rate: total wins / total pairwise comparisons (pooled).
    """
    if metric not in df.columns:
        raise ValueError(f"Metric '{metric}' not found in DataFrame columns.")

    competitors = [c for c in df[captain_col].dropna().unique()]
    boards = sorted(df[board_col].dropna().unique())

    records = []
    grouped = df.groupby([captain_col, board_col])

    for ca, cb in combinations(competitors, 2):
        board_results = []
        total_wins = 0.0
        total_comparisons = 0
        boards_considered = 0
        for board in boards:
            try:
                a_group = grouped.get_group((ca, board))
            except KeyError:
                a_group = pd.DataFrame(columns=df.columns)
            try:
                b_group = grouped.get_group((cb, board))
            except KeyError:
                b_group = pd.DataFrame(columns=df.columns)

            a_vals = a_group[metric].dropna().values
            b_vals = b_group[metric].dropna().values
            if len(a_vals) == 0 or len(b_vals) == 0:
                continue

            board_win_rate, wins, comps = compute_board_level_win_rate(
                a_vals, b_vals, higher_is_better=higher_is_better
            )
            if not np.isnan(board_win_rate):
                board_results.append((board, board_win_rate, wins, comps))
                total_wins += wins
                total_comparisons += comps
                boards_considered += 1

        if boards_considered == 0:
            mean_board_win_rate = np.nan
            weighted_all_pairs_win_rate = np.nan
        else:
            mean_board_win_rate = (
                np.nanmean([br for _, br, _, _ in board_results]) if board_results else np.nan
            )
            weighted_all_pairs_win_rate = (
                total_wins / total_comparisons if total_comparisons > 0 else np.nan
            )

        for board, br, wins, comps in board_results:
            records.append(
                {
                    "competitor_a": ca,
                    "competitor_b": cb,
                    "metric": metric,
                    "higher_is_better": higher_is_better,
                    "board_id": board,
                    "board_win_rate": br,
                    "board_wins": wins,
                    "board_comparisons": comps,
                }
            )

        records.append(
            {
                "competitor_a": ca,
                "competitor_b": cb,
                "metric": metric,
                "higher_is_better": higher_is_better,
                "board_id": None,
                "board_win_rate": mean_board_win_rate,
                "board_wins": total_wins,
                "board_comparisons": total_comparisons,
                "weighted_all_pairs_win_rate": weighted_all_pairs_win_rate,
                "boards_considered": boards_considered,
            }
        )

    detailed_df = pd.DataFrame(records)
    aggregate_df = detailed_df[detailed_df["board_id"].isna()].copy()
    aggregate_df = aggregate_df.assign(
        mean_board_win_rate=aggregate_df["board_win_rate"],
        weighted_all_pairs_win_rate=aggregate_df["weighted_all_pairs_win_rate"].fillna(
            aggregate_df["board_wins"] / aggregate_df["board_comparisons"].replace({0: np.nan})
        ),
    )

    # Preserve original categorical order if available
    if pd.api.types.is_categorical_dtype(df[captain_col]):
        competitors_sorted = [c for c in df[captain_col].cat.categories if c in competitors]
    else:
        competitors_sorted = sorted(competitors)

    mean_matrix = pd.DataFrame(np.nan, index=competitors_sorted, columns=competitors_sorted)
    weighted_matrix = pd.DataFrame(np.nan, index=competitors_sorted, columns=competitors_sorted)

    for _, row in aggregate_df.iterrows():
        ca, cb = row["competitor_a"], row["competitor_b"]
        mean_ab = row["mean_board_win_rate"]
        weighted_ab = row["weighted_all_pairs_win_rate"]
        if ca in mean_matrix.index and cb in mean_matrix.columns:
            mean_matrix.loc[ca, cb] = mean_ab
            weighted_matrix.loc[ca, cb] = weighted_ab
            if not pd.isna(mean_ab):
                mean_matrix.loc[cb, ca] = 1 - mean_ab
            if not pd.isna(weighted_ab):
                weighted_matrix.loc[cb, ca] = 1 - weighted_ab
            mean_matrix.loc[ca, ca] = 0.5
            mean_matrix.loc[cb, ca] = mean_matrix.loc[cb, ca]
            weighted_matrix.loc[ca, ca] = 0.5
            weighted_matrix.loc[cb, ca] = weighted_matrix.loc[cb, ca]

    return {
        "detailed": detailed_df,
        "aggregate": aggregate_df,
        "mean_board_win_rate_matrix": mean_matrix,
        "weighted_win_rate_matrix": weighted_matrix,
    }


# --- Run win rate analysis for F1 (higher better) at competitor (LLM+captain) granularity ---
win_results_f1_comp = compute_pairwise_win_rates(
    df, metric="f1_score", higher_is_better=True, captain_col="competitor"
)
print("Mean board win rate matrix (F1, competitor-level):")
display(win_results_f1_comp["mean_board_win_rate_matrix"])
print("Weighted all-pairs win rate matrix (F1, competitor-level):")
display(win_results_f1_comp["weighted_win_rate_matrix"])

# --- Run win rate analysis for move count (lower better) ---
if "move_count" in df.columns:
    win_results_moves_comp = compute_pairwise_win_rates(
        df, metric="move_count", higher_is_better=False, captain_col="competitor"
    )
    print("Mean board win rate matrix (Move Count, competitor-level):")
    display(win_results_moves_comp["mean_board_win_rate_matrix"])
    print("Weighted all-pairs win rate matrix (Move Count, competitor-level):")
    display(win_results_moves_comp["weighted_win_rate_matrix"])
else:
    print("Column 'move_count' not found; skip move-count win rates.")

# Aggregate summary for F1
f1_comp_summary = win_results_f1_comp["aggregate"][ [
    "competitor_a", "competitor_b", "mean_board_win_rate", "weighted_all_pairs_win_rate", "boards_considered", "board_wins", "board_comparisons"
] ].sort_values(["competitor_a", "competitor_b"]).reset_index(drop=True)
print("Pairwise aggregate win rates (F1, competitor-level):")
display(f1_comp_summary)


In [None]:
# Heatmap of F1 Weighted All-Pairs Win Rates (competitor-level)
weighted_matrix = win_results_f1_comp["weighted_win_rate_matrix"].copy()

# Ensure numeric dtype
weighted_matrix = weighted_matrix.astype(float)

# Optionally mask the lower triangle (since it's redundant: win(B,A)=1-win(A,B))
show_full = True
if not show_full:
    mask = np.tril(np.ones_like(weighted_matrix, dtype=bool), k=-1)
else:
    mask = None

figsize_base = 0.55  # width/height scaling per competitor
n = len(weighted_matrix)
fig, ax = plt.subplots(figsize=(min(24, 2 + n * figsize_base), min(24, 2 + n * figsize_base)))

# Dynamic annotation font size: shrink as matrix grows
annot_font_size = max(6, min(12, 120 / max(n, 1)))

sns.heatmap(
    weighted_matrix,
    mask=mask,
    cmap="coolwarm",
    vmin=0,
    vmax=1,
    center=0.5,
    annot=True,
    fmt=".2f",
    annot_kws={"size": annot_font_size},
    linewidths=0.4,
    linecolor="white",
    cbar_kws={"label": "Win rate (row beats column)"},
    ax=ax,
)

ax.set_title("F1 Weighted All-Pairs Win Rates (Row beats Column)", fontsize=14)
ax.set_xlabel("Opponent (Column)", fontsize=12)
ax.set_ylabel("Competitor (Row)", fontsize=12)
plt.xticks(rotation=90, fontsize=9)
plt.yticks(rotation=0, fontsize=9)
plt.tight_layout()

# Export
os.makedirs(PATH_EXPORT, exist_ok=True)
heatmap_path = os.path.join(PATH_EXPORT, "f1_weighted_winrate_heatmap.pdf")
plt.savefig(heatmap_path, dpi=300, bbox_inches="tight")
print(f"Saved heatmap to {heatmap_path}")
plt.show()

In [None]:
# Refined grouped heatmap: clearer LLM separation, captain-only tick labels, centered vertical LLM group labels,
# no legend, tighter colorbar, optional within-group clustering, optional column shading.

import colorsys
from matplotlib.patches import Rectangle

base_matrix = win_results_f1_comp["weighted_win_rate_matrix"].copy().astype(float)

# Config
show_group_separators = True
shade_group_rows = True
shade_group_cols = True          # also lightly shade column bands
annotate_group_labels = True
group_label_rotation = 90        # vertical to save horizontal space
group_label_color = "#222222"
captain_tick_fontsize = 6
group_label_fontsize = 10
row_shade_alpha = 0.30
col_shade_alpha = 0.18

llm_group_palette = {k: v for k, v in llm_palette.items()}

# Parse competitor into (LLM, Captain)
def split_comp(c):
    parts = c.split("|", 1)
    if len(parts) == 2:
        return parts[0].strip(), parts[1].strip()
    return c.strip(), ""

competitors = list(base_matrix.index)
competitor_llm = {c: split_comp(c)[0] for c in competitors}
competitor_captain = {c: split_comp(c)[1] for c in competitors}

# Build contiguous LLM blocks (current ordering already LLM-primary)
blocks = []
current_llm, current_block = None, []
for c in competitors:
    llm = competitor_llm[c]
    if llm != current_llm:
        if current_block:
            blocks.append((current_llm, current_block))
        current_llm = llm
        current_block = [c]
    else:
        current_block.append(c)
if current_block:
    blocks.append((current_llm, current_block))

# New order
new_order = [c for _, comps in blocks for c in comps]
mat = base_matrix.loc[new_order, new_order]

# Helper to lighten color
def lighten(hex_color, factor=0.9):
    try:
        hex_color = hex_color.lstrip('#')
        r, g, b = [int(hex_color[i:i+2], 16) for i in (0, 2, 4)]
        h, l, s = colorsys.rgb_to_hls(r/255, g/255, b/255)
        l = 1 - (1 - l) * factor
        r2, g2, b2 = colorsys.hls_to_rgb(h, l, s)
        return f"#{int(r2*255):02x}{int(g2*255):02x}{int(b2*255):02x}"
    except Exception:
        return "#f0f0f0"

row_shade_colors = {llm: lighten(llm_group_palette.get(llm, '#888888'), 0.93) for llm, _ in blocks}

n = len(mat)
figsize_base = 0.50
fig, ax = plt.subplots(figsize=(min(24, 1.4 + n * figsize_base), min(24, 1.4 + n * figsize_base)))
annot_font_size = max(6, min(12, 120 / max(n, 1)))

hm = sns.heatmap(
    mat,
    cmap="coolwarm",
    vmin=0, vmax=1, center=0.5,
    annot=True, fmt=".2f",
    annot_kws={"size": annot_font_size},
    linewidths=0.4, linecolor="white",
    cbar_kws={"shrink": 0.6, "pad": 0.02},
    ax=ax,
)

# Tick labels: captain type only
captain_labels = [competitor_captain[c] for c in new_order]
ax.set_xticks(np.arange(n) + 0.5)
ax.set_yticks(np.arange(n) + 0.5)
ax.set_xticklabels(captain_labels, rotation=90, ha='center', va='top', fontsize=captain_tick_fontsize)
ax.set_yticklabels(captain_labels, rotation=0, ha='right', va='center', fontsize=captain_tick_fontsize)

# Remove default axis labels (we'll use custom title)
ax.set_xlabel("")
ax.set_ylabel("")

# Shading rows
if shade_group_rows:
    for llm, comps in blocks:
        start = new_order.index(comps[0])
        end = new_order.index(comps[-1])
        h = end - start + 1
        ax.add_patch(
            Rectangle(
                (0, start),
                width=n,
                height=h,
                facecolor=row_shade_colors[llm],
                edgecolor='none',
                alpha=row_shade_alpha,
                zorder=0
            )
        )

# Shading columns (lighter)
if shade_group_cols:
    for llm, comps in blocks:
        start = new_order.index(comps[0])
        end = new_order.index(comps[-1])
        w = end - start + 1
        ax.add_patch(
            Rectangle(
                (start, 0),
                width=w,
                height=n,
                facecolor=row_shade_colors[llm],
                edgecolor='none',
                alpha=col_shade_alpha,
                zorder=0
            )
        )

# Redraw heatmap mesh above shading
for spine in ax.spines.values():
    spine.set_visible(False)

# Group separators
if show_group_separators:
    cum = 0
    for llm, comps in blocks:
        size = len(comps)
        cum += size
        if cum < n:
            ax.axhline(cum, color='black', linewidth=2)
            ax.axvline(cum, color='black', linewidth=2)

# Group labels (centered, vertical above; vertical lines already separate)
if annotate_group_labels:
    cum_start = 0
    for llm, comps in blocks:
        size = len(comps)
        midpoint = cum_start + size / 2
        # Column (top) label
        ax.text(
            midpoint,
            n + 1.0,  # slightly below the matrix
            llm,
            ha="center",
            va="top",
            rotation=0,
            fontsize=group_label_fontsize,
            fontweight="bold",
            color=llm_group_palette.get(llm, group_label_color),
            clip_on=False,
        )
        # Row (left) label (rotated primary label)
        ax.text(
            -1.0,
            midpoint,
            llm,
            ha="right",
            va="center",
            rotation=group_label_rotation,  # was 0; rotate 90°
            fontsize=group_label_fontsize,
            fontweight="bold",
            color=llm_group_palette.get(llm, group_label_color),
            clip_on=False,
        )
        cum_start += size

# Adjust colorbar (already shrunk)
cbar = hm.collections[0].colorbar
cbar.ax.set_title("Win rate\n(row > col)", fontsize=9, pad=6)
cbar.ax.tick_params(labelsize=8)

# ax.set_title(f"F1 Weighted Win Rates (Grouped by LLM)", fontsize=14, pad=12)
plt.tight_layout()

# Export
os.makedirs(PATH_EXPORT, exist_ok=True)
out_path = os.path.join(PATH_EXPORT, f"f1_weighted_winrate_heatmap_grouped_refined.pdf")
plt.savefig(out_path, dpi=300, bbox_inches="tight")
print(f"Saved refined grouped heatmap to {out_path} (annotation font size={annot_font_size})")

## EIG Stats

In [None]:
# Path to run directory
# base_path = resolve_project_path("experiments/collaborative/captain_benchmarks/run_combined/run_4o_mapeig_cot_captain")
base_path = resolve_project_path(
    "experiments/collaborative/captain_benchmarks/run_combined/run_4o_llmdecision_captain"
)

# Find all captain.json files in subdirectories
captain_files = glob.glob(os.path.join(base_path, '**/captain/captain.json'), recursive=True)

# Dictionary to store eig values by file
eig_values_by_file = {}
# Initialize list to store data for DataFrame
eig_data_list = []

# Extract eig values from each file
for file_path in captain_files:
    # Get relative path for naming
    rel_path = os.path.relpath(file_path, base_path)

    # Extract round_id from path
    # Use regex to extract the part after 'round_' in the relative path
    match = re.search(r'round_([a-zA-Z0-9]+)', rel_path)
    round_id = match.group(1) if match else None

    with open(file_path, 'r') as f:
        data = json.load(f)

    # Extract eig values, skipping None/null values
    for idx, datum in enumerate(data):
        if 'eig' in datum and datum['eig'] is not None and 'question' in datum and datum['question'] is not None:
            question_text = datum['question']['question']['text'] if datum['question']['question'] and 'text' in datum['question']['question'] else "No question text"
            eig_value = datum['eig']

            eig_questions = datum.get("eig_questions", [])


            if eig_questions is not None:
                if len(eig_questions) != 0:
                    eig_questions = [(q['question']['question']['text'],q['eig'], None) for q in eig_questions]
                    max_eig = max([eq[1] for eq in eig_questions if eq[1] is not None])
                    eig_questions = [(q[0], q[1], q[1] == max_eig) for q in eig_questions]

            # Add to data list
            eig_data_list.append({
                'round_id': round_id,
                'question_idx': idx,
                'question': question_text,
                'eig': eig_value,
                'eig_questions': eig_questions,
            })

# Create DataFrame from the list
model_eig_df = pd.DataFrame(eig_data_list)

model_eig_df

In [None]:
model_eig_df["eig"].hist()

In [None]:
# /////////////////////////////////////////////////
# This cell calculates EIG for human questions (and saves it to notebooks/human_eig_df.csv)
# Caution: This will take 1-2 mins to run if human_eig_df.csv doesn't exist in the notebooks directory
# \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\

# JSON file to pull the code translations of human questions from
input_json_path = resolve_project_path("experiments/collaborative/spotter_benchmarks/o4-mini_CodeSpotterModel_True.json")

def extract_questions_and_boards_to_dataframe(json_path):
    """
    Extracts all questions asked and the board state at the time they were asked from a JSON file
    and returns the data as a pandas DataFrame.

    Args:
        json_path (str): Path to the input JSON file.

    Returns:
        pd.DataFrame: A DataFrame containing the extracted questions and board states.
    """
    if not os.path.exists(json_path):
        raise FileNotFoundError(f"The file {json_path} does not exist.")

    with open(json_path, 'r') as f:
        data = json.load(f)

    extracted_data = []

    for entry in data:
        if "question" in entry and "occTiles" in entry:
            question = entry["question"]
            program = entry["program"]
            board_state = entry["occTiles"]
            answer = entry["answer"]
            true_answer = entry["true_answer"]

            if answer.lower() == "true":
                answer = "yes"
            if answer.lower() == "false":
                answer = "no"

            extracted_data.append({
                "question": question,
                "program": program,
                "board_state": board_state,
                "answer": answer,
                "true_answer": true_answer,
                "correct": answer == true_answer
            })

    return pd.DataFrame(extracted_data)


if os.path.exists('human_eig_df.csv'):
    human_eig_df = pd.read_csv('human_eig_df.csv')
else:
    human_eig_df = extract_questions_and_boards_to_dataframe(input_json_path)
    human_eig_df = human_eig_df[human_eig_df['correct'] == True]

    eig_calculator = EIGCalculator(samples=1000, timeout=15, epsilon=0)

    # Add a new column to store EIG values
    human_eig_df["calculated_eig"] = None

    for idx, row in human_eig_df.iterrows():
            # Create a CodeQuestion instance
            code_question = CodeQuestion(
                question=Question(row["question"]),
                fn_text=row["program"],
                translation_prompt="",
                completion={}
            )

            # Convert board_state to a Board instance
            board = Board.from_occ_tiles(row["board_state"])

            # Calculate EIG
            eig_value = eig_calculator(code_question, board)
            human_eig_df.at[idx, "calculated_eig"] = eig_value

    human_eig_df.to_csv('human_eig_df.csv', index=False)

human_eig_df

In [None]:
import seaborn as sns

import matplotlib.pyplot as plt

# Prepare the data for plotting
plot_data = pd.DataFrame({
    'EIG Values': pd.concat([model_eig_df["eig"], human_eig_df["calculated_eig"]], ignore_index=True),
    'Source': ['model_eig_df'] * len(model_eig_df) + ['human_eig_df'] * len(human_eig_df)
})

# Create a boxplot instead of a scatter plot
sns.boxplot(data=plot_data, x='Source', y='EIG Values', palette='Set2')

# Add labels and title
plt.title('Categorical Scatter Plot of EIG Values')
plt.xlabel('Source')
plt.ylabel('EIG Values')
plt.grid(axis='y', linestyle='--', alpha=0.5)
plt.tight_layout()
plt.show()

# Calculate and print the average EIG values for both distributions
avg_model_eig = model_eig_df["eig"].mean()
avg_human_eig = human_eig_df["calculated_eig"].astype(float).mean()

print(f"Average EIG for eig_df: {avg_model_eig:.4f}")
print(f"Average EIG for output_df: {avg_human_eig:.4f}")