In [12]:
# Importing Python and external packages
import os
import sys
import importlib
import json
import csv
import pickle
from itertools import compress
import pandas as pd
import numpy as np
from itertools import product
from scipy.stats import spearmanr

import matplotlib.pyplot as plt


In [2]:
def get_project_path_in_notebook(
    subfolder: str = '',
):
    """
    Finds path of projectfolder from Notebook.
    Start running this once to correctly find
    other modules/functions
    """
    path = os.getcwd()

    while path[-20:] != 'dyskinesia_neurophys':

        path = os.path.dirname(path)
    
    return path

In [3]:
# define local storage directories
projectpath = get_project_path_in_notebook()
codepath = os.path.join(projectpath, 'code')

In [4]:
os.chdir(codepath)
# own utility functions
import utils.utils_fileManagement as utilsFiles
# own data exploration functions
from lfpecog_plotting.plotHelpers import get_colors
import lfpecog_plotting.plotHelpers as pltHelp

import lfpecog_plotting.plot_pred_standards as plotPreds

### 0) Define settings

In [5]:

DATA_VERSION = 'v4.0'    # v4.0: new artef-rem, no reref; v3.0 multiple re-ref
FT_VERSION = 'v8'  # v4: broad-flanks, bursts; v3: broad-flanked SSD
INCL_PSD_FTS=['mean_psd', 'variation']
IGNORE_PTS = ['011', '104', '106']

CDRS_RATER = 'Patricia'
ANALYSIS_SIDE = 'BILAT'
INCL_CORE_CDRS = True
CATEGORICAL_CDRS = False

In [6]:
# get all available subs with features
SUBS = utilsFiles.get_avail_ssd_subs(DATA_VERSION=DATA_VERSION,
                                     FT_VERSION=FT_VERSION,
                                     IGNORE_PTS=IGNORE_PTS)
print(f'SUBS: n={len(SUBS)} ({SUBS})')


SUBS: n=21 (['017', '021', '023', '107', '019', '014', '020', '110', '105', '109', '013', '016', '010', '102', '008', '101', '009', '108', '022', '012', '103'])


#### Import Pred Results Timon

In [29]:
path = os.path.join(utilsFiles.get_project_path('data'),
                    'prediction_data', 'pred_results')
# filename = 'pred_results_0322.pickle'
# filename = 'd_out_NEW_FEATURES_REDUCED_WITH_MOVEMENT_offset_10_dim_4.pickle'  # ft_v7, figures end of March
filename = 'ftv8_09APR24.pickle'  # v8 n=all (13 or 21)
filename = 'ftv8_only13_30APR24.pickle'  # v8, n=13 for all, incl 4-class movement aware

assert os.path.exists(os.path.join(path, filename)), 'prediction pickle file not found'

with open(os.path.join(path, filename), 'rb') as f:
    res = pickle.load(f)

In [31]:
print(res.keys())

key = 'ECOG_CEBRA_False_binary'  # without 4 class pickle
key = 'ECOGSTN_CEBRA_False_binary_AddMovementLabels_Ephys'  # 4 class pred pcikle
print(res[key].keys())

# overview of content
print(len(res[key]['prediction']))  # array per sub
print(len(res[key]['performances']))  # one value per sub
print(type(res[key]['prediction'][0]),
      len(res[key]['prediction'][0]))  # array per sub
print(res[key]['performances'][0])  # one value per sub


