# Spotter Benchmark Analysis

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import numpy as np

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.font_manager as fm
import matplotlib.transforms as mtransforms
from sklearn.metrics import classification_report

from experiments.collaborative.analysis import (
    load_dataset,
    get_gold_answer_dataset,
    MODEL_DISPLAY_NAMES,
    get_spotter_type_short,
)
from battleship.run_spotter_benchmarks import rebuild_summary_from_results
from battleship.utils import PROJECT_ROOT

In [None]:
%config InlineBackend.figure_format = 'retina'

# set seaborn color palette
sns.set_palette("Set2")

# set seaborn style
sns.set_style("whitegrid")
sns.set_context("talk")

In [None]:
from matplotlib import font_manager

# Set the default font to DejaVu Sans
# plt.rcParams['font.family'] = 'DejaVu Sans'

# Print current font family settings
print("Current Font Settings:")
print("=" * 40)
print(f"Default font family: {plt.rcParams['font.family']}")
print(f"Sans-serif fonts: {plt.rcParams['font.sans-serif']}")
print(f"Serif fonts: {plt.rcParams['font.serif']}")
print(f"Monospace fonts: {plt.rcParams['font.monospace']}")
print(f"Cursive fonts: {plt.rcParams['font.cursive']}")
print(f"Fantasy fonts: {plt.rcParams['font.fantasy']}")

In [None]:
EXPERIMENT_NAME = "battleship-final-data"
PATH_DATA = os.path.join("data", EXPERIMENT_NAME)

# PATH_EXPORT = os.path.join(PATH_DATA, "export")
PATH_EXPORT = os.path.join(
    PROJECT_ROOT, "..", "battleship-iclr-2026", "iclr2026", "_figures_staging"
)  # Export directly into the paper draft

df_gold = load_dataset(experiment_path=PATH_DATA, use_gold=True)

## Human results

In [None]:
gold_labels, human_labels = get_gold_answer_dataset(df_gold)
print(len(gold_labels), len(human_labels))

In [None]:
print(classification_report(y_true=gold_labels, y_pred=human_labels))

human_accuracy_baseline = classification_report(y_true=gold_labels, y_pred=human_labels, output_dict=True)["accuracy"]
print(f"Human accuracy baseline: {human_accuracy_baseline:.2%}")

## Modeling results

In [None]:
RUN_IDS = [
    "run_2025_07_11_18_32_51",
    "run_2025_08_22_09_53_20", # GPT-5
]

results = [rebuild_summary_from_results(os.path.join("spotter_benchmarks", run_id)) for run_id in RUN_IDS]
df = pd.concat([pd.DataFrame(result) for result in results]).reset_index(drop=True)

# Add display names and categorizations for analysis
def add_display_fields(df):
    """Add display names and categorizations to the dataframe."""
    # Add spotter type categorization
    df["spotter_type_short"] = df.apply(
        lambda row: get_spotter_type_short(row["spotter_type"], row["use_cot"]), axis=1
    )
    df["spotter_type_short"] = pd.Categorical(
        df["spotter_type_short"],
        categories=["Base", "CoT", "Code", "CoT + Code"],
        ordered=True,
    )

    # Add model display name
    df["llm_display_name"] = df["llm"].map(lambda x: MODEL_DISPLAY_NAMES.get(x, x))

    df["llm_provider"] = df["llm"].map(lambda x: x.split("/")[0] if "/" in x else None)

    # Sort by order in MODEL_DISPLAY_NAMES using categorical
    df["llm_display_name"] = pd.Categorical(
        df["llm_display_name"], categories=[display_name for llm, display_name in MODEL_DISPLAY_NAMES.items() if llm in df["llm"].unique()], ordered=True
    )
    df = df.sort_values(by=["llm_display_name", "spotter_type_short"])

    return df


# Process the dataframe
df = add_display_fields(df)

### Completion status

In [None]:
with pd.option_context('display.max_rows', None):
    count_df = df.groupby(["llm", "spotter_type_short"], observed=False).size().to_frame(name="count")
    display(count_df)

    filtered_count_df = count_df[count_df["count"] < 948]
    if len(filtered_count_df) > 0:
        print(f"Incomplete models:")
        display(filtered_count_df)
    else:
        print("All models complete!")

In [None]:
" ".join(filtered_count_df.reset_index().llm.unique().tolist())

