In [None]:
import mne
from mne.io import read_raw
import numpy as np
import matplotlib.pyplot as plt
#import matplotlib.colors as mcolors
from os.path import join
import os
import pandas as pd
#from scipy.signal import hilbert
import json
from collections import defaultdict

from functions import ephy_plotting, preprocessing, analysis, io, utils

In [None]:
#from functions.utils import _get_onedrive_path, _update_and_save_multiple_params, extract_stats, create_1_Hz_wide_bands
#from functions.ephy_plotting import plot_raw_stim, plot_psd_log, plot_stft_stim, plot_av_freq_power_by_trial, plot_tfr_success_vs_unsuccess, plot_power_comparison_between_conditions, plot_amplitude_and_difference_from_json
#from functions.preprocessing import create_epochs, create_epochs_subsets_from_behav, create_epochs_specific_freq_band
#from functions.analysis import compute_psd_welch
#from functions.io import save_epochs, load_behav_data

# 1. Load the dataset #

In [None]:
working_path = os.path.dirname(os.getcwd())
results_path = join(working_path, "results")
behav_results_saving_path = join(results_path, "behav_results")
# read the json file containing the included and excluded subjects, based on the behavioral results
included_excluded_file = join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

# keep only subjects starting with "sub":
included_subjects = [subj for subj in included_subjects if subj.startswith('sub')]
print(f'Included_subjects: {included_subjects}')
onedrive_path = utils._get_onedrive_path()

#  Set saving path
saving_path_group = join(results_path, 'group_level', 'lfp_perc_sig_change', 'morlet_low_freq')  
os.makedirs(saving_path_group, exist_ok=True)  # Create the directory if it doesn't exist
saving_path_on_off = join(results_path, 'ON_vs_OFF', 'morlet_low_freq')
os.makedirs(saving_path_on_off, exist_ok=True)

In [None]:
# Prepare variables and dictionaries for storing results
# Dictionary to store subject epochs in
sub_dict_epochs_subsets = {}  #  Stores the epochs for each condition and for each subject/session
sub_dict_lm_GO = {}  #  Stores the epochs for lm_GO trials for each subject/session
sub_dict_RT = {}  #  Stores the mean reaction time for each trial type
sub_dict_stats = {}  #  Stores behavioral stats for each subject/session

cluster_results_dict = {}
cluster_results_dict = defaultdict(dict)  # Each missing key gets an empty dictionary
cluster_results_dict['All subjects'] = included_subjects

In [None]:
# Load all data for all included subjects
data = io.load_behav_data(included_subjects, onedrive_path)

# Compute statistics for each loaded subject
stats = {}
stats = utils.extract_stats(data)
# If no file was found, create a new JSON file
filename = "stats.json"
file_path = os.path.join(results_path, filename)
#if not os.path.isfile(file_path):
#    with open(file_path, "w", encoding="utf-8") as file:
#            json.dump({}, file, indent=4)

# Save the updated or new JSON file
with open(file_path, "w", encoding="utf-8") as file:
    json.dump(stats, file, indent=4)

# remove sub023
included_subjects.remove('sub023 DBS ON mSST')
included_subjects.remove('sub023 DBS OFF mSST')
included_subjects


# 2. Create full session plots for each subject (TFR plot, PSD) and epoch data #

In [None]:
from scipy.interpolate import interp1d

In [None]:
# save psd in dictionnary :
psd_dict = {}
psd_dict = defaultdict(dict)  # Each missing key gets an empty dictionary


