In [None]:
# Importing Python and external packages
import os
import sys
import importlib
import json
import csv
import pickle
from itertools import compress
import pandas as pd
import numpy as np
from itertools import product
from scipy.stats import spearmanr, ttest_rel

import matplotlib.pyplot as plt


In [None]:
def get_project_path_in_notebook(
    subfolder: str = '',
):
    """
    Finds path of projectfolder from Notebook.
    Start running this once to correctly find
    other modules/functions
    """
    path = os.getcwd()

    while path[-20:] != 'dyskinesia_neurophys':

        path = os.path.dirname(path)
    
    return path

In [None]:
# define local storage directories
projectpath = get_project_path_in_notebook()
codepath = os.path.join(projectpath, 'code')

In [None]:
os.chdir(codepath)
# own utility functions
import utils.utils_fileManagement as utilsFiles
import lfpecog_preproc.preproc_import_scores_annotations as importClin
import lfpecog_plotting.plotHelpers as pltHelp

# own data exploration functions
from lfpecog_plotting.plotHelpers import get_colors
import lfpecog_plotting.plotHelpers as pltHelp

import lfpecog_plotting.plot_pred_standards as plotPreds

### 0) Define settings

In [None]:

DATA_VERSION = 'v4.0'    # v4.0: new artef-rem, no reref; v3.0 multiple re-ref
FT_VERSION = 'v8'  # v4: broad-flanks, bursts; v3: broad-flanked SSD
INCL_PSD_FTS=['mean_psd', 'variation']
IGNORE_PTS = ['011', '104', '106']

CDRS_RATER = 'Patricia'
ANALYSIS_SIDE = 'BILAT'
INCL_CORE_CDRS = True
CATEGORICAL_CDRS = False

In [None]:
# get all available subs with features
SUBS = utilsFiles.get_avail_ssd_subs(DATA_VERSION=DATA_VERSION,
                                     FT_VERSION=FT_VERSION,
                                     IGNORE_PTS=IGNORE_PTS)
print(f'SUBS: n={len(SUBS)} ({SUBS})')


#### Import Pred Results Timon

In [None]:
path = os.path.join(utilsFiles.get_project_path('data'),
                    'prediction_data', 'pred_results')
# filename = 'pred_results_0322.pickle'
# filename = 'd_out_NEW_FEATURES_REDUCED_WITH_MOVEMENT_offset_10_dim_4.pickle'  # ft_v7, figures end of March

# filename = 'ftv8_09APR24.pickle'  # v8 n=all (13 or 21)
# filename = 'ftv8_only13_30APR24.pickle'  # v8, n=13 for all, incl 4-class movement aware

filename_n13 = 'pred_out_May2024_n13.pickle'  # v8 head-to-head only ECOG
filename_n21 = 'pred_out_May2024_n21.pickle'  # v8, max group sizes

for f in [filename_n13, filename_n21]:
    assert os.path.exists(os.path.join(path, f)), (
        f'pickle file: "{f}" not found in {path}'
    )

# load n = 13 and n = 21
with open(os.path.join(path, filename_n13), 'rb') as f:
    res13 = pickle.load(f)

with open(os.path.join(path, filename_n21), 'rb') as f:
    res21 = pickle.load(f)

In [None]:
res = res21.copy()

print(f'keys within res-dict: {res.keys()}')

# key = 'ECOG_CEBRA_False_binary'  # without 4 class pickle (deprec in May '24 version)
key = 'ECOGSTN_CEBRA_False_binary_AddMovementLabels_Ephys'  # 4 class pred pcikle
print('\n\nexample STNECOG PREDICTION\n', res[key].keys())

# overview of content
print(len(res[key]['prediction']))  # array per sub
print(len(res[key]['performances']))  # one value per sub
print(type(res[key]['prediction'][0]),
      len(res[key]['prediction'][0]))  # array per sub
print(res[key]['performances'][0])  # one value per sub


key = 'STN_CEBRA_False_binary_AddMovementLabels_Ephys'  # 4 class pred pcikle
print('\n\nexample STN-only PREDICTION\n', res[key].keys())

# overview of content
print(len(res[key]['prediction']))  # array per sub
print(len(res[key]['performances']))  # one value per sub
print(type(res[key]['prediction'][0]),
      len(res[key]['prediction'][0]))  # array per sub
print(res[key]['performances'][0])  # one value per sub


#### Import Acc and Time data

In [None]:
# Load arrays for Acc and Sub-codes
source = 'STN'
with open(os.path.join(utilsFiles.get_project_path('data'),
                 'prediction_data',
                 f"ACC_dataPlus_{source}.pickle"),
    "rb"
) as f:
    metadat = pickle.load(f)

In [None]:
metadat

### 1) Visualize True and Predicted Labels vs Activity