dict_keys(['ECOGSTN_CEBRA_False_categ_AddMovementLabels_Ephys', 'ECOGSTN_CEBRA_True_categ_AddMovementLabels_Ephys', 'ECOGSTN_CEBRA_False_binary_AddMovementLabels_Ephys', 'ECOGSTN_CEBRA_True_binary_AddMovementLabels_Ephys', 'ECOGSTN_CEBRA_False_scale_AddMovementLabels_Ephys', 'ECOGSTN_CEBRA_True_scale_AddMovementLabels_Ephys', 'ECOG_CEBRA_False_categ_AddMovementLabels_Ephys', 'ECOG_CEBRA_True_categ_AddMovementLabels_Ephys', 'ECOG_CEBRA_False_binary_AddMovementLabels_Ephys', 'ECOG_CEBRA_True_binary_AddMovementLabels_Ephys', 'ECOG_CEBRA_False_scale_AddMovementLabels_Ephys', 'ECOG_CEBRA_True_scale_AddMovementLabels_Ephys', 'STN_CEBRA_False_categ_AddMovementLabels_Ephys', 'STN_CEBRA_True_categ_AddMovementLabels_Ephys', 'STN_CEBRA_False_binary_AddMovementLabels_Ephys', 'STN_CEBRA_True_binary_AddMovementLabels_Ephys', 'STN_CEBRA_False_scale_AddMovementLabels_Ephys', 'STN_CEBRA_True_scale_AddMovementLabels_Ephys'])
dict_keys(['prediction', 'performances', 'cm', 'y_test_pred', 'y_test_true', 'X

In [34]:
len(res[key]['prediction'])

13

#### Import Acc and Time data

In [None]:
# # Load arrays for Acc and Sub-codes
# source = 'STN'
# with open(os.path.join(utilsFiles.get_project_path('data'),
#                  'prediction_data',
#                  f"ACC_dataPlus_{source}.pickle"),
#     "rb"
# ) as f:
#     metadat = pickle.load(f)

In [None]:
# metadat

### 1) Visualize True and Predicted Labels vs Activity

In [9]:
def get_sub_pred_dicts(
    pred_result_pickle,
    OUT_PARAM='scale',
    incl_ECOG=False,
    CEBRA_bool = False  # True -> CEBRA model, False -> linear model

):
    assert OUT_PARAM in ['binary', 'scale'], 'incorrect outcome parameter'

    if incl_ECOG: source = 'ECOG'
    else: source = 'STN'

    pred_key = f'{source}_CEBRA_{str(CEBRA_bool)}_{OUT_PARAM}'

    assert pred_key in pred_result_pickle.keys(), 'composed dict-key not in PICKLE'

    # Load arrays for Acc and Sub-codes
    with open(os.path.join(utilsFiles.get_project_path('data'),
                    'prediction_data',
                    f"ACC_dataPlus_{source}.pickle"),
        "rb"
    ) as f:
        metadat = pickle.load(f)

    # get dicts for sub plotting
    dat_dict = {}

    for l_pred, l_true, l_perf, s in zip(
        pred_result_pickle[pred_key]['prediction'],
        pred_result_pickle[pred_key]['y_test_true'],
        pred_result_pickle[pred_key]['performances'],
        np.unique(metadat['sub_ids'])
    ):
        sub_sel = metadat['sub_ids'] == s
        assert len(l_pred) == len(l_true) == sum(sub_sel), (
            f'sub-{s} mismatch y_pred and metadat'
        )

        dat_dict[s] = {
            'y_true': l_true,
            'y_pred': l_pred,
            'perf': l_perf,
            'acc': metadat['ACC_RMS'][sub_sel],
            'times': metadat['ft_times_all'][sub_sel]
        }

    return dat_dict


In [10]:
importlib.reload(plotPreds)

# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'scale'

if CEBRA_bool:
    modelname = 'cebra_'
    model = 'cebra'
else:
    modelname = 'linMod'
    model = 'lm'

if incl_ECOG:
    modelname += 'inclEcog'
else:
    modelname += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)



# sub = '008'

# for sub in dat_dict.keys():

#     plotPreds.plot_subPreds_over_time(
#         lid_out_param=OUT_PARAM,
#         sub_dict=dat_dict[sub],
#         SAVE_PLOT=True, SHOW_PLOT=False,
#         fig_name=f'0325_predRows_{modelname}_Lid{OUT_PARAM}_sub{sub}',
#         model=model,
#     )


Plot Group Results

In [None]:
importlib.reload(plotPreds)

# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'

if CEBRA_bool:
    modelname = 'cebra_'
    model = 'cebra'
else:
    modelname = 'linMod'
    model = 'lm'

if incl_ECOG:
    modelname += 'inclEcog'
else:
    modelname += 'StnOnly'



# dat_dict = get_sub_pred_dicts(
#     pred_result_pickle=res,
#     OUT_PARAM=OUT_PARAM,
#     incl_ECOG=incl_ECOG,
#     CEBRA_bool=CEBRA_bool,
# )



# plotPreds.plot_groupPreds_over_time(
#     lid_out_param=OUT_PARAM,
#     dat_dict=dat_dict,
#     SAVE_PLOT=True, SHOW_PLOT=False,
#     fig_name=f'0326_predGroupRows_{modelname}_Lid{OUT_PARAM}',
#     model=model,
# )

### Plot overall prediction performance

In [14]:
PLOT_DATE = '0429'

In [None]:
print(dat_dict.keys())
print(dat_dict['STN_only']['008'].keys())

#### Binary Performance

In [22]:
SAVE_PLOT = True
SHOW_PLOT = False
fig_name = f'{PLOT_DATE}_binaryPerfBoxes'
CEBRA = False
if CEBRA: fig_name += '_cebra'
else: fig_name += '_lm'

dat_dict = {
    'STN_only': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='binary',
        incl_ECOG=False,
        CEBRA_bool=CEBRA,
    ),
    'ECOG_incl': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='binary',
        incl_ECOG=True,
        CEBRA_bool=CEBRA,
    )
}

