# Analysis pipeline for Prolific data

In [None]:
%load_ext autoreload
%autoreload 2

In [107]:
import os

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from battleship.board import Board
from analysis import load_dataset

In [108]:
%config InlineBackend.figure_format = 'retina'

# set seaborn color palette
sns.set_palette("Set2")

# set seaborn style
sns.set_style("whitegrid")
sns.set_context("talk")

In [109]:
# EXPERIMENT_NAME = "battleship-2024-10-03-19-28-28"
# EXPERIMENT_NAME = "battleship-pilot-v2"
EXPERIMENT_NAME = "battleship-final-data"

PATH_DATA = os.path.join("data", EXPERIMENT_NAME)
PATH_EXPORT = os.path.join(PATH_DATA, "export")
PATH_BONUS_EXPORT = os.path.join(PATH_EXPORT, f"{EXPERIMENT_NAME}-bonus.csv")
os.makedirs(PATH_EXPORT, exist_ok=True)

In [None]:
df = load_dataset(PATH_DATA, use_gold=True, drop_incomplete=False)

In [None]:
df.columns

In [None]:
df["messageType"].value_counts()

In [None]:
df["name"].value_counts()

In [None]:
df["board_id"].value_counts()

In [None]:
df["messageText"].value_counts()

In [None]:
df.groupby(["board_id"])["roundID"].nunique()

# Visualizations

## Hits

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="hits",
    hue="pairID",
    style="board_id"
)
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.lineplot(data=df, x="index", y="hits_pct", hue="pairID", style="board_id")
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="hits",
    hue="pairID",
    # style="board_id"
)
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

## Number of moves to win

In [None]:
df_move_counts = (
    df[(df["messageType"] == "move")]
    .groupby(["pairID", "board_id"])
    .size()
    .to_frame("move_count")
)
df_move_counts

In [None]:
df_question_counts = df[
    (df["messageType"] == "question")
]
df_question_counts = (
    df_question_counts.groupby(["pairID", "board_id"]).size().to_frame("question_count")
)
df_question_counts

In [None]:
df_counts = df_move_counts.join(df_question_counts)
# replace null values with 0
df_counts = df_counts.fillna(0)
df_counts["question_count"] = df_counts["question_count"].astype(int)
df_counts = df_counts.sort_values(["pairID", "board_id"]).reset_index(drop=False)
df_counts

In [None]:
with sns.plotting_context("talk"), sns.axes_style("whitegrid"):

    sns.boxplot(
        data=df_counts,
        y="move_count",
        hue="pairID",
        hue_order=df_counts["pairID"].unique(),
    )

    plt.ylabel("Moves per board")

    # move legend outside of plot
    plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.stripplot(
    data=df_counts,
    x="question_count",
    y="move_count",
    hue="pairID",
    hue_order=df_counts["pairID"].unique(),
    size=10.0,
    jitter=0.2,
)

plt.xlabel("Questions")
plt.ylabel("Moves")

# move legend outside of plot
plt.legend(title="Participant pair ID", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.title(f"Questions asked vs. moves")

In [None]:
sns.regplot(
    data=df_counts,
    x="question_count",
    y="move_count",
)

plt.xlabel("Questions")
plt.ylabel("Moves")

plt.title(f"Questions asked vs. moves")

# Bonuses

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="bonus",
    hue="pairID",
    style="board_id",
)

plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.xlabel("Action #")
plt.ylabel("Bonus ($)")
plt.title(f"Bonus over time")

In [None]:
# Group by pairID and board_id, and get the highest-index stage for each group
df_final_stage = df.loc[df.groupby(["pairID", "board_id"])["index"].idxmax(), ["pairID", "gameID", "roundID", "board_id", "bonus", "hits_pct", "precision", "recall"]]

# drop outliers - probably a bug
df_final_stage = df_final_stage[df_final_stage["bonus"] < 5.0]

df_final_stage

In [None]:
# sns.displot(data=df_final_stage, x="bonus", hue="pairID", kind="kde", fill=True)
sns.displot(
    data=df_final_stage,
    x="bonus",
    # hue="pairID",
    kind="hist",
    # multiple="stack"
)
plt.xlabel("Bonus ($)")
plt.title("Bonus distribution")

In [None]:
# sns.displot(data=df_final_stage, x="bonus", hue="pairID", kind="kde", fill=True)
sns.displot(
    data=df_final_stage,
    x="hits_pct",
    # hue="pairID",
    kind="hist",
    # multiple="stack"
)

