In [None]:
# Importing Python and external packages
import os
import sys
import json
import importlib
import pandas as pd
import numpy as np
import sklearn as sk
from scipy import stats, signal
import matplotlib.pyplot as plt

from itertools import product

In [None]:
# check some package versions for documentation and reproducability
print('Python sys', sys.version)
print('pandas', pd.__version__)
print('numpy', np.__version__)
# print('mne_bids', mne_bids.__version__)
# print('mne', mne.__version__)
print('sci-py', scipy.__version__)
print('sci-kit learn', sk.__version__)

In [None]:
import retap_utils.utils_dataManagement as utils_dataMn


In [None]:
# DEFINE RETAP OUTCOME DIRECTORY
ft_path = os.path.join(utils_dataMn.get_local_proj_dir(),
                       'aDBS_tapping',
                       'retap_results', 'features')
assert os.path.exists(ft_path), 'defined ft_path does not exist'
fig_path = os.path.join(utils_dataMn.get_local_proj_dir(),
                       'aDBS_tapping', 'figures')
if not os.path.exists(fig_path): os.makedirs(fig_path)

In [None]:
# DEFINE FEATURES OF INTEREST
FEAT_SEL = ['trace_RMSn',
            'coefVar_impactRMS',
            'coefVar_intraTapInt',
            'mean_raise_velocity']

NORM_METHOD = 'norm'

explore aDBS data content


In [None]:
def get_unique_subs(ft_path):
    
    subs = np.unique([f.split('_')[1] for f in os.listdir(ft_path)])

    return subs

In [None]:
def load_on_off_first10(subs, ft_path, sel_30sec_part='first'):
    
    allowed_part_sel = ['first', 'last', 'full', 'diff']
    
    assert sel_30sec_part in allowed_part_sel, (
        f'sel_30s_part should be in {allowed_part_sel}'
    )

    files_dict = {}
    for sub in subs:
        
        files_dict[sub] = {}

        # only consider Med OFF recordings
        files = [f for f in os.listdir(ft_path)
                 if f'{sub}_' in f
                 and 'medoff' in f.lower()]
        
        if sel_30sec_part != 'diff':
            # select off and on cont DBS (for first, last, or full)
            for stim in ['off', 'on']:
                json_file = [f for f in files if sel_30sec_part in f
                        and 'cdbs' in f.lower()
                        and f'stim{stim}' in f.lower()][0]
                # load json with features
                with open(os.path.join(ft_path, json_file), 'r') as f:
                    files_dict[sub][stim] = json.load(f)
            
        elif sel_30sec_part == 'diff':
            # select off and on cont DBS for first AND last
            for sel, stim in product(['first', 'last'], ['off', 'on']):
                json_file = [f for f in files if sel in f
                             and 'cdbs' in f.lower()
                             and f'stim{stim}' in f.lower()][0]
                # load json file
                with open(os.path.join(ft_path, json_file), 'r') as f:
                    files_dict[sub][f'{stim}_{sel}'] = json.load(f)
            
    return files_dict

In [None]:
def norm_feat_lists(box_lists, feat_names,
                    norm_method):

    for i_ft, ft in enumerate(feat_names):

        off = box_lists[i_ft * 2]
        on = box_lists[i_ft * 2 + 1]

        if norm_method == 'norm':
            norm_value = np.nanmax(off)
            off = np.array(off) / norm_value
            on = np.array(on) / norm_value
        elif norm_method == 'std':
            norm_m = np.nanmean(off)
            norm_sd = np.nanstd(off)
            off = (np.array(off) - norm_m) / norm_sd
            on = (np.array(on) - norm_m) / norm_sd

        box_lists[i_ft * 2] = off
        box_lists[i_ft * 2 + 1] = on

        S, p = stats.ttest_rel(off, on)
        print(f'{ft}: R: {round(S, 3)}, p = {round(p, 5)}')
    
    return box_lists

In [None]:
def extract_feature_lists(files_dict, feat_names, subs,
                          norm_method=False,):

    conditions = ['off', 'on']

    box_lists = []
    box_labels = []

    for ft, con in product(feat_names, conditions):
        box_labels.append(f'{ft} {con}')
        values = [files_dict[s][con][ft] for s in subs]
        box_lists.append(values)
    
    if isinstance(norm_method, str):
        assert norm_method in ['norm', 'std'], 'incorrect norm_method'
        box_lists = norm_feat_lists(box_lists=box_lists,
                                feat_names=FEAT_SEL,
                                norm_method=norm_method)
        
    return box_lists, box_labels

In [None]:
subs = get_unique_subs(ft_path=ft_path)

files_dict = load_on_off_first10(subs=subs, ft_path=ft_path,
                                 sel_30sec_part='last')
# TODO: calculate difference first and last!
# TODO: include aDBS
files_dict['sub-529DT76'].keys()

In [None]:
subs = get_unique_subs(ft_path=ft_path)

files_dict = load_on_off_first10(subs=subs, ft_path=ft_path)

box_lists, box_labels = extract_feature_lists(files_dict=files_dict,
                                              feat_names=FEAT_SEL, subs=subs,
                                              norm_method='norm',)


In [None]:
SCATTER = True
SCATTER_LINES = True

SAVE_FIG=False
SHOW_FIG=True

fig_name = f'firstRun_4feats_cDBSOffvsOn_{NORM_METHOD}'

fig, ax = plt.subplots(1, 1, figsize=(8, 6))

fontsize=14

ax.boxplot(box_lists, positions=np.arange(len(box_lists)),)
ax.set_xticklabels(box_labels, rotation=90,
                   size=fontsize)
if method == 'norm':
    ylabel = 'Normalised feature values\n(against max OFF value)'
elif method == 'std':
    ylabel = 'Standardised feature values\n(against OFF values)'

if SCATTER:
    for i_x, values in enumerate(box_lists):
        ax.scatter([i_x] * len(values), values,
                   alpha=.3)
if SCATTER_LINES:
    for i_ft in np.arange(len(FEAT_SEL)):
        for i_sub in np.arange(len(box_lists[i_ft])):
            y1 = box_lists[i_ft * 2][i_sub]
            y2 = box_lists[i_ft * 2 + 1][i_sub]
            ax.plot([i_ft*2, i_ft*2+1], [y1, y2],
                    c='gray', alpha=.3)

ax.set_ylabel(ylabel,
              size=fontsize)

plt.tick_params(axis='both', labelsize=fontsize, size=fontsize)

plt.tight_layout()

if SAVE_FIG:
    plt.savefig(os.path.join(fig_path, fig_name), dpi=300,
                facecolor='w',)
if SHOW_FIG: plt.show()
else: plt.close()