clrs = ['peru', 'khaki', 'mediumturquoise', 'plum']
fsize = 14
BOX_WIDTH = .2
BOX_TICKS = np.array([i * BOX_WIDTH for i in np.arange(len(clrs))])

xlabel_ticks, xlabels = [], []

fig, ax = plt.subplots(1, 1, figsize=(6, 3))

ax.axhline(.5, xmin=-BOX_WIDTH, xmax=10, ls='dashed',
            color='gray', lw=1, alpha=.5, zorder=1,)

for i, (key, dat) in enumerate(dat_dict.items()):

    bal_acc, acc, sens, spec = [], [], [], []

    for s in dat.keys():
        # calculate performance numbers
        bal_acc.append(dat[s]['perf'])
        n_pred = len(dat[s]['y_pred'])
        tp = sum(np.logical_and(dat[s]['y_pred'] == 1,
                                dat[s]['y_true'] == 1))  # true pos
        fp = sum(np.logical_and(dat[s]['y_pred'] == 1,
                                dat[s]['y_true'] == 0))  # false pos
        tn = sum(np.logical_and(dat[s]['y_pred'] == 0,
                                dat[s]['y_true'] == 0))  # true neg
        fn = sum(np.logical_and(dat[s]['y_pred'] == 0,
                                dat[s]['y_true'] == 1))  # false neg
        
        # add metrics
        acc.append((tp + tn) / n_pred)
        if (tp + fn) > 0:
            sens.append(tp / (tp + fn))
        else:
            print(f'\t{key}: no sensitivity calculalted for sub-{s} (no dyskinesia)')
        spec.append(tn / (tn + fp))

    # PLOT BOXES
    boxes = ax.boxplot([bal_acc, acc, sens, spec],
                positions=BOX_TICKS + i, widths=BOX_WIDTH,
                patch_artist=True,)
    xlabel_ticks.append(BOX_TICKS[1] + i)
    xlabels.append(key)
    # make boxplots pretty (incl sign)
    for i_box, (patch, clr) in enumerate(zip(boxes['boxes'], clrs)):
        patch.set_facecolor(clr)
        a = .7
        # if INCL_SIGN:
        #     sig = sign_list[i_box]
        #     if sig == True: a = .8
        #     elif sig == False: a = .25
        patch.set_alpha(a)
    for median in boxes['medians']:
        median.set_color('black')