# Start a loop through subjects
for session_ID in included_subjects:
    print(f"Now processing {session_ID}")
    session_dict = {}
    sub = session_ID[:6]
    subject_ID = session_ID.split(' ') [0]
    condition = session_ID.split(' ') [1] + ' ' + session_ID.split(' ') [2]
    sub_onedrive_path = join(onedrive_path, subject_ID)
    sub_onedrive_path_task = join(onedrive_path, subject_ID, 'synced_data', session_ID)
    filename = [f for f in os.listdir(sub_onedrive_path_task) if (
        f.endswith('.set') and f.startswith('SYNCHRONIZED_INTRACRANIAL'))]
    
    if not filename:
        raise FileNotFoundError(f"No .set file found in {sub_onedrive_path_task}")

    file = join(sub_onedrive_path_task, filename[0])

    if not os.path.isfile(file):
        raise FileNotFoundError(f"File does not exist: {file}")

    print(f"Loading file: {file}")
    #file = join(sub_onedrive_path_task, filename[0])
    raw = read_raw(file, preload=True)

    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST','lfp_perc_sig_change', 'morlet_low_freq') 
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ephy_plotting.plot_raw_stim(session_ID, raw, saving_path_single)

    # resample the data to 200Hz just for PSD computation after handling NaNs:
    # Access the data array
    raw_copy = raw.copy()
    data = raw_copy.get_data()  # shape: (n_channels, n_times)

    # Interpolate NaNs per channel
    for ch in range(data.shape[0]):
        signal = data[ch]
        nans = np.isnan(signal)
        if np.any(nans):
            valid = ~nans
            if valid.sum() < 2:
                print(f"Channel {ch} has too few valid points for interpolation.")
                continue
            interp_func = interp1d(np.flatnonzero(valid), signal[valid], kind='linear', bounds_error=False, fill_value='extrapolate')
            signal[nans] = interp_func(np.flatnonzero(nans))
            data[ch] = signal  # Replace with interpolated signal

    # Put the cleaned data back into raw object
    raw_copy._data = data  # Direct assignment is OK if you already use preload=True

    raw_resampled = raw_copy.copy().resample(sfreq=200)
    psd_left, freqs_left, psd_right, freqs_right = analysis.compute_psd_welch(raw_resampled)
    session_dict['psd_left'] = psd_left
    session_dict['freqs_left'] = freqs_left
    session_dict['psd_right'] = psd_right
    session_dict['freqs_right'] = freqs_right

    # Save the PSD data in the dictionary
    psd_dict[sub][condition] = {
        'psd_left': psd_left,
        'freqs_left': freqs_left,
        'psd_right': psd_right,
        'freqs_right': freqs_right}
    
    ephy_plotting.plot_psd_log(
        session_ID, raw, freqs_left, psd_left, 
        freqs_right, psd_right, saving_path_single, is_filt=False
        )
    ephy_plotting.plot_stft_stim(
        session_ID, raw, saving_path=saving_path_single, is_filt=False, 
        vmin = -3, vmax = 1, fmin=0, fmax=100
        )

    session_dict['CHANNELS'] = raw.ch_names

    # Rename channels to be consistent across subjects:
    new_channel_names = [
        "Left_STN",
        "Right_STN",
        "left_peak_STN",
        "right_peak_STN",
        "STIM_Left_STN",
        "STIM_Right_STN"
    ]

    # Get the existing channel names
    old_channel_names = raw.ch_names

    # Create a mapping from old to new names
    rename_dict = {old: new for old, new in zip(old_channel_names, new_channel_names)}

    # Rename the channels
    raw.rename_channels(rename_dict)

    session_dict['RENAMED_CHANNELS'] = raw.ch_names

    # Filter between 1 and 95 Hz:
    filtered_data = raw.copy().filter(l_freq=1, h_freq=95)
    #filtered_data_resampled = filtered_data.copy().resample(sfreq=200) # DO NOT WORK WITH NaNs
    # DOWNSAMPLED DATA SHOULD NOT BE EPOCHED BECAUSE IT INTRODUCES JITTER IN THE EVENTS

    # Extract events and create epochs
    # only keep lfp channels
    filtered_data_lfp = filtered_data.copy().pick_channels([filtered_data.ch_names[0], filtered_data.ch_names[1]])

    epochs, filtered_event_dict = preprocessing.create_epochs(filtered_data_lfp, session_ID)

    mSST_raw_behav_session_data_path = join(
            onedrive_path, subject_ID, "raw_data", 'BEHAVIOR', condition, 'mSST'
            )
    for filename in os.listdir(mSST_raw_behav_session_data_path):
            if filename.endswith(".csv"):
                fname = filename
    filepath_behav = join(mSST_raw_behav_session_data_path, fname)
    df = pd.read_csv(filepath_behav)

    # return the index of the first row which is not filled by a Nan value:
    start_task_index = df['blocks.thisRepN'].first_valid_index()
    # Crop dataframe in 2 parts: before and after the task:
    #df_training = df.iloc[:start_task_index]
    df_maintask = df.iloc[start_task_index:-1]

    # remove the trials with early presses, as in these trials the cues were not presented
    early_presses = df_maintask[df_maintask['early_press_resp.corr'] == 1]
    early_presses_trials = list(early_presses.index)
    number_early_presses = len(early_presses_trials)

    # remove trials with early presses from the dataframe:
    df_maintask_copy = df_maintask.drop(early_presses_trials)

    # Filter successful and unsuccessful trials:
    (epochs_subsets, epochs_lm, mean_RT_dict) = preprocessing.create_epochs_subsets_from_behav(
            df_maintask_copy, 
            epochs, 
            filtered_event_dict
            )

    sub_dict_epochs_subsets[session_ID] = epochs_subsets
    sub_dict_lm_GO[session_ID] = epochs_lm
    sub_dict_RT[session_ID] = mean_RT_dict
    sub_dict_stats[session_ID] = stats[session_ID]



