In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import yaml

In [None]:
from pathlib import Path

os.chdir(Path().resolve().parent)

In [None]:
sns.set_context("poster", font_scale=2.0)
sns.set_style("white")
sns.set_theme(style="ticks")
xkcd_palette = ["lightish purple", "mango", "turquoise", "cherry red"]
sns.set_palette(sns.xkcd_palette(xkcd_palette))

In [None]:
sessions = ["bex_20230623", "ken_20230614", "ken_20230618"]
trial_start = {session: [] for session in sessions}
trial_end = {session: [] for session in sessions}
choice = {session: [] for session in sessions}
choice2idx = {session: dict() for session in sessions}
com_dict = {session: [] for session in sessions}
responses_dict = {session: [] for session in sessions}

for session in sessions:
    TRIALS_DIR = os.path.join(session, "trials")
    for trial in sorted(f for f in os.listdir(TRIALS_DIR)):
        with open(os.path.join(TRIALS_DIR, trial), "rb") as f:
            meta = yaml.safe_load(f)
        choice[session].append(meta["choice"])
        start = meta["first_frame_idx"]
        end = start + meta["num_frames"]
        duration = end - start
        remainder = duration % 50
        idxs_adjustment = remainder != 0
        adjustment = np.where(idxs_adjustment, remainder, 0)
        start_adjustment = adjustment // 2
        end_adjustment = adjustment - start_adjustment
        trial_start[session].append(int(start + start_adjustment))
        trial_end[session].append(int(end - end_adjustment))

    POSES_DIR = os.path.join(session, "poses")
    with open(os.path.join(POSES_DIR, "meta", "com.npy"), "rb") as f:
        com = np.load(f)
        com_dict[session] = [
            com[b // 10 : e // 10, :].reshape((e - b) // 50, 5, 3).mean(axis=1)
            for b, e in zip(trial_start[session], trial_end[session])
        ]

    choice2idx[session] = {
        "L": np.where(np.array(choice[session]) == "L")[0],
        "R": np.where(np.array(choice[session]) == "R")[0],
    }

    RESP_DIR = os.path.join(session, "responses")
    with open(os.path.join(RESP_DIR, "meta.yml"), "r") as f:
        resp_meta = yaml.safe_load(f)

    responses = np.memmap(
        os.path.join(RESP_DIR, "data.mem"),
        dtype=resp_meta["dtype"],
        mode="r",
        shape=(resp_meta["n_timestamps"], resp_meta["n_signals"]),
    )

    responses_dict[session] = [
        responses[b:e].reshape((e - b) // 50, 50, -1).sum(axis=1)
        for b, e in zip(trial_start[session], trial_end[session])
    ]

In [None]:
fig, axs = plt.subplots(
    2, len(sessions), figsize=(22, 11), sharex="row", sharey="row"
)

# Set global style
sns.set_theme(style="white")
sns.set_context("poster", font_scale=1)
sns.set_palette(sns.xkcd_palette(xkcd_palette))

### **First Row: Trajectories Plot**
com_dict = dict()

for i, session in enumerate(sessions):
    POSES_DIR = os.path.join(session, "poses")
    with open(os.path.join(POSES_DIR, "meta", "com.npy"), "rb") as f:
        com = np.load(f)
        com_dict[session] = [
            com[b // 10 : e // 10, :].reshape((e - b) // 50, 5, 3).mean(axis=1)
            for b, e in zip(trial_start[session], trial_end[session])
        ]

    for c, t in zip(choice[session], com_dict[session]):
        axs[0, i].plot(t[:, 0], t[:, 1], color="black", lw=2)

    axs[0, i].set_title(
        ["Monkey B", "Monkey K, Session 1", "Monkey K, Session 2"][i], pad=15
    )

    # Grid lines
    for y in [0.2, 0.4, 0.6, 0.8]:
        axs[0, i].axhline(y=y, color="black", linestyle="-", alpha=0.2)

# Set y-axis labels and ticks for the first column only
yticks_positions = [0.1, 0.3, 0.5, 0.7, 0.9]
yticks_labels = ["Tile 1", "Tile 2", "Tile 3", "Tile 4", "Tile 5"]
axs[0, 0].set_yticks(yticks_positions)
axs[0, 0].set_yticklabels(yticks_labels)
axs[0, 0].set_ylim(0, 1)
for ax in axs[0]:
    ax.set_xticks([])  # Remove x-axis ticks

### **Second Row: Boxplots**
for i, session in enumerate(sessions):
    sns.boxplot(
        x="area",
        y="correlation",
        hue="alignment",
        data=merged_df[merged_df["session"] == session],
        # palette=palette,
        ax=axs[1, i],  # Use subplot axes
        linewidth=2.5,
        flierprops={"marker": "o", "markersize": 8, "markeredgewidth": 2.5},
    )
    axs[1, i].set_title("")  # Remove subplot titles

    if axs[1, i].get_legend() is not None:
        axs[1, i].get_legend().remove()

# **Despine ONLY the second row (boxplots)**
for ax in axs[1]:
    # ax.tick_params(axis="both", direction="out", length=6, width=2)  # Custom tick style
    ax.set_xlabel("")  # Remove x-axis labels
    sns.despine(ax=ax, trim=True)

# Y-axis labels for the first column only
axs[1, 0].set_ylabel("Correlation", labelpad=15)

# Remove individual legends & move legend below
handles, labels = axs[1, -1].get_legend_handles_labels()
fig.legend(
    handles,
    [label.capitalize() for label in labels],
    loc="center",
    bbox_to_anchor=(0.5, -0.02),
    ncol=4,
    title="",
    frameon=False,
)

plt.tight_layout()
plt.savefig(
    f"plots/nonparametric_models/cosyne_plot.pdf",
    format="pdf",
    bbox_inches="tight",
)
plt.show()