In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import pathlib
import ast
import scipy.stats

In [2]:
from bci_plot.utils import data_util
from bci_plot.metadata import sessions_info_w_day
from bci_plot.gen_fig.adaptation import gen_fitts_fig

In [3]:
src_dir = pathlib.Path('../../data/adaptation')

In [4]:
stats = {}
for (session, decoder_name, fold, subject, subject_day) in sessions_info_w_day.sessions_info:
    if subject not in stats:
        stats[subject] = []
    # Assumes the metadata is already ordered!
    while len(stats[subject]) <= subject_day:
        stats[subject].append([])
    with open(src_dir / f'{session}.pickle', 'rb') as f:
        session_stats = pickle.load(f)
        session_stats['header'] = (session, decoder_name, fold, subject, subject_day)
        stats[subject][subject_day].append(session_stats)

In [5]:
def cat(x):
    if len(x) == 0:
        return np.zeros(0)
    return np.concatenate(x)

In [6]:
compiled_stats = {subject: dict() for subject in stats}
for subject, subject_stats in stats.items():
    compiled_stats[subject]['day_bps'] = [cat([session_stats['block_bps'] for session_stats in day_stats]) for day_stats in subject_stats]
    compiled_stats[subject]['day_bps_ftt'] = [cat([session_stats['block_bps_ftt'] for session_stats in day_stats]) for day_stats in subject_stats]
    compiled_stats[subject]['day_success'] = [cat([session_stats['block_success'] for session_stats in day_stats]) for day_stats in subject_stats]
    compiled_stats[subject]['day_ttt'] = [cat([session_stats['block_ttt'] for session_stats in day_stats]) for day_stats in subject_stats]
    compiled_stats[subject]['day_is_eval'] = [cat([session_stats['block_is_eval'] for session_stats in day_stats]) for day_stats in subject_stats]
    compiled_stats[subject]['day_td'] = [cat([session_stats['block_td'] for session_stats in day_stats]) for day_stats in subject_stats]
    compiled_stats[subject]['decoder_name'] = [cat([[session_stats['header'][1]]*len(session_stats['block_bps']) for session_stats in day_stats]) for day_stats in subject_stats]
    pass

In [24]:
day_idx = 4

print('subject | BPS (acq) | BPS (ftt)')
for subject in ['S2', 'H1', 'H2', 'H4']:
    bps = compiled_stats[subject]['day_bps'][day_idx].mean()
    bps_f = compiled_stats[subject]['day_bps_ftt'][day_idx].mean()
    print(f'{subject:7} |    {bps:5.4f} |    {bps_f:5.4f}')
bps_avg = np.mean([compiled_stats[subject]['day_bps'][day_idx].mean() for subject in ['H1', 'H2', 'H4']])
bps_f_avg = np.mean([compiled_stats[subject]['day_bps_ftt'][day_idx].mean() for subject in ['H1', 'H2', 'H4']])
print(f'{"H avg":7} |    {bps_avg:5.4f} |    {bps_f_avg:5.4f}')

subject | BPS (acq) | BPS (ftt)
S2      |    0.1590 |    0.1760
H1      |    0.2454 |    0.3538
H2      |    0.2057 |    0.2384
H4      |    0.2040 |    0.2823
H avg   |    0.2184 |    0.2915