In [None]:
for sub in psd_dict.keys():
    fig, ax = plt.subplots(1,2,figsize=(10,5))
    for condition in psd_dict[sub].keys():
        psd_left = psd_dict[sub][condition]['psd_left']
        freqs_left = psd_dict[sub][condition]['freqs_left']
        psd_right = psd_dict[sub][condition]['psd_right']
        freqs_right = psd_dict[sub][condition]['freqs_right']
        color = 'blue' if condition == 'DBS ON' else 'orange'
        # Plotting the PSD for left and right channels
        ax[0].plot(freqs_left, np.log(psd_left), color=color, label=condition)
        ax[1].plot(freqs_right, np.log(psd_right), color=color, label=condition)
        fig.suptitle(f'Power Spectral Density - {sub} {condition}')
        ax[0].set_title('Left STN')
        ax[1].set_title('Right STN')
        ax[0].set_xlabel('Frequency (Hz)')
        ax[0].set_ylabel('Log Power Spectral Density (dB)')
        ax[1].set_xlabel('Frequency (Hz)')
        ax[1].legend(loc='upper right')
    plt.savefig(join(results_path, f'{sub}_psd.png'))
    plt.show()

In [None]:
# also plot average psd across all subjects
psd_lefts_on = []
psd_rights_on = []
freqs_lefts_on = []
freqs_rights_on = []
psd_lefts_off = []
psd_rights_off = []
freqs_lefts_off = []
freqs_rights_off = []

fig, ax = plt.subplots(1,2,figsize=(10,5))

for condition in [
    'DBS ON', 
    'DBS OFF'
    ]:

    for sub in psd_dict.keys():
        print(f'{sub} {psd_dict[sub].keys()}')
        psd_left = psd_dict[sub][condition]['psd_left']    
        freqs_left = psd_dict[sub][condition]['freqs_left']
        ax[0].plot(freqs_left, np.log(psd_left), color='lightblue' if condition == 'DBS ON' else 'peachpuff')
        psd_right = psd_dict[sub][condition]['psd_right']
        freqs_right = psd_dict[sub][condition]['freqs_right']
        if condition == 'DBS ON':
            psd_lefts_on.append(psd_left)
            freqs_lefts_on.append(freqs_left)
            psd_rights_on.append(psd_right)
            freqs_rights_on.append(freqs_right)
        else:
            psd_lefts_off.append(psd_left)
            freqs_lefts_off.append(freqs_left)
            psd_rights_off.append(psd_right)
            freqs_rights_off.append(freqs_right)
        ax[1].plot(freqs_right, np.log(psd_right), color='lightblue' if condition == 'DBS ON' else 'peachpuff')

    if condition == 'DBS ON':
        color = 'blue'
        psd_avg_left = np.nanmean(psd_lefts_on, axis=0)
        freqs_avg_left = freqs_lefts_on[0]
        psd_avg_right = np.nanmean(psd_rights_on, axis=0)
        freqs_avg_right = freqs_rights_on[0]
    else:
        color = 'orange'
        psd_avg_left = np.nanmean(psd_lefts_off, axis=0)
        freqs_avg_left = freqs_lefts_off[0]
        psd_avg_right = np.nanmean(psd_rights_off, axis=0)
        freqs_avg_right = freqs_rights_off[0]

    ax[0].plot(freqs_avg_left, np.log(psd_avg_left), color=color, zorder=5, label=condition)
    ax[1].plot(freqs_avg_right, np.log(psd_avg_right), color=color, zorder = 5, label=condition)
    fig.suptitle(f'Average Power Spectral Density - n = {len(psd_dict.keys())} subjects')
    ax[0].set_title('Left STN')
    ax[1].set_title('Right STN')
    ax[0].set_xlabel('Frequency (Hz)')
    ax[0].set_ylabel('Log Power Spectral Density (dB)')
    ax[1].set_xlabel('Frequency (Hz)')
    ax[1].legend(loc='upper right')