ax.set_xticks(xlabel_ticks)
ax.set_xticklabels(xlabels)
ax.set_xlim(-BOX_WIDTH, xlabel_ticks[-1] + BOX_WIDTH*8)
ax.set_ylim(-.1, 1.1)
ax.set_yticks([0, .5, 1])
ax.set_yticklabels(['0', '0.5', '1'])
ax.set_ylabel('Performance', size=fsize, weight='bold',)
if CEBRA: xlab = 'CEBRA model'
else: xlab = 'Linear model'
ax.set_xlabel(xlab, size=fsize, weight='bold',)

ax.spines[['right', 'top']].set_visible(False)

ax.tick_params(axis='both', size=fsize, labelsize=fsize)

# Add Legend for Boxes
fig.text(0.7, 0.7, 'Bal. accuracy', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[0], 'lw': 0,},)
fig.text(0.7, 0.6, 'Accuracy', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[1], 'lw': 0,},)
fig.text(0.7, 0.5, 'Sensitivity', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[2], 'lw': 0,},)
fig.text(0.7, 0.4, 'Specificity', color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[3], 'lw': 0,},)

plt.tight_layout()


if SAVE_PLOT:
    path = os.path.join(utilsFiles.get_project_path('figures'),
                    'final_Q1_2024', 'prediction', 'group_v8', 'binary')
    plt.savefig(os.path.join(path, fig_name),
                    facecolor='w', dpi=300,)

if not SHOW_PLOT: plt.close()
else: plt.show()

	STN_only: no sensitivity calculalted for sub-010 (no dyskinesia)
	STN_only: no sensitivity calculalted for sub-014 (no dyskinesia)
	STN_only: no sensitivity calculalted for sub-017 (no dyskinesia)
	STN_only: no sensitivity calculalted for sub-101 (no dyskinesia)
	STN_only: no sensitivity calculalted for sub-109 (no dyskinesia)
	ECOG_incl: no sensitivity calculalted for sub-010 (no dyskinesia)
	ECOG_incl: no sensitivity calculalted for sub-014 (no dyskinesia)
	ECOG_incl: no sensitivity calculalted for sub-017 (no dyskinesia)


#### Scale Performance

- calculate performance of scaling model (full CDRS score).
- Plot Group Rho and abs mean error; plus individual rho and MAE sorted on Max-CDRS

In [13]:
dat_dict['ECOG_incl']['008'].keys()

dict_keys(['y_true', 'y_pred', 'perf', 'acc', 'times'])

In [24]:
SAVE_PLOT = True
SHOW_PLOT = False
fig_name = f'{PLOT_DATE}_scalePerfBoxes'
CEBRA = False
if CEBRA: fig_name += '_cebra'
else: fig_name += '_lm'

dat_dict = {
    'STN_only': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=False,
        CEBRA_bool=CEBRA,
    ),
    'ECOG_incl': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=True,
        CEBRA_bool=CEBRA,
    )
}

clrs = {0: ['tan', 'steelblue'], 1: ['None', 'None']}
edgeclrs = {0: ['tan', 'steelblue'],
            1: ['tan', 'steelblue']}
metric_labs = ['Mean\nabs. error', 'Correlation\n(rho)']
fsize = 14
BOX_WIDTH = .2
BOX_TICKS = np.array([i * BOX_WIDTH for i in np.arange(1)])  # len(clrs)

xlabel_ticks, xlabels = [], []

fig, axes = plt.subplots(2, 1, figsize=(8, 4))

# ax.axhline(.5, xmin=-BOX_WIDTH, xmax=10, ls='dashed',
#             color='gray', lw=1, alpha=.5, zorder=1,)

