In [None]:
import mne
from mne.io import read_raw
import numpy as np
import matplotlib.pyplot as plt
#import matplotlib.colors as mcolors
from os.path import join
import os
import pandas as pd
from scipy.signal import hilbert
import json
from collections import defaultdict



from functions import ephy_plotting, preprocessing, analysis, io, utils

In [None]:
working_path = os.path.dirname(os.getcwd())
results_path = join(working_path, "results")
behav_results_saving_path = join(results_path, "behav_results")
# read the json file containing the included and excluded subjects, based on the behavioral results
included_excluded_file = join(behav_results_saving_path, 'final_included_subjects.json')
with open(included_excluded_file, 'r') as file:
    included_subjects = json.load(file)

# keep only subjects starting with "sub":
included_subjects = [subj for subj in included_subjects if subj.startswith('sub')]
print(f'Included_subjects: {included_subjects}')
onedrive_path = utils._get_onedrive_path()

#  Set saving path
saving_path_group = join(results_path, 'group_level', 'lfp_perc_sig_change', 'morlet_low_freq')  
os.makedirs(saving_path_group, exist_ok=True)  # Create the directory if it doesn't exist
saving_path_on_off = join(results_path, 'ON_vs_OFF', 'morlet_low_freq')
os.makedirs(saving_path_on_off, exist_ok=True)

In [None]:
# Prepare variables and dictionaries for storing results
# Dictionary to store subject epochs in
sub_dict_epochs_subsets = {}  #  Stores the epochs for each condition and for each subject/session
sub_dict_lm_GO = {}  #  Stores the epochs for lm_GO trials for each subject/session
sub_dict_RT = {}  #  Stores the mean reaction time for each trial type
sub_dict_stats = {}  #  Stores behavioral stats for each subject/session

cluster_results_dict = {}
cluster_results_dict = defaultdict(dict)  # Each missing key gets an empty dictionary
cluster_results_dict['All subjects'] = included_subjects

In [None]:
# Load all data for all included subjects
data = io.load_behav_data(included_subjects, onedrive_path)

# Compute statistics for each loaded subject
stats = {}
stats = utils.extract_stats(data)
# If no file was found, create a new JSON file
filename = "stats.json"
file_path = os.path.join(results_path, filename)
#if not os.path.isfile(file_path):
#    with open(file_path, "w", encoding="utf-8") as file:
#            json.dump({}, file, indent=4)

# Save the updated or new JSON file
with open(file_path, "w", encoding="utf-8") as file:
    json.dump(stats, file, indent=4)

# remove sub023
included_subjects.remove('sub023 DBS ON mSST')
included_subjects.remove('sub023 DBS OFF mSST')
included_subjects


In [None]:
# Load all data for all included subjects
data = io.load_behav_data(included_subjects, onedrive_path)

# Compute statistics for each loaded subject
stats = {}
stats = utils.extract_stats(data)

