In [None]:
import mne
# from mne.io import read_raw
import numpy as np
# import matplotlib.pyplot as plt
from os.path import join
import os
# import pandas as pd
import json
# from collections import defaultdict

from functions import ephy_plotting, utils

In [None]:
#from functions.utils import _get_onedrive_path, _update_and_save_multiple_params, extract_stats, create_1_Hz_wide_bands
#from functions.ephy_plotting import plot_raw_stim, plot_psd_log, plot_stft_stim, plot_av_freq_power_by_trial, plot_tfr_success_vs_unsuccess, plot_power_comparison_between_conditions, plot_amplitude_and_difference_from_json
#from functions.preprocessing import create_epochs, create_epochs_subsets_from_behav, create_epochs_specific_freq_band
#from functions.analysis import compute_psd_welch
#from functions.io import save_epochs, load_behav_data

# 1. Load the dataset #

In [None]:
working_path = os.path.dirname(os.getcwd())
results_path = join(working_path, "results")
behav_results_saving_path = join(results_path, "behav_results")
# read the json file containing the included and excluded subjects, based on the behavioral results
included_excluded_file = join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

# keep only subjects starting with "sub":
included_subjects = [subj for subj in included_subjects if subj.startswith('sub')]
print(f'Included_subjects: {included_subjects}')
onedrive_path = utils._get_onedrive_path()

#  Set saving path
freq_wide = 'all frequencies'
fmax = 70
saving_path_group = join(results_path, 'group_level', 'lfp_perc_sig_change', freq_wide)  
os.makedirs(saving_path_group, exist_ok=True)  # Create the directory if it doesn't exist
saving_path_on_off = join(results_path, 'ON_vs_OFF', freq_wide)
os.makedirs(saving_path_on_off, exist_ok=True)

# Set source path for epochs
epochs_path = join(results_path, 'lfp_epochs')

sub_dict_epochs = {}  #  Stores the epochs for each subject/session

In [None]:
for session_ID in included_subjects:
    epochs = mne.read_epochs(os.path.join(epochs_path, f"{session_ID}_cleaned-long-epo.fif"), preload=True)
    sub_dict_epochs[session_ID] = epochs

# 3. Work with epochs #

Each epoch is baseline corrected separately (single-trial baseline correction) using dB normalization. 

## Why?## 
Because when epochs are aligned on the STOP or CONTINU signals, then the baseline period is not always the same (because the SSD varies from trial to trial). Therefore, each epoch has to be corrected separately.

## How? ## 
dB normalization is used. Previously, "percent change from baseline" was used but in Hu et al., 2014, it is showed that this method is only good when baseline correction is applied on averaged epochs, and should be avoided when the correction is applied at the single trial level. They recomment the subtraction method, but this method is also not suited when we then want to perform group analysis, because it is very dependent on the baseline levels of power, which can be very different from one patient to another 

## cf ChatGPT: Why subtraction can be problematic across subjects: ##
Subtraction (power − baseline) preserves the original units (e.g., (µV)²). Different subjects often have very different absolute power levels (different baselines, scalp impedance, electrode coupling, sensor placement, intrinsic amplitude). If you average these absolute differences across subjects, subjects with larger absolute power will dominate the group mean.
Relative normalizations (percent change, dB, or z-score) produce unitless values that place subjects on the same scale and are much safer for group averages and group statistics.

Recommended pipelines (practical):
    Recommendation A — preferred for most TF analyses:
        1. Compute baseline for each trial (baseline mean and — optionally — baseline std if you want z-scores).
        2. Normalize per trial to a unitless measure:
            Percent change: pc = (power - baseline) / baseline * 100
            dB: db = 10 * np.log10(power / baseline) (or 20*log10 if using amplitudes, but for power use 10)
            Z-score: z = (power - baseline_mean) / baseline_std
        3. Average normalized trials within each subject → subject-level TFR.
        4. Average subject-level TFRs across subjects (or run subject-level stats).

    Why: relative or z-scored measures control for between-subject scaling and variance.

    Recommendation B — when subtraction may be acceptable:
        If you have strong reasons to keep absolute changes (e.g., comparing absolute energy in identical recording setups and sensors, and you’ve confirmed baselines are comparable), you can:
            1. Subtract baseline at the single-trial level,
            2. Average trials per subject,
            3. Then average subjects.

    But always check subject baseline distributions (plot baseline means) and consider adding a normalization if they spread widely.



