# Analysis pipeline for Prolific data

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from battleship.board import Board
from analysis import load_dataset

In [None]:
%config InlineBackend.figure_format = 'retina'

# set seaborn color palette
sns.set_palette("Set2")

# set seaborn style
sns.set_style("white")
sns.set_context("talk")

In [None]:
# EXPERIMENT_NAME = "battleship-2024-10-03-19-28-28"
# EXPERIMENT_NAME = "battleship-pilot-v2"
EXPERIMENT_NAME = "battleship-final-data"

PATH_DATA = os.path.join("data", EXPERIMENT_NAME)
PATH_EXPORT = os.path.join(PATH_DATA, "export")
PATH_BONUS_EXPORT = os.path.join(PATH_EXPORT, f"{EXPERIMENT_NAME}-bonus.csv")
os.makedirs(PATH_EXPORT, exist_ok=True)

In [None]:
df = load_dataset(PATH_DATA, use_gold=True, drop_incomplete=False)

In [None]:
df.columns

In [None]:
df["messageType"].value_counts()

In [None]:
df["name"].value_counts()

In [None]:
df["board_id"].value_counts()

In [None]:
df["messageText"].value_counts()

In [None]:
df.groupby(["board_id"])["roundID"].nunique()

# Visualizations

## Hits

In [None]:
sns.lineplot(data=df, x="index", y="hits", hue="pairID", style="board_id")
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
# Calculate hits percentage for visualization purposes
df["hits_pct"] = df["hits"] / df["total_ship_tiles"]

sns.lineplot(data=df, x="index", y="hits_pct", hue="pairID", style="board_id")
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="hits",
    hue="pairID",
    # style="board_id"
)
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

## Number of moves to win

In [None]:
df_move_counts = (
    df[(df["messageType"] == "move")]
    .groupby(["pairID", "board_id"])
    .size()
    .to_frame("move_count")
)
df_move_counts

In [None]:
df_question_counts = df[(df["messageType"] == "question")]
df_question_counts = (
    df_question_counts.groupby(["pairID", "board_id"]).size().to_frame("question_count")
)
df_question_counts

In [None]:
df_counts = df_move_counts.join(df_question_counts)
# replace null values with 0
df_counts = df_counts.fillna(0)
df_counts["question_count"] = df_counts["question_count"].astype(int)
df_counts = df_counts.sort_values(["pairID", "board_id"]).reset_index(drop=False)
df_counts

In [None]:
with sns.plotting_context("talk"), sns.axes_style("whitegrid"):

    sns.boxplot(
        data=df_counts,
        y="move_count",
        hue="pairID",
        hue_order=df_counts["pairID"].unique(),
    )

    plt.ylabel("Moves per board")

    # move legend outside of plot
    plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.stripplot(
    data=df_counts,
    x="question_count",
    y="move_count",
    hue="pairID",
    hue_order=df_counts["pairID"].unique(),
    size=10.0,
    jitter=0.2,
)

plt.xlabel("Questions")
plt.ylabel("Moves")

# move legend outside of plot
plt.legend(title="Participant pair ID", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.title(f"Questions asked vs. moves")

In [None]:
sns.regplot(
    data=df_counts,
    x="question_count",
    y="move_count",
)

plt.xlabel("Questions")
plt.ylabel("Moves")

plt.title(f"Questions asked vs. moves")

# Bonuses

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="bonus",
    hue="pairID",
    style="board_id",
)

plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.xlabel("Action #")
plt.ylabel("Bonus ($)")
plt.title(f"Bonus over time")

In [None]:
# Group by pairID and board_id, and get the highest-index stage for each group
df_final_stage = df.loc[
    df.groupby(["pairID", "board_id"])["index"].idxmax(),
    [
        "pairID",
        "gameID",
        "roundID",
        "board_id",
        "bonus",
        "hits_pct",
        "precision",
        "recall",
        "f1_score",
    ],
]

# drop outliers - probably a bug
# df_final_stage = df_final_stage[df_final_stage["bonus"] < 5.0]
df_final_stage.loc[df_final_stage["bonus"] > 5.0, "bonus"] = np.nan

df_final_stage