In [None]:
# Start a loop through subjects
for session_ID in included_subjects:
    print(f"Now processing {session_ID}")
    session_dict = {}
    sub = session_ID[:6]
    subject_ID = session_ID.split(' ') [0]
    condition = session_ID.split(' ') [1] + ' ' + session_ID.split(' ') [2]
    sub_onedrive_path = join(onedrive_path, subject_ID)
    sub_onedrive_path_task = join(onedrive_path, subject_ID, 'synced_data', session_ID)
    filename = [f for f in os.listdir(sub_onedrive_path_task) if (
        f.endswith('.set') and f.startswith('SYNCHRONIZED_INTRACRANIAL'))]
    
    if not filename:
        raise FileNotFoundError(f"No .set file found in {sub_onedrive_path_task}")

    file = join(sub_onedrive_path_task, filename[0])

    if not os.path.isfile(file):
        raise FileNotFoundError(f"File does not exist: {file}")

    print(f"Loading file: {file}")
    #file = join(sub_onedrive_path_task, filename[0])
    raw = read_raw(file, preload=True)

    saving_path_single = join(results_path, 'single_sub', f'{sub} mSST','freq_response') 
    os.makedirs(saving_path_single, exist_ok=True)  # Create the directory if it doesn't exist

    session_dict['CHANNELS'] = raw.ch_names

    # Rename channels to be consistent across subjects:
    new_channel_names = [
        "Left_STN",
        "Right_STN",
        "left_peak_STN",
        "right_peak_STN",
        "STIM_Left_STN",
        "STIM_Right_STN"
    ]

    # Get the existing channel names
    old_channel_names = raw.ch_names

    # Create a mapping from old to new names
    rename_dict = {old: new for old, new in zip(old_channel_names, new_channel_names)}

    # Rename the channels
    raw.rename_channels(rename_dict)

    session_dict['RENAMED_CHANNELS'] = raw.ch_names

    # Filter between 1 and 95 Hz:
    filtered_data = raw.copy().filter(l_freq=1, h_freq=95)

    # Extract events and create epochs
    # only keep lfp channels
    filtered_data_lfp = filtered_data.copy().pick_channels([filtered_data.ch_names[0], filtered_data.ch_names[1]])

    mSST_raw_behav_session_data_path = join(
            onedrive_path, subject_ID, "raw_data", 'BEHAVIOR', condition, 'mSST'
            )
    for filename in os.listdir(mSST_raw_behav_session_data_path):
            if filename.endswith(".csv"):
                fname = filename
    filepath_behav = join(mSST_raw_behav_session_data_path, fname)
    df = pd.read_csv(filepath_behav)

    # return the index of the first row which is not filled by a Nan value:
    start_task_index = df['blocks.thisRepN'].first_valid_index()
    # Crop dataframe in 2 parts: before and after the task:
    #df_training = df.iloc[:start_task_index]
    df_maintask = df.iloc[start_task_index:-1]

    # remove the trials with early presses, as in these trials the cues were not presented
    early_presses = df_maintask[df_maintask['early_press_resp.corr'] == 1]
    early_presses_trials = list(early_presses.index)
    number_early_presses = len(early_presses_trials)

    # remove trials with early presses from the dataframe:
    df_maintask_copy = df_maintask.drop(early_presses_trials)        

    # Filter the channels in specific frequency bands e.g. theta
    theta = [13, 20]
    filtered_theta = filtered_data_lfp.copy().filter(l_freq=theta[0], h_freq=theta[1])
    # Initialize an empty Raw object to store the βA signals
    band_data = raw.copy()
    # Iterate over each channel
    channels_to_process = raw.ch_names[:2]
    for channel in channels_to_process:
        print(f"Processing channel: {channel}")

        # Create a copy of the original data for filtering
        single_channel_data = raw.copy().pick([channel])  # Use the new `pick()` method

        # Initialize an array to store envelopes for each band
        envelopes = []

        low = theta[0]
        high = theta[1]
        # Apply bandpass filtering
        filtered = mne.filter.filter_data(
            single_channel_data.get_data().flatten(),
            sfreq=raw.info['sfreq'],
            l_freq=low,
            h_freq=high,
            method='fir',  # Zero-phase FIR filter
            verbose=False,
            fir_design = 'firwin',
            l_trans_bandwidth=0.5,
            h_trans_bandwidth=0.5
        )

        # Compute the analytic signal (Hilbert transform)
        analytic_signal = hilbert(filtered)
        envelope = np.abs(analytic_signal)

        # Normalize the envelope
        #normalized_envelope = (envelope - np.nanmean(envelope)) * 100
        normalized_envelope = envelope
        envelopes.append(normalized_envelope)

        #plt.plot(single_channel_data.get_data())
        #plt.plot(filtered, color="grey")
        #plt.plot(envelope, color="black")
        #plt.plot(normalized_envelope, color="red")
        #plt.legend()
        #plt.show(block=True)

        # Average the envelopes across bands
        #freq_resp = np.nanmean(envelopes, axis=0)
        #plt.plot(freq_response)

        # Replace the corresponding channel's data in `band_data`
        band_data._data[raw.ch_names.index(channel), :] = normalized_envelope


    # create epochs for the theta band
    epochs, filtered_event_dict = preprocessing.create_epochs(band_data, session_ID)
    
    # Filter successful and unsuccessful trials:
    (epochs_subsets, epochs_lm, mean_RT_dict) = preprocessing.create_epochs_subsets_from_behav(
            df_maintask_copy, 
            epochs, 
            filtered_event_dict
            )
    
    sub_dict_epochs_subsets[session_ID] = epochs_subsets
    sub_dict_lm_GO[session_ID] = epochs_lm
    sub_dict_RT[session_ID] = mean_RT_dict
    sub_dict_stats[session_ID] = stats[session_ID]

