In [None]:
import mne
import numpy as np
import os
import json

from functions import ephy_plotting, utils

# 1. Load datasets

In [None]:
working_path = os.path.dirname(os.getcwd())
results_path = os.path.join(working_path, "results")
behav_results_saving_path = os.path.join(results_path, "behav_results")
# read the json file containing the included and excluded subjects, based on the behavioral results
included_excluded_file = os.path.join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

# keep only subjects starting with "sub":
included_subjects = [subj for subj in included_subjects if subj.startswith('sub')]
print(f'Included_subjects: {included_subjects}')
onedrive_path = utils._get_onedrive_path()

#  Set saving path
freq_wide = 'all frequencies'
fmax = 70
saving_path_group = os.path.join(results_path, 'group_level', 'eeg_perc_sig_change', freq_wide)
os.makedirs(saving_path_group, exist_ok=True)  # Create the directory if it doesn't exist
saving_path_on_off = os.path.join(results_path, 'ON_vs_OFF', 'EEG', freq_wide)
os.makedirs(saving_path_on_off, exist_ok=True)

# Set source path for epochs
epochs_path = os.path.join(results_path, 'eeg_epochs')

sub_dict_epochs = {}  #  Stores the epochs for each subject/session


In [None]:
included_subjects.remove('sub006 DBS OFF mSST')
included_subjects.remove('sub006 DBS ON mSST')
included_subjects.remove('sub028 DBS OFF mSST')


In [None]:
for session_ID in included_subjects:
    epochs = mne.read_epochs(os.path.join(epochs_path, f"{session_ID}_EEG_cleaned-long-epo.fif"), preload=True)
    #epochs.resample(2048, npad='auto')
    sub_dict_epochs[session_ID] = epochs

# 2. Work with epochs #

## 2.1. Set the parameters for TFR plots ##

In [None]:
######################
### TFR PARAMETERS ###
######################

# decim = 1 
# freqs = np.arange(1, 40, 1) 
# n_cycles = np.minimum(np.maximum(freqs / 2.0, 3), 20)
# tfr_args = dict(
#     method="morlet",
#     freqs=freqs,
#     n_cycles=n_cycles,
#     decim=decim,
#     return_itc=False,
#     average=False
# )        

# tmin_tmax = [-500, 1500]
# vmin_vmax = [-70, 70]

######################
### TFR PARAMETERS ###
######################

decim = 1 
freqs = np.arange(1, fmax, 1) 
# For 500ms time resolution at 1 Hz: n_cycles = 1 * 0.5 = 0.5
# For 50ms time resolution at 40 Hz: n_cycles = 40 * 0.05 = 2
# Linear interpolation between these points
#n_cycles = 0.5 + (freqs - 1) * (2 - 0.5) / (40 - 1)
#n_cycles = freqs / 2.0
n_cycles = np.minimum(np.maximum(freqs / 2.0, 2), 20)

tfr_args = dict(
    method="morlet",
    freqs=freqs,
    n_cycles=n_cycles,
    decim=decim,
    return_itc=False,
    average=False
)        

tmin_tmax = [-500, 1500]
vmin_vmax = [-5, 5]

sub_nums = []  #  List to store unique subject numbers

for sub in included_subjects:
    sub = sub[:6]
    if sub not in sub_nums:  # Check if sub is already in sub_nums
        sub_nums.append(sub)

## 2.2. Plot power changes at the single-subject level at electrode C3 (motor cortex) ##

### 2.2.1. Plot the mean signal at C3 per trial type ###