In [None]:
# sns.displot(data=df_final_stage, x="bonus", hue="pairID", kind="kde", fill=True)
sns.displot(
    data=df_final_stage,
    x="bonus",
    # hue="pairID",
    kind="hist",
    # multiple="stack"
)
plt.xlabel("Bonus ($)")
plt.title("Bonus distribution")

In [None]:
# sns.displot(data=df_final_stage, x="bonus", hue="pairID", kind="kde", fill=True)
sns.displot(
    data=df_final_stage,
    x="hits_pct",
    # hue="pairID",
    kind="hist",
    # multiple="stack"
)

In [None]:
df_final_stage["bonus"].describe()

In [None]:
df_counts

In [None]:
df_counts_with_bonus = df_counts.merge(
    df_final_stage, on=["pairID", "board_id"], how="left"
)
df_counts_with_bonus

In [None]:
sns.stripplot(
    data=df_counts_with_bonus,
    x="question_count",
    y="bonus",
    hue="pairID",
    hue_order=df_counts_with_bonus["pairID"].unique(),
    size=10.0,
    jitter=0.2,
)

# move legend outside of plot
plt.legend(title="Participant pair ID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.regplot(
    data=df_counts_with_bonus,
    x="question_count",
    y="bonus",
)

In [None]:
df_final_stage

In [None]:
sns.scatterplot(
    data=df_counts_with_bonus, x="question_count", y="f1_score", hue="pairID"
)

In [None]:
import statsmodels.api as sm
import statsmodels.formula.api as smf
from scipy import stats

x_col = "question_count"
y_col = "f1_score"
# y_col = "move_count"
group_col = "board_id"

# Reset index to avoid multi-index issues
_df = df_counts_with_bonus.reset_index(drop=True)
_df["group"] = 1

# Fit the linear mixed effects model
# vcf = {"board_id": "0 + C(board_id)", "pairID": "0 + C(pairID)"}  # formula
# model = sm.MixedLM.from_formula(
#     "move_count ~ question_count",
#     groups="group",
#     vc_formula=vcf,
#     re_formula="~board_id",
#     data=_df,
# )
model = smf.mixedlm(
    f"{y_col} ~ {x_col}",
    _df,
    groups=_df[group_col],
)
result = model.fit()

# Print the summary of the model
print(result.summary())

# Calculate R-squared values
# For mixed effects models, we can calculate both marginal and conditional R-squared
# Marginal R-squared: variance explained by fixed effects
# Conditional R-squared: variance explained by both fixed and random effects

# Get predicted values
y = _df[y_col]
y_pred = result.predict()

# Calculate total sum of squares
tss = np.sum((y - np.mean(y)) ** 2)

# Calculate residual sum of squares
rss = np.sum((y - y_pred) ** 2)

# Calculate R-squared
r2 = 1 - (rss / tss)

print(f"\nR-squared: {r2:.4f}")

# Compute p-value using likelihood ratio test
# Fit null model (intercept only)
null_model = smf.mixedlm(
    f"{y_col} ~ 1",  # Only intercept
    _df,
    groups=_df[group_col],
)
null_result = null_model.fit()

# Calculate likelihood ratio test statistic
lr_stat = 2 * (
    result.llf - null_result.llf
)  # 2 * (log likelihood full - log likelihood null)

# For mixed effects models, the degrees of freedom is the difference in number of fixed effects parameters
# Full model has 2 fixed effects (intercept and question_count), null model has 1 (intercept)
# Calculate p-value
p_value = 1 - stats.chi2.cdf(lr_stat, 1)

print(f"P-value: {p_value:.4f}")