psd_lefts_off = np.array(psd_lefts_off)
psd_rights_off = np.array(psd_rights_off)
psd_lefts_on = np.array(psd_lefts_on)
psd_rights_on = np.array(psd_rights_on)

from scipy.stats import ttest_rel
from statsmodels.stats.multitest import fdrcorrection

t_vals_left, p_vals_left = ttest_rel(psd_lefts_on, psd_lefts_off, axis=0)
reject_left, pvals_corrected_left = fdrcorrection(p_vals_left, alpha=0.05)
#reject_left = p_vals_left < 0.05  # Create a boolean mask for significant frequencies

t_vals_right, p_vals_right = ttest_rel(psd_rights_on, psd_rights_off, axis=0)
reject_right, pvals_corrected_right = fdrcorrection(p_vals_right, alpha=0.05)
#reject_right = p_vals_right < 0.05  # Create a boolean mask for significant frequencies

sig_mask_left = reject_left
print(f"Significant differences in left STN: {np.sum(sig_mask_left)} frequencies")
sig_mask_right = reject_right
print(f"Significant differences in right STN: {np.sum(sig_mask_right)} frequencies")

ax[0].fill_between(freqs_avg_left, plt.ylim()[0], plt.ylim()[1], where=sig_mask_left, color='grey', alpha=0.3, label='Significant Difference')
ax[1].fill_between(freqs_avg_right, plt.ylim()[0], plt.ylim()[1], where=sig_mask_right, color='grey', alpha=0.3, label='Significant Difference')

# 3. Work with epochs #

In [None]:
#  sub_dict_epochs_subsets['sub006 DBS ON mSST']['GO_successful']

## 3.1. Set the parameters for the plots ## 

In [None]:
######################
### TFR PARAMETERS ###
######################

decim = 1 
freqs = np.arange(1, 35, 1) 
n_cycles = np.minimum(np.maximum(freqs / 2.0, 3), 20)
tfr_args = dict(
    method="morlet",
    freqs=freqs,
    n_cycles=n_cycles,
    decim=decim,
    return_itc=False,
    average=False
)        

tmin_tmax = [-500, 1500]
vmin_vmax = [-70, 70]

sub_nums = []  #  List to store unique subject numbers

for sub in included_subjects:
    sub = sub[:6]
    if sub not in sub_nums:  # Check if sub is already in sub_nums
        sub_nums.append(sub)

In [None]:
#############################
### MULTITAPER PARAMETERS ###
#############################

fmin, fmax = 2, 50
n_freqs = 50  # You can increase if needed
#freqs = np.logspace(np.log10(fmin), np.log10(fmax), n_freqs)
freqs = np.arange(2, 50, 1)  # 1 Hz steps
n_cycles = freqs
time_bandwidth = 2.0  # Hanning taper = 1 taper = time_bandwidth=1

# Sliding window step
#step = 0.02  # 20 ms

# Define baseline
baseline = (-0.5, 0)  # seconds

tfr_args = dict(
    method="multitaper",
    freqs=freqs,
    n_cycles=n_cycles,
    decim=decim,
    return_itc=False,
    average=False
)        

## 3.2. Plot power changes at the single-subject level ##

### 3.2.1. Plot the mean signal per channel and trial type ###