## 3.1. Set the parameters for the plots ## 

In [None]:
######################
### TFR PARAMETERS ###
######################

# decim = 1 
# freqs = np.arange(1, 40, 1) 
# n_cycles = np.minimum(np.maximum(freqs / 2.0, 3), 20)
# tfr_args = dict(
#     method="morlet",
#     freqs=freqs,
#     n_cycles=n_cycles,
#     decim=decim,
#     return_itc=False,
#     average=False
# )        

# tmin_tmax = [-500, 1500]
# vmin_vmax = [-70, 70]

######################
### TFR PARAMETERS ###
######################

decim = 1 
freqs = np.arange(1, fmax, 1) 
# For 500ms time resolution at 1 Hz: n_cycles = 1 * 0.5 = 0.5
# For 50ms time resolution at 40 Hz: n_cycles = 40 * 0.05 = 2
# Linear interpolation between these points
#n_cycles = 0.5 + (freqs - 1) * (2 - 0.5) / (40 - 1)
#n_cycles = freqs / 2.0
n_cycles = np.minimum(np.maximum(freqs / 2.0, 2), 20)

tfr_args = dict(
    method="morlet",
    freqs=freqs,
    n_cycles=n_cycles,
    decim=decim,
    return_itc=False,
    average=False
)        

tmin_tmax = [-500, 1500]
vmin_vmax = [-5, 5]

sub_nums = []  #  List to store unique subject numbers

for sub in included_subjects:
    sub = sub[:6]
    if sub not in sub_nums:  # Check if sub is already in sub_nums
        sub_nums.append(sub)

## 3.2. Plot power changes at the single-subject level ##

### 3.2.1. Plot the mean signal per channel and trial type ###

In [None]:
%matplotlib inline
#  Plot power change for each single subject and each condition/trial type
for sub in sub_nums:
    print(f"Now processing sub: {sub}")
    single_sub_dict_subsets = {key: value for key, value in sub_dict_epochs.items() if sub in key}
    print(single_sub_dict_subsets.keys())
    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST','lfp_perc_sig_change', freq_wide)
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ## single condition            
    for dbs_status in ['DBS OFF', 'DBS ON']:
        if any(dbs_status in key for key in single_sub_dict_subsets.keys()):
            for cond in [
                'GO_successful', 
                # #'GO_unsuccessful', 
                'GF_successful', 
                # #'GF_unsuccessful',
                'GC_successful', 
                # #'GC_unsuccessful',
                'GS_successful', 
                'GS_unsuccessful',
                'stop_successful',
                'stop_unsuccessful',
                'continue_successful'
                ]:
                ephy_plotting.tfr_pow_change_cond(
                    sub_dict = single_sub_dict_subsets, 
                    dbs_status = dbs_status, 
                    epoch_cond = cond,
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax, 
                    vmin_vmax = vmin_vmax,
                    baseline_correction=True,
                    saving_path=saving_path_single,
                    show_fig=True,
                    save_as = 'png'
                    )
            

### 3.2.2 Plot the power change contrasting different trial types ###

In [None]:
#  Plot power change for different constrasts, for each single subject and each condition/trial type
#  Contrasts of interest are:
#       - GS_successful - GO_successful  --> reactive inhibition contrast
#       - GO_successful - GF_successful  --> proactive inhibition contrast
#       - GS_successful - GS_unsuccessful  --> stopping contrast
#       - STOP_successful - STOP_unsuccessful --> stopping contrast but aligned to the STOP cue
#       - STOP_successful - CONTINUE_successful --> attentional contrast aligned to the unexpected stimuli