In [None]:
%matplotlib inline
#  Plot power change for each single subject and each condition/trial type
for sub in sub_nums:
    print(f"Now processing sub: {sub}")
    single_sub_dict_subsets = {key: value for key, value in sub_dict_epochs.items() if sub in key}
    print(single_sub_dict_subsets.keys())
    saving_path_single = os.path.join(results_path, 'single_sub', f'{sub} mSST','eeg_perc_sig_change', freq_wide)
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ## single condition            
    for dbs_status in ['DBS OFF', 'DBS ON']:
        if any(dbs_status in key for key in single_sub_dict_subsets.keys()):
            for cond in [
                'GO_successful', 
                # #'GO_unsuccessful', 
                'GF_successful', 
                # #'GF_unsuccessful',
                'GC_successful', 
                # #'GC_unsuccessful',
                'GS_successful', 
                'GS_unsuccessful',
                'stop_successful',
                'stop_unsuccessful',
                'continue_successful'
                ]:
                ephy_plotting.eeg_tfr_pow_change_cond(
                    sub_dict = single_sub_dict_subsets, 
                    dbs_status = dbs_status, 
                    epoch_cond = cond,
                    ch_of_interest = 'C3',
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax, 
                    vmin_vmax = vmin_vmax,
                    baseline_correction=True,
                    saving_path=saving_path_single,
                    show_fig=True,
                    save_as = 'png'
                    )
            

### 2.2.2. Plot the power change contract contrasting different trial types ###

In [None]:
#  Plot power change for different constrasts, for each single subject and each condition/trial type
#  Contrasts of interest are:
#       - GS_successful - GO_successful  --> reactive inhibition contrast
#       - GO_successful - GF_successful  --> proactive inhibition contrast
#       - GS_successful - GS_unsuccessful  --> stopping contrast
#       - STOP_successful - STOP_unsuccessful --> stopping contrast but aligned to the STOP cue
#       - STOP_successful - CONTINUE_successful --> attentional contrast aligned to the unexpected stimuli

for sub in sub_nums:
    print(f"Now processing sub: {sub}")
    single_sub_dict_subsets = {key: value for key, value in sub_dict_epochs.items() if sub in key}
    print(single_sub_dict_subsets.keys())
    saving_path_single = os.path.join(results_path, 'single_sub', f'{sub} mSST','eeg_perc_sig_change', freq_wide)
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ## difference conditions          
    for dbs_status in ['DBS OFF', 'DBS ON']:
        if any(dbs_status in key for key in single_sub_dict_subsets.keys()):
            condition = f"{dbs_status} GS_successful - GO_successful {sub}"
            ephy_plotting.eeg_perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GS_successful",
                            epoch_cond2 = "GO_successful",
                            condition = condition,
                            ch_of_interest = 'C3',
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            
            condition = f"{dbs_status} GO_successful - GF_successful {sub}"
            ephy_plotting.eeg_perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GO_successful",
                            epoch_cond2 = "GF_successful",
                            condition = condition,
                            ch_of_interest = 'C3',
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            condition = f"{dbs_status} GS_successful - GS_unsuccessful {sub}"
            ephy_plotting.eeg_perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax,
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GS_successful",
                            epoch_cond2 = "GS_unsuccessful",
                            condition = condition,
                            ch_of_interest = 'C3',
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            condition = f"{dbs_status} stop_successful - stop_unsuccessful {sub}"
            ephy_plotting.eeg_perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "stop_successful",
                            epoch_cond2 = "stop_unsuccessful",
                            condition = condition,
                            ch_of_interest = 'C3',
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            condition = f"{dbs_status} stop_successful - continue_successful {sub}"
            ephy_plotting.eeg_perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "stop_successful",
                            epoch_cond2 = "continue_successful",
                            condition = condition,
                            ch_of_interest = 'C3',
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )

## 2.3. Plot power changes at the group level ##

### 2.3.1. Plot average power changes for each trial type ###

In [None]:
importlib.reload(analysis)
importlib.reload(ephy_plotting)

In [None]:
vmin_vmax=[-5, 5]
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:                
    for cond in [
        'GO_successful', 
        #'GO_unsuccessful', 
        'GF_successful', 
        #'GF_unsuccessful',
        'GC_successful', 
        #'GC_unsuccessful',
        'GS_successful', 
        'GS_unsuccessful',
        'stop_successful',
        'stop_unsuccessful',
        'continue_successful'
        ]:
        print(f"Now processing: {dbs_status} - {cond} ")
        ephy_plotting.eeg_tfr_pow_change_cond(
                    sub_dict = sub_dict_epochs, 
                    dbs_status = dbs_status, 
                    epoch_cond = cond, 
                    ch_of_interest = 'C3',
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax, 
                    vmin_vmax = vmin_vmax,
                    baseline_correction = True,
                    saving_path = saving_path_group,
                    show_fig = True,
                    add_rt = True
                    )


