# Analysis pipeline for Prolific data

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

import json
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from battleship.board import Board

In [3]:
%config InlineBackend.figure_format = 'retina'

# set seaborn color palette
sns.set_palette("Set2")

# set seaborn style
sns.set_style("whitegrid")
sns.set_context("talk")

In [4]:
# EXPERIMENT_NAME = "battleship-2024-10-03-19-28-28"
# EXPERIMENT_NAME = "battleship-pilot-v2"
EXPERIMENT_NAME = "battleship-final-data"

PATH_DATA = os.path.join("data", EXPERIMENT_NAME)


PATH_BATCH = os.path.join(PATH_DATA, "batch.csv")
PATH_GAME = os.path.join(PATH_DATA, "game.csv")
PATH_GLOBAL = os.path.join(PATH_DATA, "global.csv")
PATH_PLAYER = os.path.join(PATH_DATA, "player.csv")
PATH_ROUND = os.path.join(PATH_DATA, "round.csv")
PATH_STAGE = os.path.join(PATH_DATA, "stage.csv")

PATH_EXPORT = os.path.join(PATH_DATA, "export")
PATH_BONUS_EXPORT = os.path.join(PATH_EXPORT, f"{EXPERIMENT_NAME}-bonus.csv")
os.makedirs(PATH_EXPORT, exist_ok=True)

In [None]:
df_stage = pd.read_csv(PATH_STAGE)

# rename id to stageID
df_stage = df_stage.rename(mapper={"id": "stageID"}, axis=1)

# drop all columns that end with LastChangedAt
df_stage = df_stage.loc[:, ~df_stage.columns.str.endswith("LastChangedAt")]

df_stage

In [None]:
df_round = pd.read_csv(PATH_ROUND)

# rename id to roundID
df_round = df_round.rename(mapper={"id": "roundID"}, axis=1)

# drop all columns that end with LastChangedAt
df_round = df_round.loc[:, ~df_round.columns.str.endswith("LastChangedAt")]

df_round

In [None]:
df_player = pd.read_csv(PATH_PLAYER)
df_player

In [None]:
df_ended = df_player[["gameID", "ended", "timeoutGameEnd"]].drop_duplicates()
df_ended["gameCompleted"] = (df_ended["ended"] == "game ended") & (df_ended["timeoutGameEnd"] == False)

if not (df_ended["gameCompleted"]).all():
    print("WARNING: Some games were not completed.")

df_ended

In [None]:
# Merge stage, round, and player dataframes

ROUND_COLUMNS = ["roundID"] + ["board_id", "trueTiles"]
df = df_stage.merge(df_round[ROUND_COLUMNS], on="roundID")

# drop all rows where messageType is not in (fire, question, answer, decision)
df = df[df["messageType"].isin(["move", "question", "answer", "decision"])]

# drop all rows where game was not completed
df = df.merge(df_ended, on="gameID")
df = df[df["gameCompleted"]]

# Convert occTiles and trueTiles to numpy arrays
df["occTiles"] = df["occTiles"].apply(json.loads)
df["trueTiles"] = df["trueTiles"].apply(json.loads)

# Convert board_id to int
# df["board_id"] = df["board_id"].astype(int)

# Map each gameID to a unique pairID (pair_01, pair_02, ...)
df["pairID"] = df["gameID"].map({gameID: f"pair_{i:02}" for i, gameID in enumerate(sorted(df["gameID"].unique()))})

# Sort by pairID and roundID
df = df.sort_values(by=["pairID", "roundID"])

df

# Sampling games for annotation

In [None]:
gold_round_ids = df["roundID"].unique()
# print(gold_round_ids)

df_gold = df[df["roundID"].isin(gold_round_ids)][["gameID", "roundID"]].drop_duplicates().sample(n=20, replace=False, random_state=123).reset_index(drop=True)
print(df_gold)

df_gold.to_csv(os.path.join(PATH_EXPORT, f"{EXPERIMENT_NAME}-gold.csv"), index=False)

In [11]:
def compute_hits(board_array: np.ndarray):
    board = np.array(board_array)
    return np.sum(board > 0)

In [12]:
df["hits"] = df["occTiles"].apply(compute_hits)
df["totalShipTiles"] = df["trueTiles"].apply(compute_hits)
df["hits_pct"] = df["hits"] / df["totalShipTiles"]

In [None]:
df.columns

In [None]:
df["messageType"].value_counts()

In [None]:
df["name"].value_counts()

In [None]:
df["board_id"].value_counts()

In [None]:
df["messageText"].value_counts()

# Visualizations

## Hits

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="hits",
    hue="pairID",
    style="board_id"
)
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.lineplot(data=df, x="index", y="hits_pct", hue="pairID", style="board_id")
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="hits",
    hue="pairID",
    # style="board_id"
)
plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

## Number of moves to win

In [None]:
df_move_counts = (
    df[(df["messageType"] == "move")]
    .groupby(["pairID", "board_id"])
    .size()
    .to_frame("move_count")
)
df_move_counts