for i, (key, dat) in enumerate(dat_dict.items()):
    
    mae_list, corrs, subs, maxlid_corr, maxlid_mae = [], [], [], [], []
    xlabel_ticks.append(BOX_TICKS[0] + i)
    xlabels.append(key)

    for s in dat_dict['STN_only'].keys():

        if s.startswith('1') and key == 'ECOG_incl':
            maxlid_mae.append(0)
            maxlid_corr.append(0)
            mae_list.append(np.nan)
            corrs.append(np.nan)

            continue

        # calculate performance numbers
        pred_y = dat[s]['y_pred'].copy()
        true_y = dat[s]['y_true'].copy()
        rho, pval = spearmanr(pred_y, true_y)
        mae = dat[s]['perf'].copy()
        subs.append(s)

        if not np.isnan(mae):
            mae_list.append(mae)
            maxlid_mae.append(max(true_y))

        if not np.isnan(rho):
            corrs.append(rho)
            maxlid_corr.append(max(true_y))
    
    
        
    # PLOT BOXES
    for i_ax, (ax, metric, maxlid) in enumerate(
        zip(axes, [mae_list, corrs], [maxlid_mae, maxlid_corr])
    ):
        nonnan_metrics = [m for m in metric if not np.isnan(m)]
        group_metrics = ax.boxplot(nonnan_metrics, positions=[1 * i], widths=.5,
                                    patch_artist=True,)
        
        # make boxplots pretty (incl sign)
        for i_box, (patch, clr) in enumerate(zip(group_metrics['boxes'], clrs)):
            patch.set_facecolor(clrs[0][i])
            patch.set_edgecolor(edgeclrs[0][i])
            # if i == 1: patch.set_hatch('OO')
            patch.set_alpha(.7)
            patch.set_linewidth(0)
        for median in group_metrics['medians']:
            median.set_color('black')

        # plot boxes sorted on LID
        idx_lid = np.argsort(maxlid)
        sortedboxes = list(np.array(metric)[idx_lid])
        xpos = np.array([2.5 + (.5 * b) for b in np.arange(len(metric))])
        bar_pms = {'facecolor': clrs[0][i], 'edgecolor': edgeclrs[0][i],
                    'alpha': .7,}
        # if i == 1: bar_pms['hatch'] = 'OO'
        bar = ax.bar(xpos + (i * .2), height=sortedboxes, width=.2, **bar_pms)
        

        
        # ax.set_xlim(-BOX_WIDTH, xlabel_ticks[-1] + BOX_WIDTH*8)

        ax.set_ylabel(metric_labs[i_ax], size=fsize, weight='bold',)
        if CEBRA: xlab = 'CEBRA model'
        else: xlab = 'Linear model'
        # ax.set_xlabel(xlab, size=fsize, weight='bold',)
        ax.spines[['right', 'top']].set_visible(False)
        ax.tick_params(axis='both', size=fsize, labelsize=fsize)

axes[0].set_xticks([0, 1])  
axes[0].set_xticklabels([])  
axes[1].set_xticks([0, 1, 6])
axes[1].set_xticklabels(['Group means', ' ', 'subjects sorted on max-dyskinesia-severity'])   


axes[0].set_ylim(0, 7)
axes[0].set_yticks([0, 2.5, 5])
axes[0].set_yticklabels(['0', '2.5', '5'])
axes[1].set_ylim(-.1, 1.1)
axes[1].set_yticks([0, .5, 1])
axes[1].set_yticklabels(['0', '0.5', '1'])
# Add Legend for Boxes
fig.text(0.35, 0.9, list(dat_dict.keys())[0], color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[0][0], 'lw': 0,},)
fig.text(0.5, 0.9, list(dat_dict.keys())[1], color='k',
            weight='bold', size=fsize, va='top',
            bbox={'alpha': .4, 'color': clrs[0][1], 'lw': 0,},)

plt.tight_layout()