In [None]:
# Create a visualization of completion status
count_df = df.groupby(["llm", "spotter_type_short"], observed=False).size().to_frame(name="count")

# Create a pivot table for heatmap
pivot_df = count_df.reset_index().pivot(index="llm", columns="spotter_type_short", values="count")

# Create completion status (1 for complete, 0 for incomplete)
completion_df = (pivot_df == 948).astype(int)

# Create the heatmap
plt.figure(figsize=(8, 12))
sns.heatmap(
    completion_df,
    annot=pivot_df,  # Show actual counts as annotations
    fmt='d',
    cmap='RdYlGn',
    cbar_kws={'label': 'Completion Status'},
    linewidths=0.5,
    linecolor='white'
)

plt.title('Spotter Benchmark Completion Status\n(Green = Complete [948], Red = Incomplete)', fontsize=14)
plt.xlabel('Spotter Type', fontsize=12)
plt.ylabel('LLM Model', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()


### Answer value distribution

In [None]:
# Check distribution of raw answer text
df["answer_text"].value_counts(dropna=False).plot(kind="bar")

In [None]:
# Create a copy of the dataframe and handle None values in answer_value
df_plot = df.copy()
df_plot["answer_value"] = df_plot["answer_value"].fillna("No Answer")

# Visualize the distribution of answer values by LLM and spotter type
plt.figure(figsize=(12, 6))
sns.countplot(
    data=df_plot,
    x="llm_display_name",
    hue="answer_value",
    order=df["llm_display_name"].cat.categories,
    palette={True: "green", False: "red", "No Answer": "gray"}
)

plt.title("Distribution of Answer Values by LLM")
plt.xlabel("LLM Display Name")
plt.ylabel("Count")
plt.xticks(rotation=90, ha="right")
plt.legend(title="Answer Value", bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout()
plt.show()


### Accuracy

In [None]:
show_legend = False

with sns.axes_style("whitegrid"):

    plt.figure(figsize=(6, 6))
    sns.barplot(
        data=df,
        x="is_correct",
        y="llm_display_name",
        hue="spotter_type_short",
        errorbar=("ci", 95),
        err_kws={
            "color": "gray",
            "linewidth": 1,
        },
        capsize=0.2,
        legend=show_legend,
    )

    plt.axvline(
        human_accuracy_baseline,
        color="#4b4f73",
        linestyle="--",
        linewidth=2.0,
        label="Human Performance",
    )

    plt.ylabel("")
    plt.xlabel("Accuracy")

    plt.xlim(0.5, 1.0)
    # plt.axvline(
    #     0.5,
    #     color="#808080",
    #     linestyle="--",
    #     linewidth=2.0,
    #     label="Random Performance",
    # )

    plt.yticks(fontsize=12)

    if show_legend:
        plt.legend(title="Spotter Models", bbox_to_anchor=(1.05, 1), loc='upper left')

    plt.savefig(
        os.path.join(PATH_EXPORT, "spotter_accuracy_by_model.pdf"),
        bbox_inches="tight",
        dpi=300,
    )

### Spotter Type Performance Across All Models

In [None]:
# Bar plot showing mean accuracy with confidence intervals for each spotter type
# Calculate mean accuracy for each spotter type across all models
spotter_accuracy = df.groupby(["spotter_type_short", "llm_display_name"])["is_correct"].mean().reset_index()

# plt.figure(figsize=(10, 6))
sns.barplot(
    data=spotter_accuracy,
    x="spotter_type_short",
    y="is_correct",
    hue="spotter_type_short",
    errorbar=("ci", 95),
    err_kws={
        "linewidth": 2,
    },
    capsize=0.1,
)
plt.axhline(
    human_accuracy_baseline,
    color="#4b4f73",
    linestyle="--",
    linewidth=2.0,
    label="Human Performance"
)

ax = plt.gca()
trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)

plt.text(
    s="Human Performance",
    x=0.5,
    y=human_accuracy_baseline + 0.01,
    fontsize=14,
    ha="center",
    va="bottom",
    transform=trans,
)
# plt.title("Mean Accuracy by Spotter Type\n(Averaged Across All Models)", fontsize=14)
plt.xlabel("", fontsize=16)
plt.ylabel("Accuracy", fontsize=16)
plt.ylim(0.5, 1.0)
# plt.legend()
# plt.xticks(rotation=45)

# sns.despine()

plt.tight_layout()

plt.savefig(os.path.join(PATH_EXPORT, "spotter_accuracy_overall.pdf"), bbox_inches="tight", dpi=300)

# Print summary statistics
print("Summary Statistics by Spotter Type:")
print("=" * 50)
spotter_stats = df.groupby("spotter_type_short")["is_correct"].agg([
    'count', 'mean', 'std', 'min', 'max'
]).round(3)
spotter_stats['mean_pct'] = (spotter_stats['mean'] * 100).round(1)
print(spotter_stats)

In [None]:
# Pairwise statistical significance testing between spotter types
from scipy import stats
from itertools import combinations
import pandas as pd

print("Pairwise Statistical Comparisons (Mann-Whitney U Test):")
print("=" * 60)

# Get all spotter types
spotter_types = df["spotter_type_short"].cat.categories

# Create a results table for p-values
results_data = []

# Perform pairwise comparisons
for type1, type2 in combinations(spotter_types, 2):
    # Get accuracy data for each spotter type
    data1 = df[df["spotter_type_short"] == type1]["is_correct"]
    data2 = df[df["spotter_type_short"] == type2]["is_correct"]

    # Perform Mann-Whitney U test (non-parametric)
    statistic, p_value = stats.mannwhitneyu(data1, data2, alternative='two-sided')

    # Calculate means and effect size
    mean1 = data1.mean()
    mean2 = data2.mean()
    mean_diff = mean2 - mean1

    # Determine significance level
    if p_value < 0.001:
        significance = "***"
        sig_level = "p < 0.001"
    elif p_value < 0.01:
        significance = "**"
        sig_level = "p < 0.01"
    elif p_value < 0.05:
        significance = "*"
        sig_level = "p < 0.05"
    else:
        significance = ""
        sig_level = "n.s."

    # Store results
    results_data.append({
        'Comparison': f"{type1} vs {type2}",
        'Mean_1': mean1,
        'Mean_2': mean2,
        'Difference': mean_diff,
        'P_value': p_value,
        'Significance': significance,
        'Sig_Level': sig_level
    })

    # Print detailed results
    print(f"{type1} vs {type2}:")
    print(f"  Mean accuracy: {mean1:.3f} vs {mean2:.3f} (diff: {mean_diff:+.3f})")
    print(f"  Sample sizes: {len(data1)} vs {len(data2)}")
    print(f"  p-value: {p_value:.4f} {significance} ({sig_level})")
    print()

# Create a summary table
results_df = pd.DataFrame(results_data)
print("\nSummary Table of Pairwise Comparisons:")
print("=" * 60)
print(results_df.round(4).to_string(index=False))


In [None]:
# Create matrices for significance analysis
n_types = len(spotter_types)
significance_matrix = np.full((n_types, n_types), "", dtype=object)
mean_diff_matrix = np.zeros((n_types, n_types))

# Fill the matrices
for i, type1 in enumerate(spotter_types):
    for j, type2 in enumerate(spotter_types):
        if i != j:  # Don't compare a type with itself
            # Find the comparison in our results
            comparison1 = f"{type1} vs {type2}"
            comparison2 = f"{type2} vs {type1}"

            # Find the result (either direction)
            result = results_df[
                (results_df['Comparison'] == comparison1) |
                (results_df['Comparison'] == comparison2)
            ]

            if not result.empty:
                sig = result.iloc[0]['Significance']
                significance_matrix[i, j] = sig

                # Use the difference from the perspective of j vs i (column vs row)
                if result.iloc[0]['Comparison'] == comparison2:
                    mean_diff_matrix[i, j] = result.iloc[0]['Difference']
                else:
                    mean_diff_matrix[i, j] = -result.iloc[0]['Difference']

# Reverse the order of spotter types (flip rows and columns)
spotter_types_reversed = spotter_types[::-1]

# Create reversed indices mapping
reverse_idx = {i: n_types - 1 - i for i in range(n_types)}

# Create new matrices with reversed order
mean_diff_matrix_rev = np.zeros((n_types, n_types))
significance_matrix_rev = np.full((n_types, n_types), "", dtype=object)

for i in range(n_types):
    for j in range(n_types):
        rev_i = reverse_idx[i]
        rev_j = reverse_idx[j]
        mean_diff_matrix_rev[i, j] = mean_diff_matrix[rev_i, rev_j]
        significance_matrix_rev[i, j] = significance_matrix[rev_i, rev_j]

# Create annotations that combine significance and mean difference
annotations = np.full((n_types, n_types), "", dtype=object)
for i in range(n_types):
    for j in range(n_types):
        if i == j:
            # Diagonal indicator - show the spotter type name
            annotations[i, j] = f"--"
        else:
            sig = significance_matrix_rev[i, j]
            diff = mean_diff_matrix_rev[i, j]
            annotations[i, j] = f"{diff:+.3f}\n{sig}" if sig else f"{diff:+.3f}"

# Create a mask for the lower triangle (excluding diagonal to preserve it)
mask = np.tril(np.ones_like(mean_diff_matrix_rev, dtype=bool), k=-1)

# Create the figure with more space for the legend
fig, ax = plt.subplots(1, 1, figsize=(10, 8))

# Create the upper-triangular heatmap
sns.heatmap(
    mean_diff_matrix_rev,
    mask=mask,
    annot=annotations,
    fmt='',
    xticklabels=spotter_types_reversed,
    yticklabels=spotter_types_reversed,
    cmap='RdBu',
    center=0,
    ax=ax,
    cbar_kws={'label': 'Mean Accuracy Difference'},
    square=True
)

ax.set_title("Spotter Accuracy Differences", fontsize=16)
ax.set_xlabel('', fontsize=12)
ax.set_ylabel('', fontsize=12)

# Add legend text
legend_text = """Statistical Significance:
*** p < 0.001 (highly significant)
**  p < 0.01 (very significant)
*   p < 0.05 (significant)
    p ≥ 0.05 (not significant)"""

# Position the legend in the lower left corner of the plot
ax.text(0.02, 0.02, legend_text, transform=ax.transAxes, fontsize=12,
        verticalalignment='bottom', bbox=dict(facecolor='white', alpha=0.9, edgecolor='gray'), fontfamily='monospace')

plt.tight_layout()

plt.savefig(
    os.path.join(PATH_EXPORT, "spotter_accuracy_differences.pdf"),
    bbox_inches="tight",
    dpi=300,
)


In [None]:
df

In [None]:
# Shared helper to build cleaned human_df from df_gold
from typing import Dict

LABEL_MAP_DEFAULT: Dict[str, str] = {
    "Discourse": "gold_discourse",
    "Stateful": "gold_stateful",
    "Vague": "gold_vague",
    "Ambiguous": "gold_ambiguous",
}

def build_human_df(df_gold: pd.DataFrame, label_map: Dict[str, str] | None = None) -> pd.DataFrame:
    """
    Return a cleaned human answers dataframe with parsed gold/human answers and a
    boolean flag `_has_any_label` indicating if any gold label is marked true.

    Parameters:
      - df_gold: raw dataset containing human answers and gold labels
      - label_map: mapping from friendly label name to boolean column name
                   (defaults to LABEL_MAP_DEFAULT)
    """
    from battleship.agents import Answer

    if label_map is None:
        label_map = LABEL_MAP_DEFAULT
    gold_label_cols = list(label_map.values())

    human_df = df_gold.copy()
    human_df = human_df[human_df["messageType"] == "answer"].copy()
    human_df = human_df[~pd.isna(human_df["gold_answer"]) & ~pd.isna(human_df["messageText"])].copy()

    # Parse answers
    human_df["gold_answer_value"] = human_df["gold_answer"].apply(Answer.parse)
    human_df["human_answer_value"] = human_df["messageText"].apply(Answer.parse)
    # Keep only rows where human parse succeeded
    human_df = human_df[human_df["human_answer_value"].isin([True, False])].copy()

    # Determine rows that have any gold labels set
    human_df["_has_any_label"] = human_df[gold_label_cols].fillna(False).any(axis=1)
    return human_df

In [None]:
# Helper to build long-form dataset for question type analysis (incl. Simple/Complex)
from typing import Tuple, List

def build_combined_long_df(df: pd.DataFrame, df_gold: pd.DataFrame, NA_LABEL: str = "Simple") -> Tuple[pd.DataFrame, List[str], List[str]]:
    """
    Build a combined long-form dataframe with per-question correctness for:
      - Models (from df)
      - Humans (from df_gold)
    and gold-label buckets: Overall, Simple (N/A), Complex (any label True), and each specific label.

    Returns:
      combined_long_df, label_order, spotter_order
    """
    LABEL_MAP = {
        NA_LABEL: None,
        "Discourse": "gold_discourse",
        "Stateful": "gold_stateful",
        "Vague": "gold_vague",
        "Ambiguous": "gold_ambiguous",
    }
    gold_label_cols = [v for v in LABEL_MAP.values() if v is not None]

    # -----------------------------
    # Models
    # -----------------------------
    rows_with_any_label = df[gold_label_cols].fillna(False).any(axis=1)
    model_records = []
    for idx, row in df.iterrows():
        # Overall always
        model_records.append({
            "gold_label": "Overall",
            "spotter_type_short": row["spotter_type_short"],
            "llm_display_name": row["llm_display_name"],
            "is_correct": row["is_correct"],
        })
        has_any = bool(rows_with_any_label.loc[idx])
        if not has_any:
            # Simple (N/A)
            model_records.append({
                "gold_label": NA_LABEL,
                "spotter_type_short": row["spotter_type_short"],
                "llm_display_name": row["llm_display_name"],
                "is_correct": row["is_correct"],
            })
        else:
            # Complex aggregate bucket
            model_records.append({
                "gold_label": "Complex",
                "spotter_type_short": row["spotter_type_short"],
                "llm_display_name": row["llm_display_name"],
                "is_correct": row["is_correct"],
            })
            # Specific labels
            for nice_name, col in LABEL_MAP.items():
                if col is None:
                    continue
                val = row[col]
                if pd.notna(val) and bool(val):
                    model_records.append({
                        "gold_label": nice_name,
                        "spotter_type_short": row["spotter_type_short"],
                        "llm_display_name": row["llm_display_name"],
                        "is_correct": row["is_correct"],
                    })
    model_long_df = pd.DataFrame.from_records(model_records)

    # -----------------------------
    # Humans (via shared helper)
    # -----------------------------
    human_df = build_human_df(df_gold, {k: v for k, v in LABEL_MAP.items() if v is not None})
    human_rows_with_any_label = human_df["_has_any_label"]

    human_records = []
    for idx, row in human_df.iterrows():
        is_correct = bool(row["human_answer_value"] == row["gold_answer_value"])
        # Overall
        human_records.append({
            "gold_label": "Overall",
            "spotter_type_short": "Human",
            "llm_display_name": "Human",
            "is_correct": is_correct,
        })
        has_any = bool(human_rows_with_any_label.loc[idx])
        if not has_any:
            # Simple
            human_records.append({
                "gold_label": NA_LABEL,
                "spotter_type_short": "Human",
                "llm_display_name": "Human",
                "is_correct": is_correct,
            })
        else:
            # Complex
            human_records.append({
                "gold_label": "Complex",
                "spotter_type_short": "Human",
                "llm_display_name": "Human",
                "is_correct": is_correct,
            })
            for nice_name, col in LABEL_MAP.items():
                if col is None:
                    continue
                val = row[col]
                if pd.notna(val) and bool(val):
                    human_records.append({
                        "gold_label": nice_name,
                        "spotter_type_short": "Human",
                        "llm_display_name": "Human",
                        "is_correct": is_correct,
                    })
    human_long_df = pd.DataFrame.from_records(human_records)

    # -----------------------------
    # Combine and order
    # -----------------------------
    combined_long_df = pd.concat([model_long_df, human_long_df], ignore_index=True)

    label_order = ["Overall", NA_LABEL, "Complex", "Discourse", "Stateful", "Vague", "Ambiguous"]
    combined_long_df["gold_label"] = pd.Categorical(combined_long_df["gold_label"], categories=label_order, ordered=True)

    existing_order = list(df["spotter_type_short"].cat.categories)
    spotter_order = existing_order + ["Human"]
    combined_long_df["spotter_type_short"] = pd.Categorical(combined_long_df["spotter_type_short"], categories=spotter_order, ordered=True)

    return combined_long_df, label_order, spotter_order

### Performance by Gold Label (Across All Models)


In [None]:
from battleship.agents import Answer

NA_LABEL = "Simple"
show_legend = False

# Build combined dataset via helper (adds Complex bucket)
combined_long_df, label_order, spotter_order = build_combined_long_df(df, df_gold, NA_LABEL=NA_LABEL)

# Plot grouped bar chart with per-question 95% CIs (no per-model aggregation)
plt.figure(figsize=(8, 6))
with sns.axes_style("whitegrid"):
    sns.barplot(
        data=combined_long_df,
        x="gold_label",
        y="is_correct",
        hue="spotter_type_short",
        errorbar=("ci", 95),
        err_kws={"linewidth": 1.5},
        capsize=0.1,
        legend=show_legend,
    )

# plt.title("Accuracy by Question Type (Per-question 95% CI, incl. Human)", fontsize=14)
plt.xlabel("")
plt.ylabel("Accuracy")
plt.xticks(rotation=0)
if show_legend:
    plt.legend(title="Spotter Models", bbox_to_anchor=(1.05, 1), loc="upper left")

sns.despine()
plt.tight_layout()

# Save figure
plt.savefig(
    os.path.join(PATH_EXPORT, "spotter_accuracy_by_gold_label.pdf"),
    bbox_inches="tight",
    dpi=300,
)

# Summary table for quick inspection (kept for debugging)
summary_table = (
    combined_long_df.groupby([
        "gold_label", "spotter_type_short"
    ], observed=False)["is_correct"].agg(["count", "mean", "std"]).round(3)
)
summary_table["mean_pct"] = (summary_table["mean"] * 100).round(1)
print(summary_table)

In [None]:
# Appendix: Mean accuracy by question type with MultiIndex rows (Provider, LLM, Spotter Type)
# Supersedes prior master_table
# Additions:
# - Include llm_provider as top-level index
# - Include Valid column (prop of non-missing answer_text)
# - Rename gold_label axis to 'Question Type' to avoid LaTeX issues
# - Preserve original LLM ordering
# - Column order places Complex immediately after Simple

# Build combined dataset (ensures Complex bucket exists and consistent ordering)
combined_long_df, label_order, spotter_order = build_combined_long_df(df, df_gold, NA_LABEL="Simple")

# Reorder label_order to ensure Complex is right after Simple
if "Simple" in label_order and "Complex" in label_order:
    lbls = [l for l in label_order if l not in ("Simple", "Complex")]
    label_order = ["Overall", "Simple", "Complex"] + [l for l in lbls if l not in ("Overall",)]

# 1. Map providers to combined_long_df (Human gets its own provider label)
provider_map = (
    df.drop_duplicates(subset=["llm_display_name"])
      .set_index("llm_display_name")["llm_provider"].to_dict()
)
combined_long_df = combined_long_df.copy()
combined_long_df["llm_provider"] = combined_long_df["llm_display_name"].map(provider_map)
combined_long_df.loc[combined_long_df["llm_display_name"] == "Human", "llm_provider"] = "Human"

# 2. Build accuracy pivot: mean correctness per (provider, llm, spotter type) across question types
appendix_accuracy_table = (
    combined_long_df
    .groupby(["llm_provider", "llm_display_name", "spotter_type_short", "gold_label"], observed=True)["is_correct"]
    .mean()
    .unstack("gold_label")
)

# 3. Ensure column order matches adjusted label_order and that Complex follows Simple
appendix_accuracy_table = appendix_accuracy_table.reindex(columns=[c for c in label_order if c in appendix_accuracy_table.columns])

# 4. Compute Valid metric from original df (models). Human assumed fully valid (1.0)
valid_series = (
    df.groupby(["llm_provider", "llm_display_name", "spotter_type_short"], observed=True)["answer_text"]
      .apply(lambda x: 1 - x.isna().mean())
)
# Append Valid column (align on index); fill Human rows with 1.0
appendix_accuracy_table["Valid"] = valid_series
appendix_accuracy_table["Valid"].fillna(1.0, inplace=True)

# 5. Round all numeric values to 3 decimals
appendix_accuracy_table = appendix_accuracy_table.round(3)

# 6. Set index level names & rename columns axis
appendix_accuracy_table.index = appendix_accuracy_table.index.set_names(["Provider", "LLM", "Spotter Type"])
appendix_accuracy_table.columns.name = "Question Type"

# 7. Sort index to respect categorical LLM ordering (already preserved) and spotter type order
appendix_accuracy_table = appendix_accuracy_table.sort_index(level=["Provider", "LLM", "Spotter Type"], sort_remaining=False)

print("Appendix Accuracy Table (mean accuracy by Question Type with Valid):")
print(appendix_accuracy_table)

# 8. Export to LaTeX (basic)
appendix_latex_path = os.path.join(PATH_EXPORT, "spotter_master_table_basic.tex")
appendix_accuracy_table.to_latex(
    appendix_latex_path,
    float_format=lambda x: f"{x:.3f}",
    multirow=True,
    # escape=True,
)
print(f"LaTeX table written to: {appendix_latex_path}")

##########################################################################
# Custom LaTeX export with provider as group header row
##########################################################################
def _latex_escape(text: str) -> str:
    repl = {
        "\\": r"\textbackslash{}",
        "&": r"\&",
        "%": r"\%",
        "$": r"\$",
        "#": r"\#",
        "_": r"\_",
        "{": r"\{",
        "}": r"\}",
        "~": r"\textasciitilde{}",
        "^": r"\textasciicircum{}",
    }
    out = str(text)
    for k, v in repl.items():
        out = out.replace(k, v)
    return out


# Build LaTeX with provider as pseudo-row
provider_grouped_path = os.path.join(PATH_EXPORT, "spotter_master_table.tex")

row_end = "\\\\"  # literal \\ in output

# Pretty names for providers
PROVIDER_DISPLAY = {
    "openai": "OpenAI",
    "anthropic": "Anthropic",
    "google": "Google",
    "meta-llama": "Meta Llama",
    "Human": "Human",
}

df_exp = appendix_accuracy_table.copy()
metric_cols = list(df_exp.columns)
num_metrics = len(metric_cols)
# We'll print columns: LLM, Spotter Type, then metric columns
ncols_total = 2 + num_metrics

lines = []
colspec = "ll" + ("r" * num_metrics)
# Note: Requires \usepackage[table]{xcolor}
lines.append(f"\\begin{{tabular}}{{{colspec}}}")
lines.append("\\toprule")
# Header row
header = ["LLM", "Spotter Type"] + [_latex_escape(c) for c in metric_cols]
lines.append(" & ".join(header) + f" {row_end}")
lines.append("\\midrule")

# Provider order as they appear in the index
provider_order = list(df_exp.index.get_level_values("Provider").unique())

# Iterate by provider in order
for p_idx, provider in enumerate(provider_order):
    df_grp = df_exp.xs(provider, level="Provider", drop_level=False)
    prov_disp = PROVIDER_DISPLAY.get(provider, provider)
    prov = _latex_escape(prov_disp)
    # Subtle full-width band for provider header row
    lines.append(
        f"\\rowcolor{{gray!40}} \\multicolumn{{{ncols_total}}}{{l}}{{\\textbf{{{prov}}}}} {row_end}"
    )

    # Drop Provider level for iteration over (LLM, Spotter)
    df_sub = df_grp.droplevel("Provider")

    # Alternate shading per LLM block
    current_llm = None
    shade = True  # reset per provider
    for (llm, spotter), row in df_sub.iterrows():
        if llm != current_llm:
            current_llm = llm
            shade = not shade
        prefix = "\\rowcolor{gray!10} " if shade else ""
        llm_s = (
            _latex_escape(llm) if shade and llm == current_llm else _latex_escape(llm)
        )
        # Avoid repeating LLM name in subsequent rows of same block
        llm_s = (
            llm_s
            if row.name[0] != current_llm or spotter == df_sub.loc[current_llm].index[0]
            else ""
        )
        spot_s = _latex_escape(str(spotter))
        vals = [f"{v:.3f}" if pd.notna(v) else "" for v in row.values]
        lines.append(prefix + " & ".join([llm_s, spot_s] + vals) + f" {row_end}")

    # # Separate provider groups with a midrule except after the last group
    # if p_idx < len(provider_order) - 1:
    #     lines.append("\\midrule")

lines.append("\\bottomrule")
lines.append("\\end{tabular}")

with open(provider_grouped_path, "w") as f:
    f.write("\n".join(lines))

print(f"LaTeX table with provider headers written to: {provider_grouped_path}")

In [None]:
# Gold label prevalence (share of human answers with each gold label)
# Note: Multilabel, so percentages can sum to >100%. Includes the N/A bucket for none.

# Config
NA_LABEL = "Simple"
LABEL_MAP = {
    "Discourse": "gold_discourse",
    "Stateful": "gold_stateful",
    "Vague": "gold_vague",
    "Ambiguous": "gold_ambiguous",
}

gold_label_cols = list(LABEL_MAP.values())

# Build human_df via shared helper (self-contained)
human_df = build_human_df(df_gold, LABEL_MAP)

label_prevalence_records = []

total_answers = len(human_df)
any_label_mask = human_df["_has_any_label"]

# Overall (all questions)
overall_count = int(total_answers)
label_prevalence_records.append(
    {
        "gold_label": "Overall",
        "count": overall_count,
        "percent": 100.00,
    }
)

# N/A first (Simple)
na_count = int((~any_label_mask).sum())
label_prevalence_records.append(
    {
        "gold_label": NA_LABEL,
        "count": na_count,
        "percent": round(100.0 * na_count / total_answers, 1) if total_answers else 0.0,
    }
)

# Other labels
for nice_name, col in LABEL_MAP.items():
    count = int(human_df[col].fillna(False).sum())
    label_prevalence_records.append(
        {
            "gold_label": nice_name,
            "count": count,
            "percent": round(100.0 * count / total_answers, 1) if total_answers else 0.0,
        }
    )

prevalence_df = pd.DataFrame(label_prevalence_records)
# Order using desired label order
order = [lbl for lbl in ["Overall", NA_LABEL, "Discourse", "Stateful", "Vague", "Ambiguous"] if lbl in prevalence_df["gold_label"].unique()]
prevalence_df["gold_label"] = pd.Categorical(prevalence_df["gold_label"], categories=order, ordered=True)
prevalence_df = prevalence_df.sort_values("gold_label").reset_index(drop=True)

print("\nGold Label Prevalence (human data):")
print(prevalence_df.to_string(index=False))

# Save prevalence to LaTeX with 1 decimal precision
latex_str = prevalence_df.to_latex(index=False, float_format=lambda x: f"{x:.1f}")
with open(os.path.join(PATH_EXPORT, "gold_label_prevalence.tex"), "w") as f:
    f.write(latex_str)

In [None]:
SIMPLE_LABEL = "Simple"
COMPLEX_LABEL = "Complex"
show_legend = False

# Build combined dataset (ensures Complex bucket exists and consistent ordering)
combined_long_df, label_order, _ = build_combined_long_df(
    df, df_gold, NA_LABEL=SIMPLE_LABEL
)

# Collapse to Simple vs Complex only
collapsed_long_df = combined_long_df[
    combined_long_df["gold_label"].isin([SIMPLE_LABEL, COMPLEX_LABEL])
].copy()
collapsed_long_df.rename(columns={"gold_label": "difficulty"}, inplace=True)

# Order categories
difficulty_order = [SIMPLE_LABEL, COMPLEX_LABEL]
collapsed_long_df["difficulty"] = pd.Categorical(
    collapsed_long_df["difficulty"], categories=difficulty_order, ordered=True
)

existing_order = list(df["spotter_type_short"].cat.categories)
spotter_order = existing_order + ["Human"]
collapsed_long_df["spotter_type_short"] = pd.Categorical(
    collapsed_long_df["spotter_type_short"], categories=spotter_order, ordered=True
)

# Point plot with tuned styling
plt.figure(figsize=(4, 6))
with sns.axes_style("whitegrid"):
    ax = sns.pointplot(
        data=collapsed_long_df,
        x="difficulty",
        y="is_correct",
        hue="spotter_type_short",
        errorbar=("ci", 95),
        err_kws={"linewidth": 1.5},
        linewidth=3,
        linestyle="-",
        markers="o",
        markersize=5,
        capsize=0.1,
        legend=show_legend,
    )

ax.set_ylim(0.5, 1.0)
ax.set_xlabel("")
ax.set_ylabel("Accuracy")

if show_legend:
    ax.legend(title="Spotter Models", bbox_to_anchor=(1.05, 1), loc="upper left")

# sns.despine()
plt.tight_layout()

# Save figure
plt.savefig(
    os.path.join(PATH_EXPORT, "spotter_accuracy_simple_vs_complex.pdf"),
    bbox_inches="tight",
    dpi=300,
)

# Print a compact summary table
summary_table = (
    collapsed_long_df
    .groupby(["difficulty", "spotter_type_short"], observed=False)["is_correct"]
    .agg(["count", "mean", "std"]).round(3)
)
summary_table["mean_pct"] = (summary_table["mean"] * 100).round(1)
print(summary_table)