In [None]:
def get_sub_pred_dicts(
    pred_result_pickle,
    OUT_PARAM='binary',
    source='nan',
    incl_ECOG=True,
    incl_STN=True,
    CEBRA_bool = False,  # True -> CEBRA model, False -> linear model
    pred_LIDandMOVE = False,
):
    """
    results in dict with one dicvt per sub, for example:
    - '008'
    --- 'y_true', 'y_pred', 'perf', 'acc', 'times'
    """
    
    assert OUT_PARAM in ['binary', 'scale'], 'incorrect outcome parameter'

    # define SOURCE of Features
    if source != 'nan':
        assert source in ['STN', 'ECOG', 'ECOGSTN'], f'incorrect SOURCE: {source}'
    
    else:
        if incl_STN and incl_ECOG: source = 'ECOGSTN'
        elif incl_STN: source = 'STN'
        elif incl_ECOG: source = 'ECOG'
        else: raise ValueError('at least STN or ECOG has to be inlcuded')

    pred_key = f'{source}_CEBRA_{str(CEBRA_bool)}_{OUT_PARAM}'

    # define only LID pred, or 4-state LID-Move prediction
    if pred_LIDandMOVE:
        target = '_AddMovementLabels_EphysMov'
    else:
        target = '_AddMovementLabels_Ephys'
    
    pred_key += target

    assert pred_key in pred_result_pickle.keys(), (
        f'composed dict-key {pred_key} not in PICKLE-keys'
        f': {pred_result_pickle.keys()}'
    )
    
    # Load arrays for Acc and Sub-codes
    with open(os.path.join(utilsFiles.get_project_path('data'),
                    'prediction_data',
                    f"ACC_dataPlus_{source}.pickle"),
        "rb"
    ) as f:
        metadat = pickle.load(f) 

    # get dicts for sub plotting
    dat_dict = {}

    for l_pred, l_true, l_perf, s in zip(
        pred_result_pickle[pred_key]['prediction'],
        pred_result_pickle[pred_key]['y_test_true'],
        pred_result_pickle[pred_key]['performances'],
        np.unique(metadat['sub_ids'])
    ):
        sub_sel = metadat['sub_ids'] == s
        sub_acc = metadat['ACC_RMS'][sub_sel]
        sub_times = metadat['ft_times_all'][sub_sel]
        

        if len(l_pred) != sum(sub_sel):
            print(f'### WARNING: sub-{s}, {source} mismatch y_pred ({len(l_pred)})'
                  f' and metadat  ({sum(sub_sel)}) -> try correction...')
            # correct for ECOGSTN meta/pred-mismatch in n=13 for sub-016
            if s == '016' and source == 'STN':
                with open(os.path.join(utilsFiles.get_project_path('data'),
                          'prediction_data', f"ACC_dataPlus_ECOG.pickle"),
                          "rb") as f:
                    rescue_meta = pickle.load(f) 
                # select values from correct dict
                sub_sel = rescue_meta['sub_ids'] == s
                sub_acc = rescue_meta['ACC_RMS'][sub_sel]
                sub_times = rescue_meta['ft_times_all'][sub_sel]

        # elif s == '016':
        #     print(f'sub-{s}, {source}-{mod} MATCH CHECK y_pred ({len(l_pred)})'
        #           f' and metadat  ({sum(sub_sel)})')
            
        assert len(l_pred) == len(l_true) == sum(sub_sel), (
            f'sub-{s} mismatch y_pred ({len(l_pred)}) and metadat  ({len(sub_sel)})'
        )

        dat_dict[s] = {
            'y_true': l_true,
            'y_pred': l_pred,
            'perf': l_perf,
            'acc': sub_acc,
            'times': sub_times
        }

    return dat_dict


In [None]:
importlib.reload(plotPreds)

# settings for data import

incl_ECOG = True
incl_STN = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res21,
    source='STN',
    CEBRA_bool=CEBRA_bool,    
)


Plot head to head comparison n=13

In [None]:
def axplot_group_perf_boxes(ax, res_dict):

    box_arrays, box_labels = [], []

    for src, mod in product(['STN', 'ECOG', 'ECOGSTN'],
                            ['LM', 'CEBRA']):
        
        if mod == 'CEBRA': C_bool = True
        else: C_bool = False

        out_dict = get_sub_pred_dicts(
            pred_result_pickle=res_dict,
            source=src,
            CEBRA_bool=C_bool,
        )

        perf = []

        for k in out_dict:
            # print(f"sub-{k}, performance: {dat_dict[k]['perf']}")
            perf.append(out_dict[k]['perf'])

        box_arrays.append(perf)
        box_labels.append(f'{src}\n{mod}')

        print(f'{src}, {mod}: mean perf: {np.mean(perf)} (+/- {np.std(perf)})')


    ax.boxplot(box_arrays, positions=[0,1,3,4,6,7], labels=box_labels)

    return ax

In [None]:
FIG_NAME = '0912_accuracy_boxes_n13_n21'

fsize=14

fig, axes = plt.subplots(1,2, figsize=(12, 4))

axes[0] = axplot_group_perf_boxes(ax=axes[0], res_dict=res13)
axes[0].set_title('Head-to-Head, n=13 only', size=fsize,)

axes[1] = axplot_group_perf_boxes(ax=axes[1], res_dict=res21)
axes[1].set_title('Full Groups, STN: n=21, ECOG: n=13', size=fsize,)

for ax in axes:
    ax.set_ylabel('Balanced Accuracy', size=fsize,)
    ax.axhline(.5, ls='--', alpha=.5, c='gray',)

path = os.path.join(utilsFiles.get_project_path('figures'),
                    'final_Q1_2024', 'prediction', 'group_v8', 'binary')
# plt.savefig(os.path.join(path, FIG_NAME),
#             facecolor='w', dpi=300,)

plt.close()

Plot CEBRA with LID and Move labeling

Plot similar boxplots, but then for the 4-State Prediction, including small boxes representing the dys- and mov-performance

In [None]:
mod = 'ECOGSTN_CEBRA_True'
temp = res21[f'{mod}_binary_AddMovementLabels_EphysMov']

print(temp.keys())


