# Combining multi-modal data for EMA validation with UPDRS and Ephys

## 0. Import packages

- document versions for reproducibility

In [None]:
# import packages
import pandas as pd
import numpy as np
import os
import sys
import csv
import json
import importlib
from itertools import product, compress
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, spearmanr
from scipy.signal import welch

In [None]:
print('Python sys', sys.version)
print('pandas', pd.__version__)
print('numpy', np.__version__)
# print('mne_bids', mne_bids.__version__)
# print('mne', mne.__version__)
# print('sci-py', scipy.__version__)
# print('sci-kit learn', sk.__version__)
# print('matplotlib', plt_version)

"""
Python sys 3.11.5 | packaged by Anaconda, Inc. | (main, Sep 11 2023, 13:26:23) [MSC v.1916 64 bit (AMD64)]
pandas 2.1.1
numpy 1.26.0
"""

In [None]:
from utils import load_utils, load_data, prep_data
# from PerceiveImport.classes import main_class

In [None]:
# FOR DEBUGGING
importlib.reload(load_data)
importlib.reload(load_utils)

## 1. Import Data

### 1.1 Import EMA and UPDRS

In [None]:
# # SINGLE CONDITION
# CONDITION = 'm0s0'

# ema_df, updrs_df = load_data.get_EMA_UPDRS_data(condition=CONDITION)


In [None]:
importlib.reload(load_data)
importlib.reload(load_utils)


# 4 CONDITIONS
EMA, UPDRS = {}, {}

for COND in ['m0s0', 'm0s1', 'm1s0', 'm1s1']:
    ema_temp, updrs_temp = load_data.get_EMA_UPDRS_data(condition=COND,)
    EMA[COND] = ema_temp
    UPDRS[COND] = updrs_temp


In [None]:
EMA['m0s0']['Q11_medState']
# EMA['m0s0']['Q12_medIntake']

# UPDRS['m0s1']

### 1.2 Import LFP data

to do's:
- double check "rest" task is not excluding data
- include stim-amplitude data rows to double s0 vs s1

In [None]:

ids = load_data.get_ids()

SKIP_LFPs = {
    'ema03': ['m0s1'],
    'ema07': ['m1s0', 'm1s1'],  # no m1 done: always ['m1s0', 'm1s1']
    'ema09': ['m1s0', 'm1s1'],  # no m1 done: always ['m1s0', 'm1s1']
    'ema10': ['m1s0', 'm1s1'],  # no m1 done: always ['m1s0', 'm1s1']
    'ema12': ['m1s0', 'm1s1'],  # no m1 done: always ['m1s0', 'm1s1']
    'ema14': 'all',  # no m1 done, m0s1 not found in motherfolder
    # 'ema14': ['m1s0', 'm1s1', 'm0s1'],  # ONLY m0s0; EXCLUDE?!
    'ema15': ['m1s0', 'm1s1'],  # no m1 done: always ['m1s0', 'm1s1']
    'ema16': ['m1s0', 'm1s1']  # ONLY m0s0; EXCLUDE?!
    # 'ema16': 'all'  # no m1 done: always ['m1s0', 'm1s1']
}

lfp_data = {}

for ema_id, COND in product(ids.index,
                            ['m0s0', 'm0s1', 'm1s0', 'm1s1']):
    if ema_id in SKIP_LFPs.keys():
        if COND in SKIP_LFPs[ema_id] or SKIP_LFPs[ema_id] == 'all':
            print(f'\n#### SKIP {ema_id} {COND}, not percept ready ####\n')
            continue

    sub = ids.loc[ema_id]['prc_id']
    ses = ids.loc[ema_id]['prc_ses']

    print(f'\nGET LFP {ema_id}, {sub}, {ses}, {COND}')

    # load session that corresponds to current selection
    ### TODO: 'REST' is hardcoded currently, check for issues with task like rest&tap
    sub_data = main_class.PerceiveData(
        sub = sub, 
        incl_modalities=['streaming'],
        incl_session = [ses],
        incl_condition =[COND,],
        incl_task = ["rest"],
        import_json=False, # for addtionally loading the corresponding JSON files as source files, set to True
        warn_for_metaNaNs=True, # True will give you a warning with rows from the metadata table with NaNs. Make sure you have filled out all columns of the file you want to load.
        allow_NaNs_in_metadata=True,
    )

    dat = getattr(sub_data.streaming, ses)
    # only include first two data rows (left and right STN signal)
    dat = getattr(dat, COND).rest.run1.data.get_data()[:2, :]
    ### TODO: include stimulation amplitude data streams to double check whether s0 vs s1 is correct
    lfp_data[f'{ema_id}_{COND}'] = dat

