In [None]:
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.ticker import FormatStrFormatter
import seaborn as sns
sns.set_theme(context="poster", font_scale=1)

from pathlib import Path

sys.path.append("/home/acho/Sync/KiddLab/MSM/src")
from utils.stim_tools import *

# sys.path.append("/home/acho/Sync/Python/sigtools")
# from sigtools.representations import *
# from sigtools.sounds import *
# from sigtools.processing import *
# from sigtools.spatialization import *

In [None]:
# plt.rcParams["font.family"] = "serif"
plt.rcParams["font.serif"] = "Times New Roman"

In [None]:
run_figsize = (16, 12)
n_runs_to_load = 1
n_trials_per_block = 8
n_words_per_trial = 4
chance_level = 1/8

marker_symbols = ["o", "v", "^", "d", "*"]
subjects = ["ayc"]
legend_elements = []
for i in range(n_runs_to_load):
    legend_elements = [Line2D([], [], linestyle="", marker=marker_symbols[i], markeredgewidth=3, markeredgecolor="k", markerfacecolor="w", markersize=18, label=subjects[i])]
legend_elements.append(Patch(facecolor="b", edgecolor="b", label=r"$\pm 90^{\circ}$"),)
legend_elements.append(Patch(facecolor="r", edgecolor="r", label=r"$0^{\circ}$"))
legend_elements.append(Line2D([], [], color="k", linestyle=":", label="chance"))

all_run_nums = range(n_runs_to_load)
fig, ax = plt.subplots(1, 1, figsize=run_figsize)
for run_num in all_run_nums:
    run_file_name = "RUN_" + str(run_num).zfill(3) + ".csv"
    stim_data = pd.read_csv(STIM_DIR/"stimulus_database.csv")
    run_data = pd.read_csv(DATA_DIR/run_file_name)
    run_stim = stim_data.loc[run_data.stimulus_ID].reset_index()
    run_stim = run_stim.drop(labels=["index", "stim_type"], axis=1)
    run_data = run_data.join(run_stim)
    run_data = run_data.drop(labels=["run_num", "subject_ID", "task_type"], axis=1)

    n_blocks = run_data.block_num.max()
    n_max_correct = n_trials_per_block*n_words_per_trial
#     n_max_correct = n_blocks*n_trials_per_block*n_words_per_trial

    # Group data
    run_rate_grouped = run_data.groupby(by="target_alt_rate")
    rates = list(run_rate_grouped.indices.keys())[1:]
    correct   = run_rate_grouped.sum()["correct"].values[1:]
    # either ear
#     antipodal = run_data[(run_data["target_alt_rate"] == 0) & \
#                          (run_data["target_init_position"] != 0)]["correct"].sum()
    # co-located
    colocated = run_data[(run_data["target_alt_rate"] == 0) & \
                         (run_data["target_init_position"] == 0)]["correct"].sum()

    # Plot
#     ax.plot(rates[ 0] - 1, antipodal/n_max_correct, "b" + marker_symbols[run_num], markersize=18)
    ax.plot(        rates,   correct/n_max_correct, "k" + marker_symbols[run_num] + "-", markersize=18)
    ax.plot(rates[-1] + 1, colocated/n_max_correct, "r" + marker_symbols[run_num], markersize=18)
ax.hlines(chance_level, rates[0] - 2*1.75, rates[-1] + 2*1.75, color="k", linestyle=":")
ax.set_title("Word by word identification performance", fontsize=40)
ax.set_xlabel("Oscillation rate [Hz]\n(linear scale)", fontsize=28)
ax.set_ylabel("Percent correct", fontsize=28)
ax.set_xlim((rates[0] - 1.75, rates[-1] + 1.75))
ax.set_ylim((0, 1))
ax.set_xticks(rates)
ax.set_xticklabels(rates, rotation=90, fontsize=20)
ax.set_yticks(np.linspace(0, 1, 11, endpoint=True))
ax.set_yticklabels(np.linspace(0, 1, 11, endpoint=True), fontsize=20)
ax.xaxis.set_major_formatter(FormatStrFormatter("%.1f"))
ax.yaxis.set_major_formatter(FormatStrFormatter("%.1f"))
ax.grid(linestyle="-")
ax.legend(handles=legend_elements)
# ax.grid(linestyle="--", alpha=0.5)

# plt.savefig("b.pdf")
plt.show()