In [None]:
#  Plot power change for each single subject and each condition/trial type
for sub in sub_nums:
    print(f"Now processing sub: {sub}")
    single_sub_dict_subsets = {key: value for key, value in sub_dict_epochs_subsets.items() if sub in key}
    print(single_sub_dict_subsets.keys())
    single_sub_dict_lm_GO = {key: value for key, value in sub_dict_lm_GO.items() if sub in key}
    single_sub_RT_dict = {key: value for key, value in sub_dict_RT.items() if sub in key}
    single_sub_stats_dict = {key: value for key, value in sub_dict_stats.items() if sub in key}
    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST','lfp_perc_sig_change')
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ## single condition            
    for dbs_status in ['DBS OFF', 'DBS ON']:
        if any(dbs_status in key for key in single_sub_dict_subsets.keys()):
            for cond in [
                'GO_successful', 
                #'GO_unsuccessful', 
                'GF_successful', 
                #'GF_unsuccessful',
                'GC_successful', 
                #'GC_unsuccessful',
                'GS_successful', 
                'GS_unsuccessful',
                'STOP_successful',
                'STOP_unsuccessful',
                'CONTINUE_successful'
                ]:
                ephy_plotting.tfr_pow_change_cond(
                    sub_dict = single_sub_dict_subsets, 
                    RT_dict = single_sub_RT_dict,   
                    stats_dict = single_sub_stats_dict, 
                    dbs_status = dbs_status, 
                    epoch_cond = cond, 
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax, 
                    vmin_vmax = vmin_vmax,
                    baseline_correction=True,
                    saving_path=saving_path_single,
                    show_fig=True
                    )
                
            ephy_plotting.tfr_pow_change_cond(
                sub_dict = single_sub_dict_lm_GO, 
                RT_dict = single_sub_RT_dict,
                stats_dict = single_sub_stats_dict,
                dbs_status = dbs_status, 
                epoch_cond = "lm_GO", 
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax,
                vmin_vmax = vmin_vmax, 
                baseline_correction=True,
                saving_path=saving_path_single,
                show_fig = True
                )
            

### 3.2.2 Plot the power change contrasting different trial types ###

In [None]:
#  Plot power change for different constrasts, for each single subject and each condition/trial type
#  Contrasts of interest are:
#       - GS_successful - lm_GO  --> reactive inhibition contrast
#       - GO_successful - GF_successful  --> proactive inhibition contrast
#       - GS_successful - GS_unsuccessful  --> stopping contrast
#       - STOP_successful - STOP_unsuccessful --> stopping contrast but aligned to the STOP cue
#       - STOP_successful - CONTINUE_successful --> attentional contrast aligned to the unexpected stimuli

for sub in sub_nums:
    print(f"Now processing sub: {sub}")
    single_sub_dict_subsets = {key: value for key, value in sub_dict_epochs_subsets.items() if sub in key}
    print(single_sub_dict_subsets.keys())
    single_sub_dict_lm_GO = {key: value for key, value in sub_dict_lm_GO.items() if sub in key}
    single_sub_RT_dict = {key: value for key, value in sub_dict_RT.items() if sub in key}
    single_sub_stats_dict = {key: value for key, value in sub_dict_stats.items() if sub in key}
    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST','lfp_perc_sig_change')
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ## difference conditions          
    for dbs_status in ['DBS OFF', 'DBS ON']:
        if any(dbs_status in key for key in single_sub_dict_subsets.keys()):
            condition = f"{dbs_status} GS_successful - lm_GO {sub}"
            cluster_results_dict = ephy_plotting.perc_pow_diff_cond2(
                            sub_dict = single_sub_dict_subsets,
                            sub_dict_lm_GO = single_sub_dict_lm_GO,
                            RT_dict = single_sub_RT_dict,
                            stats_dict = single_sub_stats_dict,
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GS_successful",
                            epoch_cond2 = "lm_GO",
                            cluster_results_dict = cluster_results_dict,
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True
                            )
                
            condition = f"{dbs_status} GO_successful - GF_successful {sub}"
            cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            RT_dict = single_sub_RT_dict, 
                            stats_dict = single_sub_stats_dict,
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GO_successful",
                            epoch_cond2 = "GF_successful",
                            cluster_results_dict = cluster_results_dict,
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            ADD_RT = True
                            )
            condition = f"{dbs_status} GS_successful - GS_unsuccessful {sub}"
            cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            RT_dict = single_sub_RT_dict,
                            stats_dict = single_sub_stats_dict,
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GS_successful",
                            epoch_cond2 = "GS_unsuccessful",
                            cluster_results_dict = cluster_results_dict,
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            ADD_RT = True
                            )
            condition = f"{dbs_status} STOP_successful - STOP_unsuccessful {sub}"
            cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            RT_dict = single_sub_RT_dict,
                            stats_dict = single_sub_stats_dict,
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "STOP_successful",
                            epoch_cond2 = "STOP_unsuccessful",
                            cluster_results_dict = cluster_results_dict,
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            ADD_RT = True
                            )
            condition = f"{dbs_status} STOP_successful - CONTINUE_successful {sub}"
            cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            RT_dict = single_sub_RT_dict,
                            stats_dict = single_sub_stats_dict,
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "STOP_successful",
                            epoch_cond2 = "CONTINUE_successful",
                            cluster_results_dict = cluster_results_dict,
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            ADD_RT = True
                            )