## 2. Preprocess data

#### Get (mean-corrected) EMA and UPDRS values per symptom subtype

In [None]:
importlib.reload(prep_data)

In [None]:
sumdf = prep_data.get_sum_df(EMA_dict=EMA, UPDRS_dict=UPDRS)

sumdf

#### Get Beta powers

In [None]:
importlib.reload(load_utils)

In [None]:
# FIG_PATH = os.path.join(os.path.dirname(os.getcwd()), 'figures', 'lfp_preprocess')
FIG_PATH = load_utils.get_onedrive_path('emaval_fig')
print(f'CHECK FIG_PATH: {FIG_PATH}, exists? -> {os.path.exists(FIG_PATH)}')

# finish correction for Rest&Tap timings!!

In [None]:
def plot_single_lfp_preprocess(
    DAT,
    SUB = 'emaXX',
    COND = 'm0s0',
    N_STD_OUTLIER = 3,
    LOWPASS = 2,
    HIGHPASS = 45,
    SFREQ=250,
    SHOWPLOTS=False,
    SAVEPLOTS=True,
):
    lfp_times = prep_data.get_lfp_times()


    fig, axes = plt.subplots(2, 2)
    for i, (arr, side) in enumerate(
        zip(DAT[f'{SUB}_{COND}'][:2], ['left', 'right'])
    ):
        arr = arr.copy()  # do not overwrite original dict data

        if ids.loc[SUB]['prc_id'] in lfp_times.keys():
            t_start, t_end = lfp_times[ids.loc[SUB]['prc_id']][COND]['rest']
            i_start, i_end = (t_start * 250, t_end * 250)
            arr = arr[i_start:i_end]

        ### plot raw signal
        axes[0, i].plot(arr, color='blue', alpha=.3, label='raw filtered',)

        ### handle outliers
        sel = np.logical_or(arr > (N_STD_OUTLIER * np.std(arr)),
                            arr < (-N_STD_OUTLIER * np.std(arr)))
        # arr[sel] = np.nan  # replace outliers with NaNs
        arr = arr[~sel]  # drop outliers
        
        ### plot resulting arr
        axes[0, i].plot(arr, color='blue', label='cleaned',)
        axes[0, i].set_title(f'{SUB} {COND} {side} STN', weight='bold')
        axes[0, i].set_ylabel(f'{side}-STN activity (yVolt)')
        xticks = np.arange(0, len(arr), 250 * 60)
        axes[0, i].set_xticks(xticks)
        axes[0, i].set_xticklabels(np.arange(len(xticks)))
        axes[0, i].set_xlabel('Time (minutes)')
        axes[0, i].set_ylim(-50, 50)
        # axes[0, i].legend(loc='upper right', frameon=False,)  # legend

        ### plot PSD
        arr = prep_data.lfp_filter(signal=arr, low=LOWPASS, high=HIGHPASS,)
        f, psx = welch(arr, fs=SFREQ,)
        axes[1, i].plot(f, psx)
        axes[1, i].set_ylabel(f'{side}-STN Power (a.u.)')
        axes[1, i].set_xlim(0, 45)
        axes[1, i].set_xlabel('Freq (Hz)')

    plt.tight_layout()

    if SAVEPLOTS:
        plt.savefig(os.path.join(FIG_PATH, 'lfp_preprocess', f'PSD_check_{SUB}_{COND}'),
                    facecolor='w', dpi=150,)
    if SHOWPLOTS: plt.show()
    else: plt.close()

CHECK missing LFP sessions