### 2.3.2. Plot average power changes across contrasts of interest ###

In [None]:
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:
    condition = f"{dbs_status} GS successful - GO successful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs, 
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GO_successful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    condition = f"{dbs_status} GO successful - GF successful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs, 
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GO_successful",
                epoch_cond2 = "GF_successful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )

    condition = f"{dbs_status} GS successful - GS unsuccessful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GS_unsuccessful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )

    condition = f"{dbs_status} stop successful - stop unsuccessful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "stop_successful",
                epoch_cond2 = "stop_unsuccessful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    # condition = f"{dbs_status} stop unsuccessful - stop successful"
    # print(f"Now processing: {condition}")
    # ephy_plotting.perc_pow_diff_cond(
    #             sub_dict = sub_dict_epochs,
    #             dbs_status = dbs_status,  
    #             tfr_args = tfr_args, 
    #             t_min_max = tmin_tmax, 
    #             vmin_vmax = vmin_vmax,
    #             epoch_cond1 = "stop_unsuccessful",
    #             epoch_cond2 = "stop_successful",
    #             condition = condition,
    #             saving_path = saving_path_group,
    #             show_fig = True,
    #             add_rt = True
    #             )
    
    condition = f"{dbs_status} stop successful - continue successful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "stop_successful",
                epoch_cond2 = "continue_successful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    condition = f"{dbs_status} GS successful - GC successful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GC_successful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    

    condition = f"{dbs_status} stop unsuccessful - continue successful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "stop_unsuccessful",
                epoch_cond2 = "continue_successful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    condition = f"{dbs_status} GS unsuccessful - GC successful"
    print(f"Now processing: {condition}")
    ephy_plotting.eeg_perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_unsuccessful",
                epoch_cond2 = "GC_successful",
                condition = condition,
                ch_of_interest = 'C3',
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )            
     

# 3. ERP analysis #

In [None]:
condition_color_dict = dict(
    GS_successful="#3ACDDE", GS_unsuccessful="#FD548C",
    GO_successful="#78CD6A", 
    GF_successful="#F89E4F",
    GC_successful="#160096",
)

In [None]:
%matplotlib qt
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:
    condition = f"{dbs_status} GO successful - GS successful"
    print(f"Now processing: {condition}")
    ephy_plotting.erp_change_diff_cond(
            sub_dict = sub_dict_epochs,
            dbs_status= dbs_status, 
            epoch_cond1= 'GO_successful',
            epoch_cond2= 'GS_successful',
            condition= condition,
            condition_color_dict= condition_color_dict
    )

    condition = f"{dbs_status} GF successful - GO successful"
    print(f"Now processing: {condition}")
    ephy_plotting.erp_change_diff_cond(
            sub_dict = sub_dict_epochs,
            dbs_status= dbs_status, 
            epoch_cond1= 'GF_successful',
            epoch_cond2= 'GO_successful',
            condition= condition,
            condition_color_dict= condition_color_dict
    )

    condition = f"{dbs_status} GS successful - GS unsuccessful"
    print(f"Now processing: {condition}")
    ephy_plotting.erp_change_diff_cond(
            sub_dict = sub_dict_epochs,
            dbs_status= dbs_status, 
            epoch_cond1= 'GS_successful',
            epoch_cond2= 'GS_unsuccessful',
            condition= condition,
            condition_color_dict= condition_color_dict
    )

    condition = f"{dbs_status} GS successful - GC successful"
    print(f"Now processing: {condition}")
    ephy_plotting.erp_change_diff_cond(
            sub_dict = sub_dict_epochs,
            dbs_status= dbs_status, 
            epoch_cond1= 'GS_successful',
            epoch_cond2= 'GC_successful',
            condition= condition,
            condition_color_dict= condition_color_dict
    )