In [None]:
sub_dict_epochs_subsets['sub006 DBS ON mSST']['GS_successful'].plot()

In [None]:
%matplotlib qt
sub_dict_epochs_subsets['sub006 DBS ON mSST']['GO_successful'].plot_image(picks = (raw.ch_names[0]), combine="mean")
sub_dict_epochs_subsets['sub006 DBS ON mSST']['GO_successful'].plot_image(picks = (raw.ch_names[1]), combine="mean")

In [None]:
# RT = stats['sub011 DBS OFF mSST']['go_trial RTs (ms)']
# RT_df = pd.DataFrame({'RT': RT, 'index': range(len(RT))})
# RT_sec = RT_df['RT'].values / 1000  # Convert to seconds

In [None]:
# # sort the dataframe by RT
# RT_df = RT_df.sort_values(by='RT')
# # add a column with the rank of the RT
# RT_df['rank'] = RT_df['RT'].rank(method='min')
# # sort the RT_df by index
# RT_df = RT_df.sort_values(by='index')

In [None]:
# # extract the rank values and store them as integers
# rank_values = RT_df['rank'].values
# rank_values = rank_values - 1

In [None]:
# rank_values.tolist()
# # Convert rank values to integers
# rank_values = rank_values.astype(int)
# rank_values

In [None]:
# Get the Epochs object
epochs = sub_dict_epochs_subsets['sub015 DBS ON mSST']['GO_successful']

# Get the RTs (in seconds) for these epochs
RT_sec = np.array(stats['sub015 DBS ON mSST']['go_trial RTs (ms)']) / 1000  # Assuming this is where RTs are stored

# # Rank/order the RTs (ascending order)
# rank_order = np.argsort(RT_sec)

# # Sort RTs to match the rank order
# RT_sec_sorted = np.array(RT_sec)[rank_order]

# # Now plot with consistent ordering
# epochs.plot_image(
#     picks=raw.ch_names[0],
#     combine="mean",
#     order=rank_order,
#     overlay_times=RT_sec_sorted
#)

In [None]:
# # Assuming RT_sec is already aligned with the epochs (i.e., came from behavior logs matched to these trials)
# epochs.metadata = pd.DataFrame({'RT_sec': RT_sec})

In [None]:
# rank_order = np.argsort(epochs.metadata['RT_sec'].values)
# RT_sec_sorted = epochs.metadata['RT_sec'].values[rank_order]

In [None]:
%matplotlib inline
# Add RTs to metadata
epochs.metadata = pd.DataFrame({'RT': RT_sec})

# Get sorted indices
sorted_indices = np.argsort(epochs.metadata['RT'].values)

# Sort the Epochs object using those indices
epochs_sorted = epochs[sorted_indices]

# Also sort RTs for overlay (though they should now match)
RT_sec_sorted = RT_sec[sorted_indices]

# Now plot
epochs_sorted.plot_image(
    picks=raw.ch_names[1],
    combine="mean",
    overlay_times=RT_sec_sorted,
    vmin=-1.5e7,  # set your minimum value (in Volts)
    vmax=1.5e7,
)