In [None]:
with sns.plotting_context("talk"), sns.axes_style("white"):

    # Create figure and axis
    plt.figure(figsize=(8, 5))

    # Add scatter plot with colors for each pair
    sns.scatterplot(
        data=_df,
        x=x_col,
        y=y_col,
        hue="pairID",
        alpha=0.8,
        legend=False,
        edgecolor=None,  # Remove borders from dots
    )

    sns.despine()

    # Add a single regression line for all data points
    sns.regplot(
        data=_df,
        x=x_col,
        y=y_col,
        scatter=False,
        color="#007bff",
        line_kws={"linewidth": 2, "linestyle": "--"},
    )

    # Set x-axis limits and ticks with padding
    plt.xlim(-0.5, 15.5)  # Add padding on both sides
    plt.xticks(range(0, 16, 1))  # Keep the same tick marks

    # Add R-squared and p-value information to the plot
    # Position text dynamically at 70% width and 25% height of the plot
    ax = plt.gca()
    x_min, x_max = ax.get_xlim()
    y_min, y_max = ax.get_ylim()
    # x_pos = x_min + 0.05 * (x_max - x_min)
    x_pos = 0.0
    y_pos = y_max - 0.1 * (y_max - y_min)

    r2_text = f"$R^2 = {r2:.3f}$"
    p_value_text = f"$p = {p_value:.3f}$"
    stats_text = f"{r2_text}\n{p_value_text}"
    plt.text(
        x_pos,
        y_pos,
        stats_text,
        fontsize=12,
        bbox=dict(facecolor="white", alpha=0.5),
        horizontalalignment="left",
    )

    # Adjust the legend and labels
    plt.xlabel("Number of questions asked")
    # plt.ylabel("F1 Score")

    plt.savefig(
        os.path.join(PATH_EXPORT, f"{y_col}_vs_{x_col}.pdf"),
        bbox_inches="tight",
        dpi=300,
    )

In [None]:
x_col = "question_count"
y_col = "f1_score"

sns.regplot(
    data=df_counts_with_bonus.groupby("pairID")[[x_col, y_col]].mean().reset_index(),
    x=x_col,
    y=y_col,
)

In [None]:
sns.regplot(data=df_counts_with_bonus, x="move_count", y="f1_score")

In [None]:
df_counts_with_bonus.query("precision == 1.0 and f1_score < 1.0")

In [None]:
import statsmodels.api as sm

In [None]:
df_counts_with_bonus.loc[df_counts_with_bonus["f1_score"].isnull()]

### Export final bonus information

In [None]:
# PLAYER_COLUMNS = ["gameID", "participantIdentifier"]
# df_final_stage_export = df_final_stage.merge(df_player[PLAYER_COLUMNS], on="gameID")

# df_final_stage_export = df_final_stage_export.groupby("participantIdentifier")[
#     "bonus"
# ].sum()
# df_final_stage_export.to_csv(PATH_BONUS_EXPORT, header=False, index=False)
# display(df_final_stage_export)

## Timing

In [None]:
df["cumulativeStageTime"] = (
    df.sort_values("index")
    .groupby(["pairID", "roundID"])["messageTime"]
    .cumsum()
    .div(1000)
)
df

In [None]:
with sns.plotting_context("talk"), sns.axes_style("whitegrid"):
    sns.lineplot(
        data=df,
        x="index",
        y="cumulativeStageTime",
        hue="pairID",
        style="board_id",
    )

    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
# Group by pairID and board_id, and get the highest-index stage for each group
df_final_stage_time = df.loc[
    df.groupby(["pairID", "board_id"])["index"].idxmax(),
    ["pairID", "board_id", "cumulativeStageTime"],
]

df_final_stage_time

In [None]:
sns.barplot(
    data=df_final_stage_time, y="cumulativeStageTime", x="board_id", hue="pairID"
)

In [None]:
pd.set_option("display.max_columns", None)
df[df["messageTime"].isna()]

# Timeline visualization
Shows the timeline of moves and questions for each game.

In [None]:
Board.from_occ_tiles(_df["occTiles"].iloc[20])

In [None]:
_df = df[df["board_id"] == "B02"].sort_values(["pairID", "board_id"])
first_n_pairs = _df["pairID"].unique()[:3]
_df = _df[_df["pairID"].isin(first_n_pairs)]

g = sns.relplot(
    kind="line",
    row="board_id",
    col="pairID",
    # col="board_id",
    # row="pairID",
    aspect=2.0,
    data=_df,
    x="index",
    y="hits_pct",
    hue="pairID",
    linewidth=6,
)

# Plot a marker for each question
g.map_dataframe(
    lambda data, **kws: sns.scatterplot(
        data=data[data["messageType"] == "question"],
        x="index",
        y="hits_pct",
        s=10,
        marker="o",
        color="black",
        zorder=10,  # Set zorder to be on top
    ),
    board_id="board_id",
    pairID="pairID",
)