print(f"{mod} mean LID-pred: {np.mean(temp['performances_dys'])}"
      f" (+/- {np.std(temp['performances_dys'])})")

# ECOG only
mod = 'ECOG_CEBRA_True'
temp = res21[f'{mod}_binary_AddMovementLabels_EphysMov']

print(f"{mod} mean LID-pred: {np.mean(temp['performances_dys'])}"
      f" (+/- {np.std(temp['performances_dys'])})")

In [None]:
FIG_NAME = '0917_accuracy_boxes_4state_LM'

CLS = 'LM'
if CLS == 'CEBRA': cls_str = 'CEBRA_True'
else: cls_str = 'CEBRA_False'

fsize=14

fig, axes = plt.subplots(1,2, figsize=(12, 4))

box_arrays, box_labels = [], []
for i, mod_src in enumerate(['STN', 'ECOGSTN', 'ECOG']):
    # get results for source model
    mod = f'{mod_src}_{cls_str}'
    temp = res21[f'{mod}_binary_AddMovementLabels_EphysMov']
    perf = temp['performances']
    box_arrays.append(perf)
    box_labels.append(mod_src)
    # print as check
    print(f"{mod} mean LID x MOV-pred: {np.mean(perf)}"
      f" (+/- {np.std(perf)})")

axes[0].boxplot(box_arrays, positions=[0,1,2], labels=box_labels)
axes[0].set_title(f'LID x MOV ({CLS}): 4-state-accuracy', size=fsize,)
axes[0].axhline(.25, ls='--', alpha=.5, c='gray',)

box_arrays, box_labels = [], []
for i, mod_src in enumerate(['STN', 'ECOGSTN', 'ECOG']):
    # get results for source model
    mod = f'{mod_src}_{cls_str}'
    temp = res21[f'{mod}_binary_AddMovementLabels_EphysMov']
    perf = temp['performances_dys']
    box_arrays.append(perf)
    box_labels.append(mod_src)
    # print as check
    print(f"{mod} mean LID-4state -pred: {np.mean(perf)}"
      f" (+/- {np.std(perf)})")

axes[1].boxplot(box_arrays, positions=[0,1,2], labels=box_labels)
axes[1].set_title(f'LID x MOV ({CLS}): Dyskinesia-accuracy', size=fsize,)
axes[1].axhline(.5, ls='--', alpha=.5, c='gray',)

for ax in axes:
    ax.set_ylabel('Balanced Accuracy', size=fsize,)
    

path = os.path.join(utilsFiles.get_project_path('figures'),
                    'final_Q1_2024', 'prediction', 'group_v8', 'binary')
# plt.savefig(os.path.join(path, FIG_NAME),
#             facecolor='w', dpi=300,)

plt.close()

### FIG7 panels

Compare Ratio vs Multivariate Model

In [None]:
# Load Ratio performances
fname = os.path.join(utilsFiles.get_project_path('data'), 'prediction_data', 'ratio_pred_results.json')

with open(fname, 'r') as f:
    ratio_perf = json.load(f)

In [None]:
# Load arrays for Acc and Sub-codes
with open(os.path.join(utilsFiles.get_project_path('data'),
                'prediction_data',
                f"ACC_dataPlus_STN.pickle"),
    "rb"
) as f:
    metadat = pickle.load(f)

mv_subs = metadat['sub_ids']

In [None]:
res21.keys()

In [None]:
lm_acc = res21[f'STN_CEBRA_False_binary_AddMovementLabels_Ephys']['performances']
cebra_acc = res21[f'STN_CEBRA_True_binary_AddMovementLabels_Ephys']['performances']
ratio_acc = ratio_perf['acc']
ratio_subs = ratio_perf['sub']

In [None]:
res21[f'STN_CEBRA_False_binary_AddMovementLabels_Ephys'].keys()

In [None]:
mv_subs_single = []
for s in mv_subs:
    if s not in mv_subs_single: mv_subs_single.append(s)


In [None]:
paired_acc = {'ratio': [], 'lm': [], 'cebra': [], 'subs': []}

for s, r in zip(ratio_subs, ratio_acc):
    paired_acc['subs'].append(s)
    paired_acc['ratio'].append(r)

    i = np.where(np.array(mv_subs_single) == s)[0][0]
    paired_acc['cebra'].append(cebra_acc[i])
    paired_acc['lm'].append(lm_acc[i])


In [None]:
print('Paired t-test on accuracy:')
for x, y in zip(['ratio', 'ratio', 'lm'], ['lm', 'cebra', 'cebra']):

    S, p = ttest_rel(paired_acc[x], paired_acc[y])
    print(f'\t {x} vs {y}, Stat: {S.round(3)}, p = {p.round(4)}')


Plot Group boxplots SOURCES (use cebra based on FIG S5)

In [None]:
FIGNAME = '1018_box_acc_3sources_cebra'
fs = 14
colors = ['khaki', 'lightseagreen', 'goldenrod']

cebra_models = {
    'STN': res13[f'STN_CEBRA_True_binary_AddMovementLabels_Ephys']['performances'],
    'ECoG': res13[f'ECOG_CEBRA_True_binary_AddMovementLabels_Ephys']['performances'],
    'STN-ECoG': res13[f'ECOGSTN_CEBRA_True_binary_AddMovementLabels_Ephys']['performances']
}

