### Violin plot info

In [1]:
import numpy as np
import pickle
import matplotlib.pyplot as plt
import pathlib
import ast
import scipy.stats

import pickle

In [2]:
from bci_plot.utils import data_util
from bci_plot.metadata import sessions_info_w_day
from bci_plot.gen_data.adaptation import get_v
from bci_plot.gen_fig.adaptation import violin_util

In [3]:
src_dir = pathlib.Path('../../data/adaptation')

In [4]:
stats = {}
for (session, decoder_name, fold, subject, subject_day) in sessions_info_w_day.sessions_info:
    if subject not in stats:
        stats[subject] = []
    # Assumes the metadata is already ordered!
    while len(stats[subject]) <= subject_day:
        stats[subject].append([])
    with open(src_dir / f'{session}.pickle', 'rb') as f:
        session_stats = pickle.load(f)
        session_stats['header'] = (session, decoder_name, fold, subject, subject_day)
        session_stats['v_info'] = get_v.get_v(session) # valid velocity info.
        stats[subject][subject_day].append(session_stats)

In [5]:
def dict_cat(ds, axis=0):
    '''
    ds: dictionaries
    '''
    keys = list(ds[0].keys())
    
    return {key: np.concatenate([d[key] for d in ds], axis=axis) for key in keys}

In [6]:
day_v_stats = {}
for subject in stats.keys():
    day_v_stats[subject] = {
        'pts': [dict_cat([item['v_info'] for item in day_stats]) for day_stats in stats[subject]]
    }

In [7]:
w = np.array([[1.0, -1, 0, 0], [0, 0, 1, -1]])

In [8]:
violin_info = {}
for subject in stats.keys():
    violin_info[subject] = {}
    violin_info[subject]['x_info_kf'] = [violin_util.get_violin_points(pts['valid_kf_vel']@w[0].T)[0] for pts in day_v_stats[subject]['pts']]
    violin_info[subject]['y_info_kf'] = [violin_util.get_violin_points(pts['valid_kf_vel']@w[1].T)[0] for pts in day_v_stats[subject]['pts']]
    violin_info[subject]['x_info_eegnet'] = [violin_util.get_violin_points(pts['valid_rlud_decoder_output']@w[0].T)[0] for pts in day_v_stats[subject]['pts']]
    violin_info[subject]['y_info_eegnet'] = [violin_util.get_violin_points(pts['valid_rlud_decoder_output']@w[1].T)[0] for pts in day_v_stats[subject]['pts']]
with open('../../data/adaptation/violin_info.pickle', 'wb') as f:
    pickle.dump(violin_info, f)