for (board_id, pairID), ax in g.axes_dict.items():
    y_max, y_offset = -np.inf, 0.05
    for _, row in (
        df[
            (df["messageType"] == "question")
            & (df["board_id"] == board_id)
            & (df["pairID"] == pairID)
        ]
        .sort_values("index", ascending=True)
        .iterrows()
    ):
        y = row["hits_pct"]
        y = max(y, y_max + y_offset)
        y_max = y

        if y > row["hits_pct"]:
            ax.plot(
                [row["index"], row["index"]],
                [y, row["hits_pct"]],
                color="gray",
                linestyle="--",
                linewidth=1,
                alpha=0.5,
            )

        ax.text(
            row["index"],
            y,
            row["messageText"],
            horizontalalignment="left",
            size=10,
            color="black",
        )

g.set_axis_labels("Action #", "Hits %")
g.set_titles(col_template="Pair ID: {col_name}", row_template="Board ID: {row_name}")

plt.savefig(
    os.path.join(PATH_EXPORT, f"hits_pct_vs_questions.pdf"),
    bbox_inches="tight",
    dpi=300,
)

In [None]:
Board.from_trial_id("B02").to_figure().savefig(
    os.path.join(PATH_EXPORT, "board_B02.pdf"),
    bbox_inches="tight",
    dpi=300,
)

In [None]:
# Prototype: first 3 pairs with timelines, question markers, and board thumbnails
import io, base64
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.offsetbox import OffsetImage, AnnotationBbox

from battleship.board import Board

# Ensure required columns exist
_df = df.copy()

# hits_pct fallback if needed
if "hits_pct" not in _df.columns:

    def compute_hits_pct(row):
        try:
            occ_board = Board.from_occ_tiles(row["occTiles"])
            hits = int(np.sum(occ_board.board > Board.water))
            # prefer trueTiles if available; else compute by board_id
            if "trueTiles" in _df.columns and isinstance(
                row.get("trueTiles"), (list, np.ndarray, str)
            ):
                true_board = Board.from_occ_tiles(row["trueTiles"])
            else:
                true_board = Board.from_trial_id(row["board_id"])
            total_ship = int(np.sum(true_board.board > Board.water))
            return hits / total_ship if total_ship > 0 else 0.0
        except Exception:
            return np.nan

    _df["hits_pct"] = _df.apply(compute_hits_pct, axis=1)

# cumulativeStageTime fallback if needed (seconds)
if "cumulativeStageTime" not in _df.columns and "messageTime" in _df.columns:
    _df = _df.sort_values("index")
    _df["cumulativeStageTime"] = (
        _df.groupby(["pairID", "roundID"])["messageTime"].cumsum().div(1000)
    )

# Pick first 3 pairs (as in existing flows)
pairs = list(sorted(_df["pairID"].dropna().unique()))[:3]
_df = _df[_df["pairID"].isin(pairs)].copy()


def board_to_img_array(occ_tiles):
    b64 = Board.from_occ_tiles(occ_tiles).to_base64()
    return plt.imread(io.BytesIO(base64.b64decode(b64)), format="png")


fig, axes = plt.subplots(
    nrows=len(pairs), ncols=1, figsize=(12, 9), sharex=False, constrained_layout=True
)
if len(pairs) == 1:
    axes = [axes]

for ax, pair in zip(axes, pairs):
    pair_df = _df[_df["pairID"] == pair].copy()
    pair_df = pair_df.sort_values(["roundID", "index"])

    # Plot per-round completion timelines and mark questions
    for round_id, s in pair_df.groupby("roundID"):
        s = s.sort_values("index")
        # line
        ax.plot(
            s["cumulativeStageTime"],
            s["hits_pct"],
            lw=1.6,
            alpha=0.85,
            label=f"{round_id}",
        )
        # question markers
        q = s[s["messageType"] == "question"]
        if not q.empty:
            ax.scatter(
                q["cumulativeStageTime"],
                q["hits_pct"],
                s=28,
                c="tab:orange",
                marker="o",
                edgecolor="white",
                linewidth=0.7,
                alpha=0.95,
                zorder=3,
            )

        # Thumbnail: final state for this round
        try:
            snap = s.iloc[-1]
            img_arr = board_to_img_array(snap["occTiles"])
            img = OffsetImage(img_arr, zoom=0.18)
            ab = AnnotationBbox(
                img,
                (float(snap["cumulativeStageTime"]), float(snap["hits_pct"])),
                frameon=True,
                box_alignment=(0.0, 1.0),
                pad=0.2,
                zorder=4,
            )
            ax.add_artist(ab)
        except Exception:
            pass

    # Thumbnail: start-of-pair snapshot (earliest time across all rounds)
    try:
        first_snap = pair_df.sort_values("cumulativeStageTime").iloc[0]
        img_arr = board_to_img_array(first_snap["occTiles"])
        img = OffsetImage(img_arr, zoom=0.22)
        ab = AnnotationBbox(
            img,
            (float(first_snap["cumulativeStageTime"]), float(first_snap["hits_pct"])),
            frameon=True,
            box_alignment=(1.0, 0.0),
            pad=0.2,
            zorder=4,
        )
        ax.add_artist(ab)
    except Exception:
        pass

    ax.set_title(f"{pair}", loc="left")
    ax.set_ylabel("Completion")
    ax.set_ylim(-0.02, 1.02)
    ax.grid(axis="x", alpha=0.25)
    ax.legend(title="roundID", fontsize=8, ncols=3, loc="upper right")