print('Paired t-test on accuracy:')
for x, y in zip(['STN', 'STN', 'ECoG'], ['ECoG', 'STN-ECoG', 'STN-ECoG']):
    S, p = ttest_rel(cebra_models[x], cebra_models[y])
    print(f'\t {x} vs {y}, Stat: {S.round(3)}, p = {p.round(4)}')

fig, ax = plt.subplots(1, 1, figsize=(4, 4),)

bplot = ax.boxplot(cebra_models.values(), patch_artist=True,)

# color boxes
for patch, c in zip(bplot['boxes'], colors):
    patch.set_facecolor(c)
    patch.set_edgecolor(c)
    patch.set_alpha(.7)
for med in bplot['medians']: med.set_color('k')

ax.axhline(1, xmin=0.5, xmax=.83, color='k', alpha=1, )  # STN vs ECoG p < .0001
ax.axhline(1.05, xmin=0.15, xmax=.5, color='k', alpha=1, )  # ECOG vs STN-ECOG p = .0008

ax.set_ylim(0,1.07)
ax.set_yticks([0, .25, .5, .75, 1])
ax.set_yticklabels(['0', '0.25', '0.5', '0.75', '1'])
ax.set_ylabel('Performance (bal. accuracy)', size=fs)
ax.axhline(.5, color='gray', alpha=.3, ls='--',)

ax.tick_params(axis='both', labelsize=fs, size=fs,)
for side in ['top','right']:
    ax.spines[side].set_visible(False)

ax.set_xticklabels(cebra_models.keys(), weight='bold', size=fs-2,)

plt.tight_layout()

plt.savefig(os.path.join(utilsFiles.get_project_path('figures'),
            'final_Q1_2024', 'prediction', 'FIG7', FIGNAME),
            dpi=300, facecolor='w',)

plt.close()

Plot Group accuracy boxplots MODELS (STN-ratio - STN-lm - STN-cebra)

In [None]:
FIGNAME = '1018_box_acc_ratio_vs_lm_vs_cebra'
fs = 14
colors = np.array(pltHelp.get_colors('Jacoba'))[[4, 0, 2]]

rat = ratio_perf['acc']
lm = res21['STN_CEBRA_False_binary_AddMovementLabels_Ephys']['performances']
cebra = res21['STN_CEBRA_True_binary_AddMovementLabels_Ephys']['performances']

fig, ax = plt.subplots(1, 1, figsize=(4, 4),)

bplot = ax.boxplot([rat, lm, cebra], labels=['Ratio', 'LM', 'CEBRA'],
                    patch_artist=True,)

# color boxes
for patch, c in zip(bplot['boxes'], colors):
    patch.set_facecolor(c)
    patch.set_edgecolor(c)
    patch.set_alpha(.7)
for med in bplot['medians']: med.set_color('k')

ax.axhline(1, xmin=0.5, xmax=.83, color='k', alpha=1, )
ax.axhline(1.05, xmin=0.15, xmax=.83, color='k', alpha=1, )

ax.set_ylim(0,1.07)
ax.set_yticks([0, .25, .5, .75, 1])
ax.set_yticklabels(['0', '0.25', '0.5', '0.75', '1'])
ax.set_ylabel('Performance (bal. accuracy)', size=fs)
ax.axhline(.5, color='gray', alpha=.3, ls='--',)

ax.tick_params(axis='both', labelsize=fs, size=fs,)
for side in ['top','right']:
    ax.spines[side].set_visible(False)

ax.set_xticklabels(['STN-ratio', 'STN-LM', 'STN-CEBRA'], weight='bold', size=fs-2,)

plt.tight_layout()

plt.savefig(os.path.join(utilsFiles.get_project_path('figures'),
            'final_Q1_2024', 'prediction', 'FIG7', FIGNAME),
            dpi=300, facecolor='w',)

plt.close()

Plot individual accuracies vs CDRS severity

In [None]:
pltHelp.get_colors('Jacoba')

In [None]:
FIGNAME = '1018_acc_vs_cdrs_3models'
fs=14
colors = np.array(pltHelp.get_colors('Jacoba'))[[4, 0, 2]]


fig, axes = plt.subplots(
    3, 1, figsize=(6, 8),
    # gridspec_kw={'width_ratios': [1, 4]},
)

for i, (lab, values) in enumerate(
    zip(['STN ratio-biomarker', 'STN LM (multivariate)', 'STN CEBRA (multivariate)'],
        [ratio_acc, lm_acc, cebra_acc])
):
    
    bars_x, bars_y = [], []
    bars_x_ns, bars_y_ns = [], []
    all_x = []

    for s, v in zip(paired_acc['subs'], values):
        cdrs_times, cdrs_scores = importClin.get_cdrs_specific(sub=s, INCL_CORE_CDRS=True,)
        max_lid = max(cdrs_scores)
        all_x.append(max_lid)
        # get x-deviation if value doubles
        w = sum(np.array(all_x) == max_lid)
        max_lid += (.2 * w)

        bars_x.append(max_lid)
        bars_y.append(v)
        
    axes[i].bar(bars_x, height=bars_y, alpha=.5, facecolor=colors[i],
                width=.5,)


    axes[i].set_title(lab, size=fs, weight='bold',)
    

for ax in axes:
    ax.set_xlabel('Maximum LID severity (sum CDRS)', size=fs,)
    ax.set_ylim(0, 1)
    ax.set_ylabel('Performance\n(bal. accuracy)', size=fs,)
    ax.axhline(0.5, ls='--', c='gray', alpha=.5,)
    ax.set_yticks([0, .25, .5, .75, 1])
    ax.set_yticklabels(['0', '0.25', '0.5', '0.75', '1'])


    ax.tick_params(axis='both', labelsize=fs, size=fs,)
    for side in ['top','right']:
        ax.spines[side].set_visible(False)

