In [None]:
import mne
from mne.io import read_raw
import numpy as np
import matplotlib.pyplot as plt
from os.path import join
import os
import pandas as pd
import json
from collections import defaultdict

from functions import ephy_plotting, preprocessing, analysis, io, utils

# 1. Load the dataset #

In [None]:
working_path = os.path.dirname(os.getcwd())
results_path = join(working_path, "results")
behav_results_saving_path = join(results_path, "behav_results")
# read the json file containing the included and excluded subjects, based on the behavioral results
included_excluded_file = join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

# keep only subjects starting with "sub":
included_subjects = [subj for subj in included_subjects if subj.startswith('sub')]
print(f'Included_subjects: {included_subjects}')
onedrive_path = utils._get_onedrive_path()

#  Set saving path for cleaned epochs
saving_path= join(results_path, 'lfp_epochs')
os.makedirs(saving_path, exist_ok=True)  # Create the directory if it doesn't exist

sub_dict_epochs = {}  #  Stores the epochs for each subject/session
all_sub_session_dict = {}
all_sub_session_dict = defaultdict(dict)  # Each missing key gets an empty dictionary


In [None]:
# # Load all data for all included subjects
# data = io.load_behav_data(included_subjects, onedrive_path)

# # Compute statistics for each loaded subject
# stats = {}
# stats = utils.extract_stats(data)
# # If no file was found, create a new JSON file
# filename = "stats.json"
# file_path = os.path.join(results_path, filename)
# #if not os.path.isfile(file_path):
# #    with open(file_path, "w", encoding="utf-8") as file:
# #            json.dump({}, file, indent=4)

# # Save the updated or new JSON file
# with open(file_path, "w", encoding="utf-8") as file:
#     json.dump(stats, file, indent=4)

# # remove sub027 DBS OFF mSST from included_subjects because it has not been synchronized yet
# #included_subjects.remove('sub027 DBS OFF mSST')
# included_subjects


# 2. Create full session plots for each subject (Raw traces, TFR plot, PSD) #

In [None]:
for session_ID in included_subjects:
    print(f"Now processing {session_ID}")
    session_dict = {}
    sub = session_ID[:6]
    condition = session_ID.split(' ') [1] + ' ' + session_ID.split(' ') [2]
    sub_onedrive_path = join(onedrive_path, sub)
    sub_onedrive_path_task = join(onedrive_path, sub, 'synced_data', session_ID)
    filename = [f for f in os.listdir(sub_onedrive_path_task) if (
        f.endswith('.set') and f.startswith('SYNCHRONIZED_INTRACRANIAL'))]
    
    if not filename:
        raise FileNotFoundError(f"No .set file found in {sub_onedrive_path_task}")

    file = join(sub_onedrive_path_task, filename[0])

    if not os.path.isfile(file):
        raise FileNotFoundError(f"File does not exist: {file}")

    print(f"Loading file: {file}")
    #file = join(sub_onedrive_path_task, filename[0])
    raw = read_raw(file, preload=True)

    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST') 
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    ephy_plotting.plot_raw_stim(session_ID, raw, saving_path_single)
    psd_left, freqs_left, psd_right, freqs_right = analysis.compute_psd_welch(raw)
    session_dict['psd_left_V^2/Hz'] = psd_left
    session_dict['freqs_left'] = freqs_left
    session_dict['psd_right_V^2/Hz'] = psd_right
    session_dict['freqs_right'] = freqs_right

    # Compute band power for theta, alpha, low-beta and high-beta ranges:
    band_metrics_left = utils.compute_band_metrics(psd_left, freqs_left)
    band_metrics_right = utils.compute_band_metrics(psd_right, freqs_right)
    session_dict['left'] = band_metrics_left
    session_dict['right'] = band_metrics_right

    print(f'Values for Left STN: {band_metrics_left}')
    print(f'Values for Right STN: {band_metrics_right}')

    ephy_plotting.plot_psd_log(
        session_ID, raw, freqs_left, psd_left, 
        freqs_right, psd_right, saving_path_single, is_filt=False
        )
    ephy_plotting.plot_stft_stim(
        session_ID, raw, saving_path=saving_path_single, is_filt=False, 
        vmin = -18, vmax = -12, 
        fmin=0, fmax=100
        )

    all_sub_session_dict[sub][condition] = session_dict


In [None]:
df = analysis.compare_band_power(all_sub_session_dict)
# save dataframe to excel
df.to_excel(join(results_path, "band_power_comparison.xlsx"), index=False)