for sub in sub_nums:
    print(f"Now processing sub: {sub}")
    single_sub_dict_subsets = {key: value for key, value in sub_dict_epochs.items() if sub in key}
    print(single_sub_dict_subsets.keys())
    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST','lfp_perc_sig_change', freq_wide)
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ## difference conditions          
    for dbs_status in ['DBS OFF', 'DBS ON']:
        if any(dbs_status in key for key in single_sub_dict_subsets.keys()):
            condition = f"{dbs_status} GS_successful - GO_successful {sub}"
            ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GS_successful",
                            epoch_cond2 = "GO_successful",
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            
            condition = f"{dbs_status} GO_successful - GF_successful {sub}"
            ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GO_successful",
                            epoch_cond2 = "GF_successful",
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            condition = f"{dbs_status} GS_successful - GS_unsuccessful {sub}"
            ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax,
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "GS_successful",
                            epoch_cond2 = "GS_unsuccessful",
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            condition = f"{dbs_status} stop_successful - stop_unsuccessful {sub}"
            ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "stop_successful",
                            epoch_cond2 = "stop_unsuccessful",
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )
            condition = f"{dbs_status} stop_successful - continue_successful {sub}"
            ephy_plotting.perc_pow_diff_cond(
                            sub_dict = single_sub_dict_subsets,  
                            dbs_status = dbs_status,  
                            tfr_args = tfr_args, 
                            t_min_max = tmin_tmax, 
                            vmin_vmax = vmin_vmax,
                            epoch_cond1 = "stop_successful",
                            epoch_cond2 = "continue_successful",
                            condition = condition,
                            saving_path = saving_path_single,
                            show_fig = True,
                            add_rt = True
                            )

## 3.3. Plot power changes at the group level ##

### 3.3.1. Plot average power changes for each trial type ###

In [None]:
vmin_vmax=[-5, 5]
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:                
    for cond in [
        'GO_successful', 
        #'GO_unsuccessful', 
        'GF_successful', 
        #'GF_unsuccessful',
        'GC_successful', 
        #'GC_unsuccessful',
        'GS_successful', 
        'GS_unsuccessful',
        'stop_successful',
        'stop_unsuccessful',
        'continue_successful'
        ]:
        print(f"Now processing: {dbs_status} - {cond} ")
        ephy_plotting.tfr_pow_change_cond(
                    sub_dict = sub_dict_epochs, 
                    dbs_status = dbs_status, 
                    epoch_cond = cond, 
                    tfr_args = tfr_args, 
                    t_min_max = tmin_tmax, 
                    vmin_vmax = vmin_vmax,
                    baseline_correction = True,
                    saving_path = saving_path_group,
                    show_fig = True,
                    add_rt = True
                    )


### 3.3.2. Plot average power changes across contrasts of interest ###

In [None]:
for dbs_status in [
    'DBS OFF', 
    'DBS ON'
    ]:
    condition = f"{dbs_status} GS successful - GO successful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs, 
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GO_successful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    condition = f"{dbs_status} GO successful - GF successful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs, 
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GO_successful",
                epoch_cond2 = "GF_successful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )

    condition = f"{dbs_status} GS successful - GS unsuccessful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GS_unsuccessful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )

    condition = f"{dbs_status} stop successful - stop unsuccessful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "stop_successful",
                epoch_cond2 = "stop_unsuccessful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    # condition = f"{dbs_status} stop unsuccessful - stop successful"
    # print(f"Now processing: {condition}")
    # ephy_plotting.perc_pow_diff_cond(
    #             sub_dict = sub_dict_epochs,
    #             dbs_status = dbs_status,  
    #             tfr_args = tfr_args, 
    #             t_min_max = tmin_tmax, 
    #             vmin_vmax = vmin_vmax,
    #             epoch_cond1 = "stop_unsuccessful",
    #             epoch_cond2 = "stop_successful",
    #             condition = condition,
    #             saving_path = saving_path_group,
    #             show_fig = True,
    #             add_rt = True
    #             )
    
    condition = f"{dbs_status} stop successful - continue successful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "stop_successful",
                epoch_cond2 = "continue_successful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    condition = f"{dbs_status} GS successful - GC successful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_successful",
                epoch_cond2 = "GC_successful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    

    condition = f"{dbs_status} stop unsuccessful - continue successful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "stop_unsuccessful",
                epoch_cond2 = "continue_successful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )
    
    condition = f"{dbs_status} GS unsuccessful - GC successful"
    print(f"Now processing: {condition}")
    ephy_plotting.perc_pow_diff_cond(
                sub_dict = sub_dict_epochs,
                dbs_status = dbs_status,  
                tfr_args = tfr_args, 
                t_min_max = tmin_tmax, 
                vmin_vmax = vmin_vmax,
                epoch_cond1 = "GS_unsuccessful",
                epoch_cond2 = "GC_successful",
                condition = condition,
                saving_path = saving_path_group,
                show_fig = True,
                add_rt = True
                )            
     


