In [None]:
# Importing Python and external packages
import os
import sys
import importlib
import json
import csv
import pickle
from dataclasses import dataclass, field, fields
from itertools import compress
import pandas as pd
import numpy as np
from itertools import product
import sklearn as sk
from scipy import signal, stats

import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns


In [None]:
def get_project_path_in_notebook(
    subfolder: str = '',
):
    """
    Finds path of projectfolder from Notebook.
    Start running this once to correctly find
    other modules/functions
    """
    path = os.getcwd()

    while path[-20:] != 'dyskinesia_neurophys':

        path = os.path.dirname(path)
    
    return path

In [None]:
# define local storage directories
projectpath = get_project_path_in_notebook()
codepath = os.path.join(projectpath, 'code')

In [None]:
os.chdir(codepath)
# own utility functions
import utils.utils_fileManagement as utilsFiles
# own data exploration functions
from lfpecog_plotting.plotHelpers import get_colors
import lfpecog_plotting.plotHelpers as pltHelp

import lfpecog_plotting.plot_pred_standards as plotPreds

### 0) Define settings

In [None]:

DATA_VERSION = 'v4.0'    # v4.0: new artef-rem, no reref; v3.0 multiple re-ref
FT_VERSION = 'v8'  # v4: broad-flanks, bursts; v3: broad-flanked SSD
INCL_PSD_FTS=['mean_psd', 'variation']
IGNORE_PTS = ['011', '104', '106']

CDRS_RATER = 'Patricia'
ANALYSIS_SIDE = 'BILAT'
INCL_CORE_CDRS = True
CATEGORICAL_CDRS = False

In [None]:
# get all available subs with features
SUBS = utilsFiles.get_avail_ssd_subs(DATA_VERSION=DATA_VERSION,
                                     FT_VERSION=FT_VERSION,
                                     IGNORE_PTS=IGNORE_PTS)
print(f'SUBS: n={len(SUBS)} ({SUBS})')


#### Import Pred Results Timon

In [None]:
path = os.path.join(utilsFiles.get_project_path('data'),
                    'prediction_data', 'pred_results')
# filename = 'pred_results_0322.pickle'
filename = 'd_out_NEW_FEATURES_REDUCED_WITH_MOVEMENT_offset_10_dim_4.pickle'

assert os.path.exists(os.path.join(path, filename)), 'prediction pickle file not found'

with open(os.path.join(path, filename), 'rb') as f:
    res = pickle.load(f)

In [None]:
print(res.keys())

print(res['ECOG_CEBRA_False_binary'].keys())

# overview of content
res['ECOG_CEBRA_False_binary']['prediction']  # array per sub
res['ECOG_CEBRA_False_binary']['performances']  # one value per sub


res['ECOG_CEBRA_False_binary']['y_test_pred'][0]

#### Import Acc and Time data

In [None]:
# # Load arrays for Acc and Sub-codes
# source = 'STN'
# with open(os.path.join(utilsFiles.get_project_path('data'),
#                  'prediction_data',
#                  f"ACC_dataPlus_{source}.pickle"),
#     "rb"
# ) as f:
#     metadat = pickle.load(f)

In [None]:
# metadat

### 1) Visualize True and Predicted Labels vs Activity

In [None]:
def get_sub_pred_dicts(
    pred_result_pickle,
    OUT_PARAM='scale',
    incl_ECOG=False,
    CEBRA_bool = False  # True -> CEBRA model, False -> linear model

):
    assert OUT_PARAM in ['binary', 'scale'], 'incorrect outcome parameter'

    if incl_ECOG: source = 'ECOG'
    else: source = 'STN'

    pred_key = f'{source}_CEBRA_{str(CEBRA_bool)}_{OUT_PARAM}'

    assert pred_key in pred_result_pickle.keys(), 'composed dict-key not in PICKLE'

    # Load arrays for Acc and Sub-codes
    with open(os.path.join(utilsFiles.get_project_path('data'),
                    'prediction_data',
                    f"ACC_dataPlus_{source}.pickle"),
        "rb"
    ) as f:
        metadat = pickle.load(f)

    # get dicts for sub plotting
    dat_dict = {}

    for l_pred, l_true, s in zip(
        pred_result_pickle[pred_key]['prediction'],
        pred_result_pickle[pred_key]['y_test_true'],
        np.unique(metadat['sub_ids'])
    ):
        sub_sel = metadat['sub_ids'] == s
        assert len(l_pred) == len(l_true) == sum(sub_sel), (
            f'sub-{s} mismatch y_pred and metadat'
        )

        dat_dict[s] = {
            'y_true': l_true,
            'y_pred': l_pred,
            'acc': metadat['ACC_RMS'][sub_sel],
            'times': metadat['ft_times_all'][sub_sel]
        }

    return dat_dict