In [None]:
# from scipy.stats import wilcoxon

# for band in df['band'].unique():
#     for hemi in df['hemisphere'].unique():
#         subset = df[(df['band'] == band) & (df['hemisphere'] == hemi)]
#         stat, p_val = wilcoxon(subset['DBS OFF_power_uV2'], subset['DBS ON_power_uV2'])
#         print(f"{band} - {hemi}: Wilcoxon p={p_val:.4f}")

# 3. Work with epochs #

In [None]:
for session_ID in included_subjects:
    session_dict = {}
    sub = session_ID[:6]
    condition = session_ID.split(' ') [1] + ' ' + session_ID.split(' ') [2]
    print(f"Now processing {session_ID}")
    all_sub_session_dict[sub][condition] = session_dict
    
    sub_onedrive_path = join(onedrive_path, sub)
    sub_onedrive_path_task = join(onedrive_path, sub, 'synced_data', session_ID)
    filename = [f for f in os.listdir(sub_onedrive_path_task) if (
        f.endswith('.set') and f.startswith('SYNCHRONIZED_INTRACRANIAL'))]
    
    if not filename:
        raise FileNotFoundError(f"No .set file found in {sub_onedrive_path_task}")

    file = join(sub_onedrive_path_task, filename[0])

    if not os.path.isfile(file):
        raise FileNotFoundError(f"File does not exist: {file}")

    print(f"Loading file: {file}")
    #file = join(sub_onedrive_path_task, filename[0])
    raw = read_raw(file, preload=True)
    all_sub_session_dict[sub][condition]['CHANNELS'] = raw.ch_names

    # Rename channels to be consistent across subjects:
    new_channel_names = [
        "Left_STN",
        "Right_STN",
        "left_peak_STN",
        "right_peak_STN",
        "STIM_Left_STN",
        "STIM_Right_STN"
    ]

    # Get the existing channel names
    old_channel_names = raw.ch_names

    # Create a mapping from old to new names
    rename_dict = {old: new for old, new in zip(old_channel_names, new_channel_names)}

    # Rename the channels
    raw.rename_channels(rename_dict)

    all_sub_session_dict[sub][condition]['RENAMED_CHANNELS'] = raw.ch_names

    # Filter between 1 and 80 Hz:
    filtered_data = raw.copy().filter(l_freq=1, h_freq=80)

    # Extract events and create epochs
    # only keep lfp channels
    filtered_data_lfp = filtered_data.copy().pick_channels([filtered_data.ch_names[0], filtered_data.ch_names[1]])

    #epochs, filtered_event_dict = preprocessing.create_epochs(filtered_data_lfp, session_ID)

    mSST_raw_behav_session_data_path = join(
            onedrive_path, sub, "raw_data", 'BEHAVIOR', condition, 'mSST'
            )
    for filename in os.listdir(mSST_raw_behav_session_data_path):
            if filename.endswith(".csv"):
                fname = filename
    filepath_behav = join(mSST_raw_behav_session_data_path, fname)
    print(filepath_behav)
    df = pd.read_csv(filepath_behav)

    # return the index of the first row which is not filled by a Nan value:
    start_task_index = df['blocks.thisRepN'].first_valid_index()
    stop_task_index = df['blocks.thisRepN'].last_valid_index()
    df_maintask = df.iloc[start_task_index:stop_task_index + 1] ### HERE MISTAKE OF INDEXING: CHECK IN OTHER SCRIPTS IF THIS IS ALSO WRONG!!!


    # remove all useless columns to clean up dataframe
    column_names = df_maintask.columns
    columns_to_keep = [i for i in [
        'blocks.thisN', 'trial_loop.thisN', 'trial_type', 
        'continue_signal_time', 'stop_signal_time', 
        'key_resp_experiment.keys', 'key_resp_experiment.corr', 'key_resp_experiment.rt',
        'early_press_resp.keys', 'early_press_resp.rt', 'early_press_resp.corr',
        'late_key_resp1.keys', 'late_key_resp1.rt', 
        'late_key_resp2.keys', 'late_key_resp2.rt'
        ] if i in column_names]

    mini_df_maintask = df_maintask[columns_to_keep]
    print(mini_df_maintask.shape)

    # remove the trials with early presses, as in these trials the cues were not presented (for mSST)
    early_presses = mini_df_maintask[mini_df_maintask['early_press_resp.corr'] == 1]
    early_presses_trials = list(early_presses.index)
    number_early_presses = len(early_presses_trials)
    print(f'Number of early presses: {number_early_presses}')

    # remove trials with early presses from the dataframe:
    df_maintask_copy = mini_df_maintask.drop(early_presses_trials).reset_index(drop=True)
    print(df_maintask_copy.shape)
    print(df_maintask_copy['blocks.thisN'])

    # First generate global epochs (without taking into account success outcome)
    # events and event_id used for epochs creation
    events, event_id = mne.events_from_annotations(filtered_data_lfp)
    epochs, filtered_event_dict = preprocessing.create_epochs(
         filtered_data_lfp, 
         sub, 
         keys_to_keep = ['GC', 'GF', 'GO', 'GS', 'continue', 'stop'],
         tmin = -3.5,
         tmax = 3.5,
         baseline=None
         )
    n_epochs = len(epochs)
    print(epochs)

    # inverse mapping (event code -> label)
    inv_event_id = {v: k for k, v in event_id.items()}

    metadata = pd.DataFrame(index=np.arange(len(epochs)))
    metadata["event"] = [inv_event_id[e] for e in epochs.events[:, 2]]
    metadata["trial_type"] = np.nan

    # LFP -> behavioral naming mapping
    mapping = {
        "GC": "go_continue_trial",
        "GO": "go_trial",
        "GF": "go_fast_trial",
        "GS": "stop_trial",
    }

    trial_mask = metadata["event"].isin(mapping.keys())

    assert trial_mask.sum() == len(df_maintask_copy), \
        f"Mismatch: {trial_mask.sum()} LFP trials vs {len(df_maintask_copy)} behavioral trials"

    # fill directly from behavioral file
    for col in df_maintask_copy.columns:
        metadata.loc[trial_mask, col] = df_maintask_copy[col].values

    for i in metadata.index:
        if metadata.loc[i, "event"] == "continue":
            # find the last GC before this
            prev_idx = metadata.loc[:i-1][metadata["event"] == "GC"].index[-1]
            metadata.loc[i, df_maintask_copy.columns] = metadata.loc[prev_idx, df_maintask_copy.columns]

        elif metadata.loc[i, "event"] == "stop":
            # find the last GS before this
            prev_idx = metadata.loc[:i-1][metadata["event"] == "GS"].index[-1]
            metadata.loc[i, df_maintask_copy.columns] = metadata.loc[prev_idx, df_maintask_copy.columns]

    epochs.metadata = metadata

    sub_dict_epochs[session_ID] = epochs