## 3.3. Plot power changes at the group level ##

### 3.3.1. Plot average power changes for each trial type ###

In [None]:
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:                
    for cond in [
        'GO_successful', 
        #'GO_unsuccessful', 
        'GF_successful', 
        #'GF_unsuccessful',
        'GC_successful', 
        #'GC_unsuccessful',
        'GS_successful', 
        'GS_unsuccessful',
        'STOP_successful',
        'STOP_unsuccessful',
        'CONTINUE_successful'
        ]:
        print(f"Now processing: {dbs_status} - {cond} ")
        ephy_plotting.tfr_pow_change_cond(
                    sub_dict = sub_dict_epochs_subsets, 
                    RT_dict = sub_dict_RT,
                    stats_dict = sub_dict_stats,
                    dbs_status = dbs_status, 
                    epoch_cond = cond, 
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax, 
                    vmin_vmax = vmin_vmax,
                    baseline_correction = True,
                    saving_path = saving_path_group,
                    show_fig = True,
                    ADD_RT = True
                    )

    print(f"Now processing: {dbs_status} - lm_GO ")    
    ephy_plotting.tfr_pow_change_cond(
                    sub_dict = sub_dict_lm_GO,
                    RT_dict = sub_dict_RT,
                    stats_dict = sub_dict_stats,
                    dbs_status = dbs_status, 
                    epoch_cond = "lm_GO", 
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax,
                    vmin_vmax = vmin_vmax, 
                    baseline_correction = True,
                    saving_path=saving_path_group,
                    show_fig = True,
                    ADD_RT = True
                    )

### 3.3.2. Plot average power changes across contrasts of interest ###

In [None]:
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:
    condition = f"{dbs_status} GS_successful - lm_GO"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond2(
                sub_dict = sub_dict_epochs_subsets,
                sub_dict_lm_GO = sub_dict_lm_GO,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "lm_GO",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True
                )
    
    condition = f"{dbs_status} GO - GF"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets, 
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GO_successful",
                epoch_cond2 = "GF_successful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )

    condition = f"{dbs_status} GS successful - GS unsuccessful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets, 
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GS_unsuccessful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )

    condition = f"{dbs_status} STOP successful - STOP unsuccessful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "STOP_successful",
                epoch_cond2 = "STOP_unsuccessful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )
    
    condition = f"{dbs_status} STOP unsuccessful - STOP successful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "STOP_unsuccessful",
                epoch_cond2 = "STOP_successful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )
    
    condition = f"{dbs_status} STOP successful - CONTINUE successful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "STOP_successful",
                epoch_cond2 = "CONTINUE_successful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )
    
    condition = f"{dbs_status} GS successful - GC successful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GC_successful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )
    

    condition = f"{dbs_status} STOP unsuccessful - CONTINUE successful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "STOP_unsuccessful",
                epoch_cond2 = "CONTINUE_successful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )
    
    condition = f"{dbs_status} GS unsuccessful - GC successful"
    print(f"Now processing: {condition}")
    cluster_results_dict = ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs_subsets,
                RT_dict = sub_dict_RT,
                stats_dict = sub_dict_stats,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_unsuccessful",
                epoch_cond2 = "GC_successful",
                cluster_results_dict = cluster_results_dict,
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                ADD_RT = True
                )            
     


## 3.4. Comparison between DBS ON and DBS OFF ##
Here only keep subjects who have both conditions.

### 3.4.1. First, subselect the data of patients with both sessions, remove other patients ###

In [None]:
sub_both_cond = []

sub_dict_subsets_ON_OFF = {}
sub_dict_lm_GO_ON_OFF = {}
sub_dict_RT_ON_OFF = {}
sub_dict_stats_ON_OFF = {}