plt.tight_layout()
plt.savefig(os.path.join(utilsFiles.get_project_path('figures'),
            'final_Q1_2024', 'prediction', 'FIG7', FIGNAME),
            dpi=300, facecolor='w',)

plt.close()

Plot similar barplot for acc vs activity

In [None]:
dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res21,
    source='STN',
    CEBRA_bool=CEBRA_bool,    
)
dat_dict['008'].keys()
ACT_RANGES = np.arange(-2, 4, .35)

lid_acc = []
nolid_acc = []

for i, act_x in enumerate(ACT_RANGES):
    # make sure every act-range has lists
    lid_acc.append([])
    nolid_acc.append([])

    for subdat in dat_dict.values():

        if i == 0:
            act_sel = subdat['acc'] < act_x
        elif i == len(ACT_RANGES) - 1:
            act_sel = subdat['acc'] > ACT_RANGES[i-1]
        else:
            act_sel = np.logical_and(subdat['acc'] < act_x,
                                     subdat['acc'] > ACT_RANGES[i-1])
        
        if sum(act_sel) == 0: continue

        # select dysk and nondysk predictions, without activity range
        sel_LID = subdat['y_true'][act_sel] == 1
        sel_noLID = subdat['y_true'][act_sel] == 0
        # if no lid or lid is present: add accuracy
        if sum(sel_LID) > 0:
            preds_LID = subdat['y_pred'][act_sel][sel_LID]
            lid_acc[i].append(sum(preds_LID == 1) / len(preds_LID))
        if sum(sel_noLID) > 0:
            preds_noLID = subdat['y_pred'][act_sel][sel_noLID]
            nolid_acc[i].append(sum(preds_noLID == 0) / len(preds_noLID))
        

In [None]:
FIGNAME = 'FIG7_acc_lidNolid_vs_act'

colors = ['darkgreen', 'darkorchid']

fig, ax = plt.subplots(1,1, figsize=(9, 3))

for i, values in enumerate([nolid_acc, lid_acc]):
    if i == 0: xdf = -.25
    else: xdf = .25
    bplot = ax.boxplot(values, patch_artist=True, widths=.4,
                       positions=np.arange(len(values)) + xdf,)
    for b in bplot['boxes']:
        b.set_edgecolor(colors[i])
        b.set_facecolor(colors[i])
        b.set_alpha(.5)
    for m, w in zip(bplot['medians'], bplot['whiskers']):
        m.set_color('k')
        w.set_color(colors[i])

ax.set_ylabel('Accuracy (a.u.)', size=fs, weight='bold',)
ax.set_ylim(0, 1.1)
ax.axhline(0.5, ls='--', c='gray', alpha=.5,)
ax.set_yticks([0, .25, .5, .75, 1])
ax.set_yticklabels(['0', '0.25', '0.5', '0.75', '1'])

# ax.set_xlabel('Detected movement (z-scored acc-rms)', size=fs,)
ax.set_xticks([2, len(ACT_RANGES) - 2],)
ax.set_xticklabels(['No movement', 'Many movements'], weight='bold',)
ax.set_xlim(0, len(ACT_RANGES) + 1)

 # Add Legend for Boxes
fig.text(0.98, 0.8, 'no LID', color='k', ha='right',
            weight='bold', size=fs-4,
            bbox={'alpha': .3, 'color': 'darkgreen'},)
fig.text(0.98, 0.65, 'LID', color='k', ha='right',
            weight='bold', size=fs-4,
            bbox={'alpha': .3, 'color': 'darkorchid'},)

ax.tick_params(axis='both', labelsize=fs, size=fs,)
for side in ['top','right']:
    ax.spines[side].set_visible(False)

plt.tight_layout()
plt.savefig(os.path.join(utilsFiles.get_project_path('figures'),
            'final_Q1_2024', 'prediction', 'FIG7', FIGNAME),
            dpi=300, facecolor='w',)

plt.close()

Old Result Plots: Individual and Group

In [None]:
# Old Plotting

# sub = '008'

# if CEBRA_bool:
#     modelname = 'cebra_'
#     model = 'cebra'
# else:
#     modelname = 'linMod'
#     model = 'lm'

# if incl_ECOG:
#     modelname += 'inclEcog'
# else:
#     modelname += 'StnOnly'

# for sub in dat_dict.keys():

#     plotPreds.plot_subPreds_over_time(
#         lid_out_param=OUT_PARAM,
#         sub_dict=dat_dict[sub],
#         SAVE_PLOT=True, SHOW_PLOT=False,
#         fig_name=f'0325_predRows_{modelname}_Lid{OUT_PARAM}_sub{sub}',
#         model=model,
#     )

In [None]:
importlib.reload(plotPreds)

# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'

if CEBRA_bool:
    modelname = 'cebra_'
    model = 'cebra'
else:
    modelname = 'linMod'
    model = 'lm'

if incl_ECOG:
    modelname += 'inclEcog'
else:
    modelname += 'StnOnly'



# dat_dict = get_sub_pred_dicts(
#     pred_result_pickle=res,
#     OUT_PARAM=OUT_PARAM,
#     incl_ECOG=incl_ECOG,
#     CEBRA_bool=CEBRA_bool,
# )