axes[-1].set_xlabel("Time (s)")
plt.show()

In [None]:
# Single-board prototype: B02
# - One main axis with completion trajectories (hits_pct) for multiple pairIDs
# - One top "image lane" with non-overlapping board thumbnails aligned by time
# - Question markers on the main axis

import io, base64
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from matplotlib.offsetbox import OffsetImage, AnnotationBbox
from battleship.board import Board

_df = df.copy()

# Ensure hits_pct exists
if "hits_pct" not in _df.columns:

    def compute_hits_pct(row):
        try:
            occ_b = Board.from_occ_tiles(row["occTiles"])
            hits = int(np.sum(occ_b.board > Board.water))
            # prefer trueTiles if available; else from board_id
            if "trueTiles" in _df.columns and isinstance(
                row.get("trueTiles"), (list, np.ndarray, str)
            ):
                true_b = Board.from_occ_tiles(row["trueTiles"])
            else:
                true_b = Board.from_trial_id(row["board_id"])
            total = int(np.sum(true_b.board > Board.water))
            return hits / total if total > 0 else 0.0
        except Exception:
            return np.nan

    _df["hits_pct"] = _df.apply(compute_hits_pct, axis=1)

# Ensure cumulativeStageTime (seconds) exists
if "cumulativeStageTime" not in _df.columns and "messageTime" in _df.columns:
    _df = _df.sort_values("index")
    _df["cumulativeStageTime"] = (
        _df.groupby(["pairID", "roundID"])["messageTime"].cumsum().div(1000)
    )

# Focus on a single board
board_focus = "B02"
bdf = _df[_df["board_id"] == board_focus].copy()

# Choose up to first 3 pairs that played this board
pairs = list(sorted(bdf["pairID"].dropna().unique()))[:]
bdf = bdf[bdf["pairID"].isin(pairs)]


# Helper for images
def board_img_from_occ(occ_tiles, zoom=0.18):
    b64 = Board.from_occ_tiles(occ_tiles).to_base64()
    img = plt.imread(io.BytesIO(base64.b64decode(b64)), format="png")
    return OffsetImage(img, zoom=zoom)


# Build figure with a thin top lane for images and a main lane for lines
fig = plt.figure(figsize=(12, 6))
gs = GridSpec(nrows=2, ncols=1, height_ratios=[1, 4], hspace=0.05, figure=fig)
ax_img = fig.add_subplot(gs[0, 0])
ax = fig.add_subplot(gs[1, 0], sharex=ax_img)

# Color per pair
colors = plt.cm.tab10.colors
pair_to_color = {p: colors[i % len(colors)] for i, p in enumerate(pairs)}

# Plot lines and questions in the main axis
for pair in pairs:
    g = bdf[bdf["pairID"] == pair].sort_values(["roundID", "index"])
    # If there are multiple rounds for this board/pair, plot each
    for round_id, s in g.groupby("roundID"):
        s = s.sort_values("index")
        ax.plot(
            s["cumulativeStageTime"],
            s["hits_pct"],
            lw=2.0,
            color=pair_to_color[pair],
            alpha=0.9,
            label=f"{pair}" if round_id == list(g["roundID"].unique())[0] else None,
        )
        q = s[s["messageType"] == "question"]
        if not q.empty:
            ax.scatter(
                q["cumulativeStageTime"],
                q["hits_pct"],
                s=30,
                c=[pair_to_color[pair]],
                marker="o",
                edgecolor="white",
                linewidth=0.8,
                alpha=0.95,
                zorder=3,
            )