# For each session, look at the epochs and remove 'bad' epochs, then save the cleaned file #

In [None]:
subject_id = 'sub006 DBS OFF mSST'
cleaned_epochs = sub_dict_epochs[subject_id].copy()

In [None]:
cleaned_epochs

In [None]:
%matplotlib qt

# cleaned_epochs.plot(n_epochs=10, n_channels = len(cleaned_epochs.ch_names), events=True)

In [None]:
# nan_bads = [252, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272 ,273, 274, 275, 276, 277, 278, 279, 280]

In [None]:
# cleaned_epochs.drop(nan_bads)

In [None]:
%matplotlib qt
cleaned_epochs.plot(n_epochs=4, n_channels = len(cleaned_epochs.ch_names), events=True)

In [None]:
cleaned_epochs.metadata

In [None]:
cleaned_epochs

In [None]:
metadata_df = pd.DataFrame(cleaned_epochs.metadata)
# save both to csv (easier for later python import), and xlsx (easier to read in excel)
metadata_df.to_csv(os.path.join(saving_path, f"{subject_id}_cleaned-long-epo_metadata.csv"), index=True)
metadata_df.to_excel(os.path.join(saving_path, f"{subject_id}_cleaned-long-epo_metadata.xlsx"), index=True)

In [None]:
file_epoch = os.path.join(saving_path, f"{subject_id}_cleaned-long-epo.fif")
cleaned_epochs.save(file_epoch, overwrite=True)

In [None]:
#epoch_reload = mne.read_epochs(os.path.join(saving_path, f"sub011 DBS OFF mSST_cleaned-long-epo.fif"), preload=True)

In [None]:
# cropped_epochs = cleaned_epochs.copy().crop(tmin=-0.5, tmax=1.5)
# file_cropped_epoch = os.path.join(saving_path, f"{subject_id}_cleaned-short-epo.fif")
# cropped_epochs.save(file_cropped_epoch, overwrite=True)