# plotPreds.plot_groupPreds_over_time(
#     lid_out_param=OUT_PARAM,
#     dat_dict=dat_dict,
#     SAVE_PLOT=True, SHOW_PLOT=False,
#     fig_name=f'0326_predGroupRows_{modelname}_Lid{OUT_PARAM}',
#     model=model,
# )

### Plot overall prediction performance

In [None]:
PLOT_DATE = '0429'

In [None]:
print(dat_dict.keys())
print(dat_dict['STN_only']['008'].keys())

#### Binary Performance

In [None]:
SAVE_PLOT = True
SHOW_PLOT = False
fig_name = f'{PLOT_DATE}_binaryPerfBoxes'
CEBRA = False
if CEBRA: fig_name += '_cebra'
else: fig_name += '_lm'

dat_dict = {
    'STN_only': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='binary',
        incl_ECOG=False,
        CEBRA_bool=CEBRA,
    ),
    'ECOG_incl': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='binary',
        incl_ECOG=True,
        CEBRA_bool=CEBRA,
    )
}

clrs = ['peru', 'khaki', 'mediumturquoise', 'plum']
fsize = 14
BOX_WIDTH = .2
BOX_TICKS = np.array([i * BOX_WIDTH for i in np.arange(len(clrs))])

xlabel_ticks, xlabels = [], []

fig, ax = plt.subplots(1, 1, figsize=(6, 3))

ax.axhline(.5, xmin=-BOX_WIDTH, xmax=10, ls='dashed',
            color='gray', lw=1, alpha=.5, zorder=1,)

for i, (key, dat) in enumerate(dat_dict.items()):

    bal_acc, acc, sens, spec = [], [], [], []

    for s in dat.keys():
        # calculate performance numbers
        bal_acc.append(dat[s]['perf'])
        n_pred = len(dat[s]['y_pred'])
        tp = sum(np.logical_and(dat[s]['y_pred'] == 1,
                                dat[s]['y_true'] == 1))  # true pos
        fp = sum(np.logical_and(dat[s]['y_pred'] == 1,
                                dat[s]['y_true'] == 0))  # false pos
        tn = sum(np.logical_and(dat[s]['y_pred'] == 0,
                                dat[s]['y_true'] == 0))  # true neg
        fn = sum(np.logical_and(dat[s]['y_pred'] == 0,
                                dat[s]['y_true'] == 1))  # false neg
        
        # add metrics
        acc.append((tp + tn) / n_pred)
        if (tp + fn) > 0:
            sens.append(tp / (tp + fn))
        else:
            print(f'\t{key}: no sensitivity calculalted for sub-{s} (no dyskinesia)')
        spec.append(tn / (tn + fp))

    # PLOT BOXES
    boxes = ax.boxplot([bal_acc, acc, sens, spec],
                positions=BOX_TICKS + i, widths=BOX_WIDTH,
                patch_artist=True,)
    xlabel_ticks.append(BOX_TICKS[1] + i)
    xlabels.append(key)
    # make boxplots pretty (incl sign)
    for i_box, (patch, clr) in enumerate(zip(boxes['boxes'], clrs)):
        patch.set_facecolor(clr)
        a = .7
        # if INCL_SIGN:
        #     sig = sign_list[i_box]
        #     if sig == True: a = .8
        #     elif sig == False: a = .25
        patch.set_alpha(a)
    for median in boxes['medians']:
        median.set_color('black')

ax.set_xticks(xlabel_ticks)
ax.set_xticklabels(xlabels)
ax.set_xlim(-BOX_WIDTH, xlabel_ticks[-1] + BOX_WIDTH*8)
ax.set_ylim(-.1, 1.1)
ax.set_yticks([0, .5, 1])
ax.set_yticklabels(['0', '0.5', '1'])
ax.set_ylabel('Performance', size=fsize, weight='bold',)
if CEBRA: xlab = 'CEBRA model'
else: xlab = 'Linear model'
ax.set_xlabel(xlab, size=fsize, weight='bold',)

ax.spines[['right', 'top']].set_visible(False)

ax.tick_params(axis='both', size=fsize, labelsize=fsize)

# Add Legend for Boxes
fig.text(0.7, 0.7, 'Bal. accuracy', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[0], 'lw': 0,},)
fig.text(0.7, 0.6, 'Accuracy', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[1], 'lw': 0,},)
fig.text(0.7, 0.5, 'Sensitivity', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[2], 'lw': 0,},)
fig.text(0.7, 0.4, 'Specificity', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[3], 'lw': 0,},)

plt.tight_layout()


if SAVE_PLOT:
    path = os.path.join(utilsFiles.get_project_path('figures'),
                    'final_Q1_2024', 'prediction', 'group_v8', 'binary')
    plt.savefig(os.path.join(path, fig_name),
                    facecolor='w', dpi=300,)

if not SHOW_PLOT: plt.close()
else: plt.show()

#### Scale Performance

- calculate performance of scaling model (full CDRS score).
- Plot Group Rho and abs mean error; plus individual rho and MAE sorted on Max-CDRS

In [None]:
dat_dict['ECOG_incl']['008'].keys()

In [None]:
SAVE_PLOT = True
SHOW_PLOT = False
fig_name = f'{PLOT_DATE}_scalePerfBoxes'
CEBRA = False
if CEBRA: fig_name += '_cebra'
else: fig_name += '_lm'