# Determine x-limits from data
xmin = float(bdf["cumulativeStageTime"].min()) if not bdf.empty else 0.0
xmax = float(bdf["cumulativeStageTime"].max()) if not bdf.empty else 1.0
if xmin == xmax:
    xmax = xmin + 1.0
ax.set_xlim(xmin, xmax)

# Place one thumbnail per pair at the pair's final snapshot time, in a top lane with vertical staggering
y_slots = np.linspace(0.2, 0.8, max(3, len(pairs)))  # keep spacing generous
for i, pair in enumerate(pairs):
    g = bdf[bdf["pairID"] == pair].sort_values(["roundID", "index"])
    # final snapshot across all rounds for this pair on this board
    snap = g.iloc[-1]
    try:
        img = board_img_from_occ(snap["occTiles"], zoom=0.18)
        y_pos = y_slots[i]
        ab = AnnotationBbox(
            img,
            (float(snap["cumulativeStageTime"]), y_pos),
            frameon=True,
            box_alignment=(0.0, 0.5),
            pad=0.2,
            zorder=4,
        )
        ax_img.add_artist(ab)
        # small label next to thumbnail
        ax_img.text(
            float(snap["cumulativeStageTime"]) + 2,  # slight offset
            y_pos,
            pair,
            va="center",
            ha="left",
            fontsize=10,
            color=pair_to_color[pair],
            weight="bold",
        )
    except Exception:
        pass

# Style the axes
ax.set_ylim(-0.02, 1.02)
ax.set_ylabel("Completion (hits / total)")
ax.set_xlabel("Time (s)")
ax.grid(axis="x", alpha=0.3)
ax.set_title(f"Board {board_focus}: completion over time by pairID")
ax.legend(title="pairID", ncol=min(3, len(pairs)), loc="lower right", frameon=True)

# Style the image lane
ax_img.set_ylim(0, 1)
ax_img.set_yticks([])
ax_img.set_ylabel("Board\nthumbnails", rotation=0, labelpad=40, va="center")
ax_img.spines["top"].set_visible(False)
ax_img.spines["right"].set_visible(False)
ax_img.spines["left"].set_visible(False)
ax_img.spines["bottom"].set_visible(False)
ax_img.set_xticks([])

plt.show()

In [None]:
"""
IDEAS:
- Show F1 score in the legend
- Include sampled question text in the timeline
- Show markers for hits and misses in the timeline
"""