if SAVE_PLOT:
    path = os.path.join(utilsFiles.get_project_path('figures'),
                    'final_Q1_2024', 'prediction', 'group_v8', 'scale')
    plt.savefig(os.path.join(path, fig_name),
                    facecolor='w', dpi=300,)

if not SHOW_PLOT: plt.close()
else: plt.show()



#### Scale results: scatter with corelation

In [16]:
from matplotlib.gridspec import GridSpec

Plot single scatters

In [None]:
SAVE_PLOT = True
SHOW_PLOT = False

dat_dict = {
    'STN_only': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=False,
        CEBRA_bool=True,
    ),
    'ECOG_incl': get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM='scale',
        incl_ECOG=True,
        CEBRA_bool=True,
    )
}

fsize = 14
clrs = ['teal', 'darkkhaki']

# fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# ax.plot([0, 1], [0, 1], color='gray', lw=1, alpha=.5, zorder=1,)

for s in dat_dict['STN_only'].keys():
    # fig, axes = plt.subplots(1, 2, figsize=(6, 2))
    figw, figh = (16, 6)
    fig = plt.figure( figsize=(8, 4),)  # layout='constrained',
    gs = GridSpec(figh, figw, figure=fig)
    scat_axes = []
    META_PLOTTED = False

    for i, (key, dat) in enumerate(dat_dict.items()):
        if s.startswith('1') and key == 'ECOG_incl': continue

# for s in dat.keys():
        # calculate performance numbers
        pred_y = dat[s]['y_pred'].copy()
        true_y = dat[s]['y_true'].copy()
        rho, pval = spearmanr(pred_y, true_y)

        xjitt, yjitt = pltHelp.get_plot_jitter(x_temp=true_y, y_temp=pred_y,
                                                jit_width=.2)
        pred_y += yjitt
        true_y += xjitt
        
        # add gridspec
        if i == 0: scat_axes.append(fig.add_subplot(gs[2:, :int(figw / 2)]))
        if i == 1: scat_axes.append(fig.add_subplot(gs[2:, int(figw / 2):]))
                
        if all(dat[s]['y_true'] == 0):
            # ax.set_title(f'{s}: {key}\n(no dyskinesia)')
            leglab = f'{s}: {key}\n(no dyskinesia)'
        else:
            # ax.set_title(f'{s}: {key}\n(R: {rho.round(2)}, p: {round(pval, 5)})')
            leglab = f'{s}, R: {rho.round(2)}, p: {round(pval, 3)}'
        
        
        scat_axes[i].scatter(true_y, pred_y, alpha=.3, label=leglab)
        if i == 0: scat_axes[0].set_ylabel('Predicted CDRS', weight='bold',
                                           size=fsize, )


        if not META_PLOTTED:
            ax_acc = fig.add_subplot(gs[0, :])
            ax_acc.plot(dat[s]['times'], dat[s]['acc'])
            ax_true = fig.add_subplot(gs[1, :])
            ax_true.plot(dat[s]['times'], dat[s]['y_true'])
            ax_true.get_shared_x_axes().join(ax_true, ax_acc)
            META_PLOTTED = True
            plt.setp(ax_acc.get_xticklabels(), visible=False)
            ax_true.set_ylabel('LID', rotation=0, labelpad=15,
                                weight='bold', size=fsize,)
            ax_acc.set_ylabel('ACC', rotation=0, labelpad=15,
                            weight='bold', size=fsize,)

    for ax in scat_axes:
        ax.set_xlabel('True CDRS', weight='bold', size=fsize,)
        ax.legend(fontsize=fsize-4, frameon=False, loc='lower right')

    plt.subplots_adjust(wspace=0.3, hspace=2)


    if SAVE_PLOT:
        path = os.path.join(utilsFiles.get_project_path('figures'),
                        'final_Q1_2024', 'prediction', 'cebra', 'scale', 'scatters')
        plt.savefig(os.path.join(path, f'{s}_cebra_scatter'),
                        facecolor='w', dpi=300,)

    if not SHOW_PLOT: plt.close()
    else: plt.show()