dat_dict = {
    'STN_only': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=False,
        CEBRA_bool=CEBRA,
    ),
    'ECOG_incl': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=True,
        CEBRA_bool=CEBRA,
    )
}

clrs = {0: ['tan', 'steelblue'], 1: ['None', 'None']}
edgeclrs = {0: ['tan', 'steelblue'],
            1: ['tan', 'steelblue']}
metric_labs = ['Mean\nabs. error', 'Correlation\n(rho)']
fsize = 14
BOX_WIDTH = .2
BOX_TICKS = np.array([i * BOX_WIDTH for i in np.arange(1)])  # len(clrs)

xlabel_ticks, xlabels = [], []

fig, axes = plt.subplots(2, 1, figsize=(8, 4))

# ax.axhline(.5, xmin=-BOX_WIDTH, xmax=10, ls='dashed',
#             color='gray', lw=1, alpha=.5, zorder=1,)

for i, (key, dat) in enumerate(dat_dict.items()):
    
    mae_list, corrs, subs, maxlid_corr, maxlid_mae = [], [], [], [], []
    xlabel_ticks.append(BOX_TICKS[0] + i)
    xlabels.append(key)

    for s in dat_dict['STN_only'].keys():

        if s.startswith('1') and key == 'ECOG_incl':
            maxlid_mae.append(0)
            maxlid_corr.append(0)
            mae_list.append(np.nan)
            corrs.append(np.nan)

            continue

        # calculate performance numbers
        pred_y = dat[s]['y_pred'].copy()
        true_y = dat[s]['y_true'].copy()
        rho, pval = spearmanr(pred_y, true_y)
        mae = dat[s]['perf'].copy()
        subs.append(s)

        if not np.isnan(mae):
            mae_list.append(mae)
            maxlid_mae.append(max(true_y))

        if not np.isnan(rho):
            corrs.append(rho)
            maxlid_corr.append(max(true_y))
    
    
        
    # PLOT BOXES
    for i_ax, (ax, metric, maxlid) in enumerate(
        zip(axes, [mae_list, corrs], [maxlid_mae, maxlid_corr])
    ):
        nonnan_metrics = [m for m in metric if not np.isnan(m)]
        group_metrics = ax.boxplot(nonnan_metrics, positions=[1 * i], widths=.5,
                                    patch_artist=True,)
        
        # make boxplots pretty (incl sign)
        for i_box, (patch, clr) in enumerate(zip(group_metrics['boxes'], clrs)):
            patch.set_facecolor(clrs[0][i])
            patch.set_edgecolor(edgeclrs[0][i])
            # if i == 1: patch.set_hatch('OO')
            patch.set_alpha(.7)
            patch.set_linewidth(0)
        for median in group_metrics['medians']:
            median.set_color('black')

        # plot boxes sorted on LID
        idx_lid = np.argsort(maxlid)
        sortedboxes = list(np.array(metric)[idx_lid])
        xpos = np.array([2.5 + (.5 * b) for b in np.arange(len(metric))])
        bar_pms = {'facecolor': clrs[0][i], 'edgecolor': edgeclrs[0][i],
                    'alpha': .7,}
        # if i == 1: bar_pms['hatch'] = 'OO'
        bar = ax.bar(xpos + (i * .2), height=sortedboxes, width=.2, **bar_pms)
        

        
        # ax.set_xlim(-BOX_WIDTH, xlabel_ticks[-1] + BOX_WIDTH*8)

        ax.set_ylabel(metric_labs[i_ax], size=fsize, weight='bold',)
        if CEBRA: xlab = 'CEBRA model'
        else: xlab = 'Linear model'
        # ax.set_xlabel(xlab, size=fsize, weight='bold',)
        ax.spines[['right', 'top']].set_visible(False)
        ax.tick_params(axis='both', size=fsize, labelsize=fsize)

axes[0].set_xticks([0, 1])  
axes[0].set_xticklabels([])  
axes[1].set_xticks([0, 1, 6])
axes[1].set_xticklabels(['Group means', ' ', 'subjects sorted on max-dyskinesia-severity'])   


axes[0].set_ylim(0, 7)
axes[0].set_yticks([0, 2.5, 5])
axes[0].set_yticklabels(['0', '2.5', '5'])
axes[1].set_ylim(-.1, 1.1)
axes[1].set_yticks([0, .5, 1])
axes[1].set_yticklabels(['0', '0.5', '1'])
# Add Legend for Boxes
fig.text(0.35, 0.9, list(dat_dict.keys())[0], color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[0][0], 'lw': 0,},)
fig.text(0.5, 0.9, list(dat_dict.keys())[1], color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[0][1], 'lw': 0,},)

plt.tight_layout()


if SAVE_PLOT:
    path = os.path.join(utilsFiles.get_project_path('figures'),
                    'final_Q1_2024', 'prediction', 'group_v8', 'scale')
    plt.savefig(os.path.join(path, fig_name),
                    facecolor='w', dpi=300,)

if not SHOW_PLOT: plt.close()
else: plt.show()

#### Scale results: scatter with corelation

In [None]:
from matplotlib.gridspec import GridSpec

Plot single scatters

In [None]:
SAVE_PLOT = True
SHOW_PLOT = False

dat_dict = {
    'STN_only': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=False,
        CEBRA_bool=True,
    ),
    'ECOG_incl': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=True,
        CEBRA_bool=True,
    )
}

fsize = 14
clrs = ['teal', 'darkkhaki']

# fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# ax.plot([0, 1], [0, 1], color='gray', lw=1, alpha=.5, zorder=1,)