def plot_board_progress_with_thumbnails(
    df,
    board_id="B02",
    n_pairs=None,
    thumb_zoom=0.08,  # half-size thumbnails
    x_stretch=1.0,  # NEW: >1.0 stretches the x-axis horizontally
):
    """
    - df: collaborative dataframe with columns ['pairID','roundID','board_id','index','messageType','occTiles','trueTiles'(opt),'messageTime']
    - board_id: e.g., 'B02'
    - n_pairs: if not None, limit to first n pairIDs for this board
    - x_stretch: multiplies the figure width (e.g., 1.5 -> 50% wider)
    """
    _df = df.copy()

    # Ensure hits_pct
    if "hits_pct" not in _df.columns:

        def compute_hits_pct(row):
            try:
                occ_b = Board.from_occ_tiles(row["occTiles"])
                hits = int(np.sum(occ_b.board > Board.water))
                true_b = (
                    Board.from_occ_tiles(row["trueTiles"])
                    if "trueTiles" in _df.columns
                    and isinstance(row.get("trueTiles"), (list, np.ndarray, str))
                    else Board.from_trial_id(row["board_id"])
                )
                total = int(np.sum(true_b.board > Board.water))
                return hits / total if total > 0 else 0.0
            except Exception:
                return np.nan

        _df["hits_pct"] = _df.apply(compute_hits_pct, axis=1)

    # Ensure cumulativeStageTime (seconds)
    if "cumulativeStageTime" not in _df.columns and "messageTime" in _df.columns:
        _df = _df.sort_values("index")
        _df["cumulativeStageTime"] = (
            _df.groupby(["pairID", "roundID"])["messageTime"].cumsum().div(1000)
        )

    # Focus on board
    bdf = (
        _df[_df["board_id"] == board_id]
        .copy()
        .sort_values(["pairID", "roundID", "index"])
    )
    pairs = list(bdf["pairID"].dropna().unique())
    if n_pairs is not None:
        pairs = pairs[:n_pairs]
        bdf = bdf[bdf["pairID"].isin(pairs)]

    def board_img_from_occ(occ_tiles, zoom=thumb_zoom):
        b64 = Board.from_occ_tiles(occ_tiles).to_base64()
        img = plt.imread(io.BytesIO(base64.b64decode(b64)), format="png")
        return OffsetImage(img, zoom=zoom)

    # Figure: top lane for thumbnails (single y-level), bottom for lines
    base_w, base_h = 12.0, 6.0
    fig = plt.figure(figsize=(base_w * float(x_stretch), base_h))
    gs = GridSpec(nrows=2, ncols=1, height_ratios=[1, 4], hspace=0.06, figure=fig)
    ax_img = fig.add_subplot(gs[0, 0])
    ax = fig.add_subplot(gs[1, 0], sharex=ax_img)

    sns.despine(fig=fig, ax=ax)

    # Colors
    cmap = plt.cm.get_cmap("tab20", max(20, len(pairs)))
    pair_to_color = {p: cmap(i % cmap.N) for i, p in enumerate(pairs)}

    # Lines + question markers (x = time)
    for pair in pairs:
        g = bdf[bdf["pairID"] == pair]
        for round_id, s in g.groupby("roundID"):
            s = s.sort_values("index")
            ax.plot(
                s["cumulativeStageTime"],
                s["hits_pct"],
                lw=1.8,
                color=pair_to_color[pair],
                alpha=0.9,
                label=pair if round_id == list(g["roundID"].unique())[0] else None,
            )
            q = s[s["messageType"] == "question"]
            if not q.empty:
                # hollow circle
                ax.scatter(
                    q["cumulativeStageTime"],
                    q["hits_pct"],
                    s=100,
                    facecolors="white",
                    edgecolors=pair_to_color[pair],
                    marker="o",
                    linewidth=1.2,
                    alpha=1.0,
                    zorder=4,
                )
                # question-mark text centered inside each circle (with white halo for legibility)
                for xt, yt in zip(q["cumulativeStageTime"], q["hits_pct"]):
                    ax.text(
                        xt,
                        yt,
                        "?",
                        ha="center",
                        va="center",
                        fontsize=6,
                        weight="bold",
                        color=pair_to_color[pair],
                        zorder=5,
                        clip_on=True,
                    )

    # X-limits from data
    xmin = float(bdf["cumulativeStageTime"].min()) if not bdf.empty else 0.0
    xmax = float(bdf["cumulativeStageTime"].max()) if not bdf.empty else 1.0
    if xmin == xmax:
        xmax = xmin + 1.0
    ax.set_xlim(xmin, xmax)

    # Thumbnails: final snapshot per pair, all at the same y position
    y_pos = 0.5
    for pair in pairs:
        g = bdf[bdf["pairID"] == pair]
        snap = g.iloc[-1]
        img = board_img_from_occ(snap["occTiles"])
        x_pos = float(snap["cumulativeStageTime"])
        ab = AnnotationBbox(
            img,
            (x_pos, y_pos),
            frameon=True,
            bboxprops=dict(linewidth=0.5, edgecolor="gray"),
            box_alignment=(0.0, 0.5),
            pad=0.12,
            zorder=4,
        )
        ax_img.add_artist(ab)

    # Styling
    ax.set_ylim(-0.02, 1.02)
    ax.set_ylabel("Completion (hits / total)")
    ax.set_xlabel("Time (s)")
    ax.grid(axis="x", alpha=0.3)
    # ax.set_title(f"Board {board_id}: per-pair progress over time")
    ax.legend(
        loc="upper left",
        fontsize=10,
    )

    ax_img.set_ylim(0, 1)
    ax_img.set_yticks([])
    # ax_img.set_xticks([])
    for side in ["top", "right", "left", "bottom"]:
        ax_img.spines[side].set_visible(False)

    ax.tick_params(axis="x", which="both", labelbottom=True, length=4)

    plt.show()
    return fig, (ax_img, ax)