In [None]:
df_question_counts = df[
    (df["messageType"] == "question")
]
df_question_counts = (
    df_question_counts.groupby(["pairID", "board_id"]).size().to_frame("question_count")
)
df_question_counts

In [None]:
df_counts = df_move_counts.join(df_question_counts)
# replace null values with 0
df_counts = df_counts.fillna(0)
df_counts["question_count"] = df_counts["question_count"].astype(int)
df_counts = df_counts.sort_values(["pairID", "board_id"]).reset_index(drop=False)
df_counts

In [None]:
with sns.plotting_context("talk"), sns.axes_style("whitegrid"):

    sns.boxplot(
        data=df_counts,
        y="move_count",
        hue="pairID",
        hue_order=df_counts["pairID"].unique(),
    )

    plt.ylabel("Moves per board")

    # move legend outside of plot
    plt.legend(title="pairID", bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
sns.stripplot(
    data=df_counts,
    x="question_count",
    y="move_count",
    hue="pairID",
    hue_order=df_counts["pairID"].unique(),
    size=10.0,
    jitter=0.2,
)

plt.xlabel("Questions")
plt.ylabel("Moves")

# move legend outside of plot
plt.legend(title="Participant pair ID", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.title(f"Questions asked vs. moves")

# Bonuses

In [None]:
sns.lineplot(
    data=df,
    x="index",
    y="bonus",
    hue="pairID",
    style="board_id",
)

plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.xlabel("Action #")
plt.ylabel("Bonus ($)")
plt.title(f"Bonus over time")

In [None]:
# Group by pairID and board_id, and get the highest-index stage for each group
df_final_bonus = df.loc[df.groupby(["pairID", "board_id"])["index"].idxmax(), ["pairID", "gameID", "roundID", "board_id", "bonus"]]

df_final_bonus

In [None]:
# sns.displot(data=df_final_bonus, x="bonus", hue="pairID", kind="kde", fill=True)
sns.displot(data=df_final_bonus, x="bonus", hue="pairID", kind="hist", multiple="stack")
plt.xlabel("Bonus ($)")
plt.title("Bonus distribution")

In [None]:
df_final_bonus["bonus"].describe()

### Export final bonus information

In [None]:
PLAYER_COLUMNS = ["gameID", "participantIdentifier"]
df_final_bonus_export = df_final_bonus.merge(df_player[PLAYER_COLUMNS], on="gameID")

df_final_bonus_export = df_final_bonus_export.groupby("participantIdentifier")[
    "bonus"
].sum()
df_final_bonus_export.to_csv(PATH_BONUS_EXPORT, header=False, index=False)
display(df_final_bonus_export)

## Timing

In [None]:
df['cumulativeStageTime'] = df.sort_values("index").groupby(['pairID', 'roundID'])['messageTime'].cumsum().div(1000)
df

In [None]:
with sns.plotting_context("talk"), sns.axes_style("whitegrid"):
    sns.lineplot(
        data=df,
        x="index",
        y="cumulativeStageTime",
        hue="pairID",
        style="board_id",
    )

    plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")

In [None]:
# Group by pairID and board_id, and get the highest-index stage for each group
df_final_stage_time = df.loc[
    df.groupby(["pairID", "board_id"])["index"].idxmax(),
    ["pairID", "board_id", "cumulativeStageTime"],
]

df_final_stage_time

In [None]:
sns.barplot(
    data=df_final_stage_time, y="cumulativeStageTime", x="board_id", hue="pairID"
)

In [None]:
pd.set_option('display.max_columns', None)
df[df['messageTime'].isna()]

# Timeline visualization
Shows the timeline of moves and questions for each game.

In [None]:
g = sns.relplot(
    kind="line",
    col="board_id",
    row="pairID",
    aspect=2.0,
    data=df.sort_values(["pairID", "board_id"]),
    x="index",
    y="hits_pct",
    hue="pairID",
    linewidth=6,
)

# Plot a marker for each question
g.map_dataframe(
    lambda data, **kws: sns.scatterplot(
        data=data[data["messageType"] == "question"],
        x="index",
        y="hits_pct",
        s=10,
        marker="o",
        color="black",
        zorder=10,  # Set zorder to be on top
    ),
    board_id="board_id",
    pairID="pairID",
)


for (pairID, board_id), ax in g.axes_dict.items():
    y_max, y_offset = -np.inf, 0.05
    for _, row in df[
        (df["messageType"] == "question") &
        (df["board_id"] == board_id) &
        (df["pairID"] == pairID)
    ].sort_values("index", ascending=True).iterrows():
        y = row["hits_pct"]
        y = max(y, y_max + y_offset)
        y_max = y

        if y > row["hits_pct"]:
            ax.plot(
                [row["index"], row["index"]],
                [y, row["hits_pct"]],
                color="gray",
                linestyle="--",
                linewidth=1,
                alpha=0.5,
            )

        ax.text(
            row["index"],
            y,
            row["messageText"],
            horizontalalignment="left",
            size=10,
            color="black",
        )