In [None]:
df_final_stage["bonus"].describe()

In [None]:
df_counts_with_bonus = df_counts.merge(df_final_stage, on=["pairID", "board_id"], how="left")
df_counts_with_bonus

In [None]:
sns.stripplot(
    data=df_counts_with_bonus,
    x="question_count",
    y="bonus",
    hue="pairID",
    hue_order=df_counts_with_bonus["pairID"].unique(),
    size=10.0,
    jitter=0.2,
)

# move legend outside of plot
plt.legend(title="Participant pair ID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.regplot(
    data=df_counts_with_bonus,
    x="question_count",
    y="bonus",
)

In [None]:
df_final_stage

In [None]:
sns.regplot(data=df_counts_with_bonus, x="question_count", y="precision")

In [137]:
import statsmodels.api as sm

In [None]:
df_counts_with_bonus

In [None]:
import statsmodels.api as sm
import statsmodels.formula.api as smf

# Reset index to avoid multi-index issues
df_counts_with_bonus = df_counts_with_bonus.reset_index(drop=True)
df_counts_with_bonus["group"] = 1

# Fit the linear mixed effects model
# vcf = {"board_id": "0 + C(board_id)", "pairID": "0 + C(pairID)"}  # formula
# model = sm.MixedLM.from_formula(
#     "move_count ~ question_count",
#     groups="group",
#     vc_formula=vcf,
#     re_formula="~board_id",
#     data=df_counts_with_bonus,
# )
model = smf.mixedlm("move_count ~ question_count", df_counts_with_bonus, groups=df_counts_with_bonus["board_id"])
result = model.fit()

# Print the summary of the model
print(result.summary())

### Export final bonus information

In [None]:
PLAYER_COLUMNS = ["gameID", "participantIdentifier"]
df_final_stage_export = df_final_stage.merge(df_player[PLAYER_COLUMNS], on="gameID")

df_final_stage_export = df_final_stage_export.groupby("participantIdentifier")[
    "bonus"
].sum()
df_final_stage_export.to_csv(PATH_BONUS_EXPORT, header=False, index=False)
display(df_final_stage_export)

## Timing

In [None]:
df['cumulativeStageTime'] = df.sort_values("index").groupby(['pairID', 'roundID'])['messageTime'].cumsum().div(1000)
df

In [None]:
with sns.plotting_context("talk"), sns.axes_style("whitegrid"):
    sns.lineplot(
        data=df,
        x="index",
        y="cumulativeStageTime",
        hue="pairID",
        style="board_id",
    )

    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
# Group by pairID and board_id, and get the highest-index stage for each group
df_final_stage_time = df.loc[
    df.groupby(["pairID", "board_id"])["index"].idxmax(),
    ["pairID", "board_id", "cumulativeStageTime"],
]

df_final_stage_time

In [None]:
sns.barplot(
    data=df_final_stage_time, y="cumulativeStageTime", x="board_id", hue="pairID"
)

In [None]:
pd.set_option('display.max_columns', None)
df[df['messageTime'].isna()]

# Timeline visualization
Shows the timeline of moves and questions for each game.

In [None]:
g = sns.relplot(
    kind="line",
    col="board_id",
    row="pairID",
    aspect=2.0,
    data=df.sort_values(["pairID", "board_id"]),
    x="index",
    y="hits_pct",
    hue="pairID",
    linewidth=6,
)

# Plot a marker for each question
g.map_dataframe(
    lambda data, **kws: sns.scatterplot(
        data=data[data["messageType"] == "question"],
        x="index",
        y="hits_pct",
        s=10,
        marker="o",
        color="black",
        zorder=10,  # Set zorder to be on top
    ),
    board_id="board_id",
    pairID="pairID",
)


for (pairID, board_id), ax in g.axes_dict.items():
    y_max, y_offset = -np.inf, 0.05
    for _, row in df[
        (df["messageType"] == "question") &
        (df["board_id"] == board_id) &
        (df["pairID"] == pairID)
    ].sort_values("index", ascending=True).iterrows():
        y = row["hits_pct"]
        y = max(y, y_max + y_offset)
        y_max = y

        if y > row["hits_pct"]:
            ax.plot(
                [row["index"], row["index"]],
                [y, row["hits_pct"]],
                color="gray",
                linestyle="--",
                linewidth=1,
                alpha=0.5,
            )

        ax.text(
            row["index"],
            y,
            row["messageText"],
            horizontalalignment="left",
            size=10,
            color="black",
        )