BOARD_ID = "B02"
N_PAIRS = None
THUMB_ZOOM = 0.10
X_STRETCH = 1.5

fig, _ = plot_board_progress_with_thumbnails(
    df, board_id=BOARD_ID, n_pairs=N_PAIRS, thumb_zoom=THUMB_ZOOM, x_stretch=X_STRETCH
);

fig.savefig(
    os.path.join(PATH_EXPORT, f"board_timeline_{BOARD_ID}.pdf"),
    bbox_inches="tight",
    dpi=300,
)

# Gold label analysis

In [None]:
from analysis import parse_answer, GOLD_ANSWER_LABEL, GOLD_CATEGORY_LABELS

GOLD_ANSWER_LABEL_BOOL = "gold_answer_bool"  # True / False
df[GOLD_ANSWER_LABEL_BOOL] = df[GOLD_ANSWER_LABEL].map(parse_answer)

N_QUESTIONS = len(df[(df["messageType"] == "question")])
print(N_QUESTIONS)

In [None]:
# Sum up how many questions fall into each gold category
counts = df[GOLD_CATEGORY_LABELS.keys()].sum().reset_index()
counts.columns = ["category", "count"]
counts["category"] = counts["category"].map(GOLD_CATEGORY_LABELS)


# Convert to percentage of total questions
counts["percentage"] = counts["count"] / N_QUESTIONS * 100

# Plot percentages
plt.figure(figsize=(8, 5))
sns.barplot(data=counts, x="category", y="percentage", palette="Set2")
plt.xticks(rotation=45)
plt.xlabel("Gold Category")
plt.ylabel("Percentage of Questions (%)")
plt.title("Percentage of Each Gold Category")

plt.savefig(
    os.path.join(PATH_EXPORT, f"gold_category_percentage.pdf"),
    bbox_inches="tight",
    dpi=300,
)

In [None]:
# Compute correlation matrix for gold categories (including gold_answer)
df_categories = df[GOLD_CATEGORY_LABELS.keys()].rename(columns=GOLD_CATEGORY_LABELS)
corr = df_categories.corr()

# Plot heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(corr, annot=True, fmt=".2f", cmap="coolwarm", vmin=-1, vmax=1)
plt.title("Correlation Matrix of Gold Categories")
plt.xticks(rotation=45)
plt.yticks(rotation=0)

plt.savefig(
    os.path.join(OUTPUT_DIR, "gold_categories_correlation_matrix.pdf"),
    bbox_inches="tight",
    dpi=300,
)

In [None]:
from sklearn.metrics import cohen_kappa_score

annotator1_cols = [x + "_annotator_1" for x in GOLD_CATEGORY_LABELS.keys()]
annotator2_cols = [x + "_annotator_2" for x in GOLD_CATEGORY_LABELS.keys()]

results = []
for ann1, ann2 in zip(annotator1_cols, annotator2_cols):
    # get base label and human-readable name
    base = ann1.replace("_annotator_1", "")
    label = GOLD_CATEGORY_LABELS[base]
    # drop rows where either annotator is missing
    df_pair = df[[ann1, ann2]].dropna()
    a1 = df_pair[ann1].values.astype(int)
    a2 = df_pair[ann2].values.astype(int)
    # compute Cohen's kappa and percent agreement
    kappa = cohen_kappa_score(a1, a2)
    pct_agree = (a1 == a2).mean()
    results.append(
        {"category": label, "cohen_kappa": kappa, "percent_agreement": pct_agree}
    )

agreement_df = pd.DataFrame(results).set_index("category")
display(agreement_df)

In [None]:
# Visualize inter-annotator agreement (Cohen's kappa and percent agreement) by category
ag = agreement_df.reset_index().rename(columns={"index": "category"})
ag_melt = ag.melt(
    id_vars="category",
    value_vars=["cohen_kappa", "percent_agreement"],
    var_name="metric",
    value_name="value",
)

plt.figure(figsize=(8, 5))
sns.barplot(data=ag_melt, x="category", y="value", hue="metric", palette="Set1")
plt.xticks(rotation=45)
plt.ylabel("Agreement Value")
plt.title("Inter-Annotator Agreement by Gold Category")
plt.legend(title="Metric")
plt.tight_layout()