for sub in sub_nums:
    if (sub + " DBS ON mSST" in included_subjects) and (sub + " DBS OFF mSST" in included_subjects):
        sub_both_cond.append(sub)

# extract the data for the subjects that have both conditions
for sub in sub_both_cond:
    single_sub_dict_subsets_ON_OFF = {key: value for key, value in sub_dict_epochs_subsets.items() if sub in key}
    sub_dict_subsets_ON_OFF[sub] = single_sub_dict_subsets_ON_OFF
    single_sub_dict_lm_GO_ON_OFF = {key: value for key, value in sub_dict_lm_GO.items() if sub in key}
    sub_dict_lm_GO_ON_OFF[sub] = single_sub_dict_lm_GO_ON_OFF
    single_sub_RT_dict_ON_OFF = {key: value for key, value in sub_dict_RT.items() if sub in key}
    sub_dict_RT_ON_OFF[sub] = single_sub_RT_dict_ON_OFF
    single_sub_stats_dict_ON_OFF = {key: value for key, value in sub_dict_stats.items() if sub in key}
    sub_dict_stats_ON_OFF[sub] = single_sub_stats_dict_ON_OFF

### 3.4.2. Plot contrast DBS ON - DBS OFF for each trial type ###

In [None]:
for cond in [         
    'GO_successful', 
    #'GO_unsuccessful', 
    'GF_successful', 
    #'GF_unsuccessful',
    'GC_successful', 
    #'GC_unsuccessful',
    'GS_successful', 
    'GS_unsuccessful',
    'STOP_successful',
    'STOP_unsuccessful',
    'CONTINUE_successful'
    ]:
    print(f"Now processing: {cond}")
    condition = f"ON vs OFF - {cond}"
    cluster_results_dict = ephy_plotting.perc_pow_diff_on_off(
        sub_dict_ON_OFF = sub_dict_subsets_ON_OFF, 
        RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
        tfr_args = tfr_args, 
        cond = cond,
        t_min_max = tmin_tmax, 
        vmin_vmax = vmin_vmax,  
        cluster_results_dict = cluster_results_dict,
        condition = condition,
        saving_path = saving_path_on_off, 
        show_fig = True
    )

    
print(f"Now processing: lm_GO")
condition = f"ON vs OFF - lm_GO"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off(
    sub_dict_ON_OFF = sub_dict_lm_GO_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond = "lm_GO",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True
)

### 3.4.3. Plot contrast DBS ON - DBS OFF for each contrast of interest ###

In [None]:
print(f"Now processing contrast: GS_successful - lm_GO")
condition = f"GS_successful - lm_GO"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_lm_GO_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "GS_successful",
    cond2 = "lm_GO",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: GS_successful - GS_unsuccessful")
condition = f"GS_successful - GS_unsuccessful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "GS_successful",
    cond2 = "GS_unsuccessful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: GO_successful - GF_successful")
condition = f"GO_successful - GF_successful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "GO_successful",
    cond2 = "GF_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: STOP_successful - STOP_unsuccessful")
condition = f"STOP_successful - STOP_unsuccessful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "STOP_successful",
    cond2 = "STOP_unsuccessful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: STOP_unsuccessful - STOP_successful")
condition = f"STOP_unsuccessful - STOP_successful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "STOP_unsuccessful",
    cond2 = "STOP_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: GS_unsuccessful - GC_successful")
condition = f"GS_unsuccessful - GC_successful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "GS_unsuccessful",
    cond2 = "GC_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: STOP_unsuccessful - CONTINUE_successful")
condition = f"STOP_unsuccessful - CONTINUE_successful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "STOP_unsuccessful",
    cond2 = "CONTINUE_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)



print(f"Now processing contrast: GS_successful - GC_successful")
condition = f"GS_successful - GC_successful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "GS_successful",
    cond2 = "GC_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: STOP_successful - CONTINUE_successful")
condition = f"STOP_successful - CONTINUE_successful"
cluster_results_dict = ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    sub_dict_ON_OFF_cond2 = sub_dict_subsets_ON_OFF, 
    RT_dict_ON_OFF = sub_dict_RT_ON_OFF, 
    tfr_args = tfr_args, 
    cond1 = "STOP_successful",
    cond2 = "CONTINUE_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    cluster_results_dict = cluster_results_dict,
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)