In [None]:
importlib.reload(plotPreds)

# settings for data import

incl_ECOG = False
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'scale'

if CEBRA_bool:
    modelname = 'cebra_'
    model = 'cebra'
else:
    modelname = 'linMod'
    model = 'lm'

if incl_ECOG:
    modelname += 'inclEcog'
else:
    modelname += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)



# sub = '008'

for sub in dat_dict.keys():

    plotPreds.plot_subPreds_over_time(
        lid_out_param=OUT_PARAM,
        sub_dict=dat_dict[sub],
        SAVE_PLOT=True, SHOW_PLOT=False,
        fig_name=f'0325_predRows_{modelname}_Lid{OUT_PARAM}_sub{sub}',
        model=model,
    )


Plot Group Results

In [None]:
importlib.reload(plotPreds)

# settings for data import

incl_ECOG = False
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'scale'

if CEBRA_bool:
    modelname = 'cebra_'
    model = 'cebra'
else:
    modelname = 'linMod'
    model = 'lm'

if incl_ECOG:
    modelname += 'inclEcog'
else:
    modelname += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)



# plotPreds.plot_groupPreds_over_time(
#     lid_out_param=OUT_PARAM,
#     dat_dict=dat_dict,
#     SAVE_PLOT=True, SHOW_PLOT=False,
#     fig_name=f'0326_predGroupRows_{modelname}_Lid{OUT_PARAM}',
#     model=model,
# )

Further exploration

In [None]:
# Explore predicted arrays

y_pred_list = res[pred_key]['prediction']  # gives arrays per sub
y_true_list = res[pred_key]['y_test_true']

Explore activity in predictions

In [None]:
import lfpecog_plotting.plot_Spectrals_vs_LID as plotSpectrals

In [None]:
dat_dict['019'].keys()

sub_dict = dat_dict['019']

# print(np.around(sub_dict['y_pred']))

#### Scatter scale-predictions over Activity and True-CDRS

In [None]:
importlib.reload(plotPreds)


# settings for data import

incl_ECOG = True
CEBRA_bool = False  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'scale'   # binary no sense for scatter

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

fig_name=f'0326_predErrorScatter_{model}'


plotPreds.scatter_predErrors(
    dat_dict, SAVE_PLOT=True,
    fig_name=fig_name,
    ROUND_PREDS=False,
    modelname=model, out_param=OUT_PARAM,
)



In [None]:
import lfpecog_plotting.plot_activity_in_preds as plotPredAct

#### Plot activities in binary pos and neg predictions

In [None]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'   # scale not for binary plotting

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

plotPredAct.plot_binary_act_distr(
    dat_dict=dat_dict,
    model=model,
    SAVE_PLOT=True, SHOW_PLOT=False,
    fig_name=f'0326_binary_act_pred_{model}'
)


Plot sens and spec over act-bins

In [None]:
importlib.reload(plotPredAct)


# settings for data import

incl_ECOG = True
CEBRA_bool = True  # True -> CEBRA model, False -> linear model
OUT_PARAM = 'binary'   # scale not for binary plotting

if CEBRA_bool: model = 'cebra_'
else: model = 'lm_'

if incl_ECOG: model += 'inclEcog'
else: model += 'StnOnly'


dat_dict = get_sub_pred_dicts(
    pred_result_pickle=res,
    OUT_PARAM=OUT_PARAM,
    incl_ECOG=incl_ECOG,
    CEBRA_bool=CEBRA_bool,
)

plotPredAct.plot_predValues_per_ActBin(
    dat_dict=dat_dict,
    model=model,
    SAVE_PLOT=True, SHOW_PLOT=False,
    fig_name=f'0326_sensSpec_ActBins_{model}'
)