check motherfolder:
- ema16, sub105: too many runs? UPDRS tasks? 3 rest m0s0, 2 rest m0s1?
- ema14: only m0s0, leave out only one state

In [None]:
lfp_done = np.unique([k.split('_')[0] for k in lfp_data.keys()])

lfp_todo = [s for s in ids.index if s not in lfp_done]

print(lfp_todo)



In [None]:
for sub in lfp_todo:

    print(f'\n{sub}  -> sub-{ids.loc[sub]["prc_id"]} @ {ids.loc[sub]["prc_ses"]}')
    for COND in ['m0s0', 'm0s1', 'm1s0', 'm1s1']:
        print(f'\t{COND}')
        sub_data = main_class.PerceiveData(
            sub = ids.loc[sub]['prc_id'],
            incl_modalities=['streaming'],
            incl_session = [ids.loc[sub]['prc_ses']],
            incl_condition =[COND,],
            incl_task = ["rest"],
            import_json=False, # for addtionally loading the corresponding JSON files as source files, set to True
            warn_for_metaNaNs=True, # True will give you a warning with rows from the metadata table with NaNs. Make sure you have filled out all columns of the file you want to load.
            allow_NaNs_in_metadata=True,
        )

#### Select relevant ephys epochs based on task timings

In [None]:
lfp_times = prep_data.get_lfp_times()
ids = load_data.get_ids()


In [None]:
Fs = 250
sub = 'ema01'
con = 'm0s0'
lfp_sub = ids.loc[sub]['prc_id']

rest_times = lfp_times[lfp_sub][con]['rest']
rest_samples = [rest_times[0] * Fs, rest_times[1] * Fs]

plt.plot(lfp_data[f'{sub}_{con}'][0][rest_samples[0]:rest_samples[1]])

### TODO:
# check if all seconds for available data is working
# correct 'rest' tasks if troublesome i.e. rest&tap
# check s0 and s1 versus stim-ampltidude time series
# plot individual PSDs
# calculate beta-powers X UPDRS correlations
# draft if and if so, how to include movement parts?

#### Plot and save spectral preprocessing

In [None]:
lfp_subs = np.unique([k.split('_')[0] for k in lfp_data.keys()])

# lfp_subs = ['ema01', 'ema08']

for SUB, COND in product(lfp_subs, ['m0s0', 'm0s1', 'm1s0', 'm1s1']):

    print(f'\n### {SUB}, {COND}')
    if f'{SUB}_{COND}' not in lfp_data.keys():
        print(f'...skip {SUB}, {COND}')
        continue

    plot_single_lfp_preprocess(SUB=SUB, COND=COND, DAT=lfp_data,
                               N_STD_OUTLIER=6,
                               SHOWPLOTS=False, SAVEPLOTS=True,)


## 3. Analyze Correlations

In [None]:
def scatter_EMA_UPDRS(
    dat_df,
    EMA_subscore = 'brady',
    UPDRS_subscore = 'brady',
):

    x, y = [], []

    for COND in ['m0s0', 'm0s1', 'm1s0', 'm1s1']:

        x.extend(dat_df[f'EMA_SUM_{EMA_subscore}_{COND}'])
        y.extend(dat_df[f'UPDRS_SUM_{UPDRS_subscore}_{COND}'])

    plt.scatter(x, y)
    plt.axhline(y=0, c='gray', alpha=0.3)
    plt.axvline(x=0, c='gray', alpha=0.3)

    R, p = spearmanr(
        [x for x in x if not np.isnan(x)],
        [y for y in y if not np.isnan(y)]
    )

    plt.title(f'Spearman R: {R.round(2)}, p={p.round(5)}')
    plt.xlabel(f'EMA {EMA_subscore}\n(higher is less symptoms)')
    plt.ylabel(f'UPDRS {UPDRS_subscore}\n(lower is less symptoms)')
    plt.show()

In [None]:
EMA_subscore = 'brady'
UPDRS_subscore = 'brady'


scatter_EMA_UPDRS(
    dat_df=sumdf,
    EMA_subscore=EMA_subscore,
    UPDRS_subscore=UPDRS_subscore,
)