## 3.4. Comparison between DBS ON and DBS OFF ##
Here only keep subjects who have both conditions.

### 3.4.1. First, subselect the data of patients with both sessions, remove other patients ###

In [None]:
sub_both_cond = []

sub_dict_subsets_ON_OFF = {}


for sub in sub_nums:
    if (sub + " DBS ON mSST" in included_subjects) and (sub + " DBS OFF mSST" in included_subjects):
        sub_both_cond.append(sub)

# extract the data for the subjects that have both conditions
for sub in sub_both_cond:
    single_sub_dict_subsets_ON_OFF = {key: value for key, value in sub_dict_epochs.items() if sub in key}
    sub_dict_subsets_ON_OFF[sub] = single_sub_dict_subsets_ON_OFF


### 3.4.2. Plot contrast DBS ON - DBS OFF for each trial type ###

In [None]:
for cond in [         
    'GO_successful', 
    #'GO_unsuccessful', 
    'GF_successful', 
    #'GF_unsuccessful',
    'GC_successful', 
    #'GC_unsuccessful',
    'GS_successful', 
    'GS_unsuccessful',
    'stop_successful',
    'stop_unsuccessful',
    'continue_successful'
    ]:
    print(f"Now processing: {cond}")
    condition = f"ON vs OFF - {cond}"
    ephy_plotting.perc_pow_diff_on_off(
        sub_dict_ON_OFF = sub_dict_subsets_ON_OFF, 
        tfr_args = tfr_args, 
        cond = cond,
        t_min_max = tmin_tmax, 
        vmin_vmax = vmin_vmax,  
        condition = condition,
        saving_path = saving_path_on_off, 
        show_fig = True
    )


### 3.4.3. Plot contrast DBS ON - DBS OFF for each contrast of interest ###

In [None]:
print(f"Now processing contrast: GS_successful - GO_successful")
condition = f"ON - OFF - GS_successful - GO_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "GS_successful",
    cond2 = "GO_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: GS_successful - GS_unsuccessful")
condition = f"ON - OFF - GS_successful - GS_unsuccessful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "GS_successful",
    cond2 = "GS_unsuccessful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: GO_successful - GF_successful")
condition = f"ON - OFF - GO_successful - GF_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "GO_successful",
    cond2 = "GF_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: stop_successful - stop_unsuccessful")
condition = f"ON - OFF - stop_successful - stop_unsuccessful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "stop_successful",
    cond2 = "stop_unsuccessful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: stop_unsuccessful - stop_successful")
condition = f"ON - OFF - stop_unsuccessful - stop_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "stop_unsuccessful",
    cond2 = "stop_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: GS_unsuccessful - GC_successful")
condition = f"ON - OFF - GS_unsuccessful - GC_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "GS_unsuccessful",
    cond2 = "GC_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)

print(f"Now processing contrast: stop_unsuccessful - continue_successful")
condition = f"ON - OFF - stop_unsuccessful - continue_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "stop_unsuccessful",
    cond2 = "continue_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)



print(f"Now processing contrast: GS_successful - GC_successful")
condition = f"ON - OFF - GS_successful - GC_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "GS_successful",
    cond2 = "GC_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True,
    add_rt = True,
    add_ssd = True            
)

print(f"Now processing contrast: stop_successful - continue_successful")
condition = f"ON - OFF - stop_successful - continue_successful"
ephy_plotting.perc_pow_diff_on_off_contrast(
    sub_dict_ON_OFF = sub_dict_subsets_ON_OFF,
    tfr_args = tfr_args, 
    cond1 = "stop_successful",
    cond2 = "continue_successful",
    t_min_max = tmin_tmax, 
    vmin_vmax = vmin_vmax,  
    condition = condition,
    saving_path = saving_path_on_off,
    show_fig = True            
)