## Visualize activity in predictions

In [17]:
import lfpecog_plotting.plot_Spectrals_vs_LID as plotSpectrals
import lfpecog_plotting.plot_activity_in_preds as plotPredAct

In [None]:
dat_dict['019'].keys()

sub_dict = dat_dict['019']

# print(np.around(sub_dict['y_pred']))

#### Scatter scale-predictions over Activity and True-CDRS

In [27]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = False
CEBRA_bool = True  # True -> CEBRA model, False -> linear model

OUT_PARAM = 'scale'   # binary no sense for scatter

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

fig_name=f'{PLOT_DATE}_predErrorScatter_{model}'


plotPredAct.scatter_predErrors(
    dat_dict, SAVE_PLOT=True,
    fig_name=fig_name,
    ROUND_PREDS=False,
    modelname=model, out_param=OUT_PARAM,
)



...saved 0429_predErrorScatter_cebra_StnOnly in c:\Users\habetsj\Research\projects\dyskinesia_neurophys\figures\final_Q1_2024\prediction\group_v8\scale


#### Plot activities in binary pos and neg predictions

Plots relative histograms of true and false predictions

In [19]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'   # scale not for binary plotting

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

plotPredAct.plot_binary_act_distr(
    dat_dict=dat_dict,
    model=model,
    SAVE_PLOT=True, SHOW_PLOT=False,
    fig_name=f'{PLOT_DATE}_binary_act_pred_{model}'
)


  return n/db/n.sum(), bin_edges


...saved 0429_binary_act_pred_cebra_inclEcog in c:\Users\habetsj\Research\projects\dyskinesia_neurophys\figures\final_Q1_2024\prediction\group_v8\binary


Plot Predictive Performance over act-bins
- Sensitivity / Specificity
- Accuracies of positive and negative predictions

In [21]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = False
CEBRA_bool = False  # True -> CEBRA model, False -> linear model

OUT_PARAM = 'binary'   # scale not for binary plotting
SensSpec = True
PredAccur = False

for incl_ECOG, CEBRA_bool in product([True, False],
                                     [True, False]): 

    if CEBRA_bool: model = 'cebra_'
    else: model = 'lm_'

    if incl_ECOG: model += 'inclEcog'
    else: model += 'StnOnly'


    dat_dict = get_sub_pred_dicts(
        pred_result_pickle=res,
        OUT_PARAM=OUT_PARAM,
        incl_ECOG=incl_ECOG,
        CEBRA_bool=CEBRA_bool,
    )

    plotPredAct.plot_predValues_per_ActBin(
        dat_dict=dat_dict,
        model=model,
        SensSpec=SensSpec,
        PredAcc=PredAccur,
        SAVE_PLOT=True, SHOW_PLOT=False,
        fig_name=f'{PLOT_DATE}_PredSensSpec_ActBins_{model}',
    )


...saved 0429_PredSensSpec_ActBins_cebra_inclEcog in c:\Users\habetsj\Research\projects\dyskinesia_neurophys\figures\final_Q1_2024\prediction\group_v8\binary\SensSpec
...saved 0429_PredSensSpec_ActBins_lm_inclEcog in c:\Users\habetsj\Research\projects\dyskinesia_neurophys\figures\final_Q1_2024\prediction\group_v8\binary\SensSpec
...saved 0429_PredSensSpec_ActBins_cebra_StnOnly in c:\Users\habetsj\Research\projects\dyskinesia_neurophys\figures\final_Q1_2024\prediction\group_v8\binary\SensSpec
...saved 0429_PredSensSpec_ActBins_lm_StnOnly in c:\Users\habetsj\Research\projects\dyskinesia_neurophys\figures\final_Q1_2024\prediction\group_v8\binary\SensSpec