for s in dat_dict['STN_only'].keys():
    # fig, axes = plt.subplots(1, 2, figsize=(6, 2))
    figw, figh = (16, 6)
    fig = plt.figure( figsize=(8, 4),)  # layout='constrained',
    gs = GridSpec(figh, figw, figure=fig)
    scat_axes = []
    META_PLOTTED = False

    for i, (key, dat) in enumerate(dat_dict.items()):
        if s.startswith('1') and key == 'ECOG_incl': continue

# for s in dat.keys():
        # calculate performance numbers
        pred_y = dat[s]['y_pred'].copy()
        true_y = dat[s]['y_true'].copy()
        rho, pval = spearmanr(pred_y, true_y)

        xjitt, yjitt = pltHelp.get_plot_jitter(x_temp=true_y, y_temp=pred_y,
                                                jit_width=.2)
        pred_y += yjitt
        true_y += xjitt
        
        # add gridspec
        if i == 0: scat_axes.append(fig.add_subplot(gs[2:, :int(figw / 2)]))
        if i == 1: scat_axes.append(fig.add_subplot(gs[2:, int(figw / 2):]))
                
        if all(dat[s]['y_true'] == 0):
            # ax.set_title(f'{s}: {key}\n(no dyskinesia)')
            leglab = f'{s}: {key}\n(no dyskinesia)'
        else:
            # ax.set_title(f'{s}: {key}\n(R: {rho.round(2)}, p: {round(pval, 5)})')
            leglab = f'{s}, R: {rho.round(2)}, p: {round(pval, 3)}'
        
        
        scat_axes[i].scatter(true_y, pred_y, alpha=.3, label=leglab)
        if i == 0: scat_axes[0].set_ylabel('Predicted CDRS', weight='bold',
                                           size=fsize, )


        if not META_PLOTTED:
            ax_acc = fig.add_subplot(gs[0, :])
            ax_acc.plot(dat[s]['times'], dat[s]['acc'])
            ax_true = fig.add_subplot(gs[1, :])
            ax_true.plot(dat[s]['times'], dat[s]['y_true'])
            ax_true.get_shared_x_axes().join(ax_true, ax_acc)
            META_PLOTTED = True
            plt.setp(ax_acc.get_xticklabels(), visible=False)
            ax_true.set_ylabel('LID', rotation=0, labelpad=15,
                                weight='bold', size=fsize,)
            ax_acc.set_ylabel('ACC', rotation=0, labelpad=15,
                            weight='bold', size=fsize,)

    for ax in scat_axes:
        ax.set_xlabel('True CDRS', weight='bold', size=fsize,)
        ax.legend(fontsize=fsize-4, frameon=False, loc='lower right')

    plt.subplots_adjust(wspace=0.3, hspace=2)


    if SAVE_PLOT:
        path = os.path.join(utilsFiles.get_project_path('figures'),
                        'final_Q1_2024', 'prediction', 'cebra', 'scale', 'scatters')
        plt.savefig(os.path.join(path, f'{s}_cebra_scatter'),
                        facecolor='w', dpi=300,)

    if not SHOW_PLOT: plt.close()
    else: plt.show()



## Visualize activity in predictions

In [None]:
import lfpecog_plotting.plot_Spectrals_vs_LID as plotSpectrals
import lfpecog_plotting.plot_activity_in_preds as plotPredAct

In [None]:
dat_dict['019'].keys()

sub_dict = dat_dict['019']

# print(np.around(sub_dict['y_pred']))

#### Scatter scale-predictions over Activity and True-CDRS

In [None]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = False
CEBRA_bool = True  # True -> CEBRA model, False -> linear model

OUT_PARAM = 'scale'   # binary no sense for scatter

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

fig_name=f'{PLOT_DATE}_predErrorScatter_{model}'


plotPredAct.scatter_predErrors(
    dat_dict, SAVE_PLOT=True,
    fig_name=fig_name,
    ROUND_PREDS=False,
    modelname=model, out_param=OUT_PARAM,
)



#### Plot activities in binary pos and neg predictions

Plots relative histograms of true and false predictions

In [None]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'   # scale not for binary plotting

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

plotPredAct.plot_binary_act_distr(
    dat_dict=dat_dict,
    model=model,
    SAVE_PLOT=True, SHOW_PLOT=False,
    fig_name=f'{PLOT_DATE}_binary_act_pred_{model}'
)


Plot Predictive Performance over act-bins
- Sensitivity / Specificity
- Accuracies of positive and negative predictions

In [None]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = False
CEBRA_bool = False  # True -> CEBRA model, False -> linear model

OUT_PARAM = 'binary'   # scale not for binary plotting
SensSpec = True
PredAccur = False

for incl_ECOG, CEBRA_bool in product([True, False],
                                     [True, False]): 

    if CEBRA_bool: model = 'cebra_'
    else: model = 'lm_'

    if incl_ECOG: model += 'inclEcog'
    else: model += 'StnOnly'


    dat_dict = get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM=OUT_PARAM,
        incl_ECOG=incl_ECOG,
        CEBRA_bool=CEBRA_bool,
    )

    plotPredAct.plot_predValues_per_ActBin(
        dat_dict=dat_dict,
        model=model,
        SensSpec=SensSpec,
        PredAcc=PredAccur,
        SAVE_PLOT=True, SHOW_PLOT=False,
        fig_name=f'{PLOT_DATE}_PredSensSpec_ActBins_{model}',
    )
