# MAIN
1. split main task into stages
2. epoch motor planning (-500, 500) around target onset
3. epoch motor execution (-500, 700) around movement onset
4. time-frequency for individual subjects + group analysis


In [1]:
import mne
import os
from utils import check_paths
import numpy as np
import pandas as pd
import pickle

import matplotlib.pyplot as plt
%matplotlib qt

# PLANNING / EXECUTION

1. EPOCHING

In [None]:
# Create planning / execution epochs for subjects
eeg_data_dir = 'D:\\BonoKat\\research project\\# study 1\\eeg_data\\set'
group = 'O'
task = '_MAIN' # ['_BL', '_MAIN']
task_stage = '_plan' # '_plan' for PLANNING or '_go' for EXECUTION

subs_dir = os.path.join(eeg_data_dir, group)
figs_dir = os.path.join(eeg_data_dir, 'figures', group, 'epochs', task)
check_paths(figs_dir)

# trial infor dictionary
trial_num_data = {'sub': [], 'task_stage': [], 'block_name': [], 'trial_num': []}

# metadata for each epoch shall include events from the range: [-0.5, 0.7] s,
# i.e. starting with stimulus onset and expanding beyond the end of the epoch
if task_stage == '_plan':
    metadata_tmin, metadata_tmax = -0.5, 0.5 # [-0.5, 0.7] for _go or [-0.5, 0.5] for _plan
    # events of interest
    row_events = ['target_on'] # ['target_on'] or ['go_on']
    # timing of the epochs
    epochs_tmin, epochs_tmax = -0.5, 0.5  # epochs range: [-0.5, 0.7] for _go or [-0.5, 0.5] for _plan

else:
    metadata_tmin, metadata_tmax = -0.5, 0.7 # [-0.5, 0.7] for _go or [-0.5, 0.5] for _plan
    # events of interest
    row_events = ['go_on'] # ['target_on'] or ['go_on']
    # timing of the epochs
    epochs_tmin, epochs_tmax = -0.5, 0.7  # epochs range: [-0.5, 0.7] for _go or [-0.5, 0.5] for _plan

baseline_epo = (epochs_tmin, epochs_tmax) # baseline is the average of the epoch


for sub_name in os.listdir(subs_dir): # os.listdir(subs_dir) OR ['s1_pac_sub00'] # EXCLUDED sub19 - no baseline trigger

    print(f'Creating epochs for {task} task in {sub_name}...')

    preproc_dir = os.path.join(subs_dir, sub_name, 'preproc')
    filt_dir = os.path.join(subs_dir, sub_name, 'preproc', 'filt')
    analysis_dir = os.path.join(preproc_dir, 'analysis')

    eeg_data_path = os.path.join(analysis_dir, f'{sub_name}{task}_reconst.fif')
    raw = mne.io.read_raw_fif(eeg_data_path, preload=True)
    sf = raw.info['sfreq'] # sampling frequency of data

    # Open events from pickle file
    with open(os.path.join(filt_dir, f'{sub_name}{task}_events.pkl'), 'rb') as pickle_file:
        events_raw = pickle.load(pickle_file)

    # check if event time samples are unique
    for i in range(len(events_raw[0]) - 1):
        if events_raw[0][i, 0] == events_raw[0][i+1, 0]:
            print(f'Time sample for events_raw {events_raw[0][i, 2]} and {events_raw[0][i+1, 2]} in trials {i}-{i+1} are not unique!')
            events_raw[0][i+1, 0] = events_raw[0][i, 0] + 1


    # extract time stamps (in samples) from the events_raw[0] array for key in events_raw[1] dict
    bl_start_sample = events_raw[0][events_raw[0][:, 2] == events_raw[1]['baseline']] 
    adapt_start_sample = events_raw[0][events_raw[0][:, 2] == events_raw[1]['adapt1']]
    adapt_finish_sample = events_raw[0][events_raw[0][:, 2] == events_raw[1]['postadapt']]
    
    # convert time stamps from sample to seconds
    bl_start_time = bl_start_sample[0, 0] / sf
    adapt_start_time = adapt_start_sample[0, 0] / sf
    adapt_finish_time = adapt_finish_sample[0, 0] / sf

    print(f'''baseline block: {bl_start_time}-{adapt_start_time} sec ~ {(adapt_start_time-bl_start_time)/60:.2f} min
    adaptation block: {adapt_start_time}-{adapt_finish_time} sec ~ {(adapt_finish_time-adapt_start_time)/60:.2f} min''')

    # crop raw file into segments: baseline block and adaptation block
    bl_raw = raw.copy().crop(tmin=(bl_start_time-1), tmax=(adapt_start_time+1)) # add and subtract 1 to catch start and end events
    adapt_raw = raw.copy().crop(tmin=(adapt_start_time-1), tmax=(adapt_finish_time+1))

    for block_raw in [bl_raw, adapt_raw]:

        if block_raw == bl_raw:
            block_name = '_baseline'
        else:
            block_name = '_adaptation'

        # auto-create metadata:
        # this also returns a new events array and an event_id dictionary. we'll see
        # later why this is important
        metadata, events, event_id = mne.epochs.make_metadata(
            events=events_raw[0],
            event_id=events_raw[1],
            tmin=metadata_tmin,
            tmax=metadata_tmax,
            sfreq=sf,
            row_events=row_events
        )

        epochs = mne.Epochs(
            raw=block_raw,
            tmin=epochs_tmin,
            tmax=epochs_tmax,
            events=events,
            event_id=event_id,
            baseline=baseline_epo,
            detrend=None,
            metadata=metadata,
            reject_by_annotation=True,
            preload=True,
        )

        if task_stage == '_plan':
            epochs = epochs["trial_start.isna() & go_on.isna() & bad_early.isna()"]
        
        else:
            if 'bad_late' in events_raw[1]:
                epochs = epochs["trial_start.isna() & bad_early.isna() & bad_late.isna()"]
            else:
                epochs = epochs["trial_start.isna() & bad_early.isna()"]

        print(f'TOTAL NUMBER OF EPOCHS: {len(epochs)}')

        # Append data to the dictionary
        trial_num_data['sub'].append(sub_name)
        trial_num_data['task_stage'].append(task_stage)
        trial_num_data['block_name'].append(block_name)
        trial_num_data['trial_num'].append(len(epochs))

        # Save the epochs
        epochs.save(os.path.join(analysis_dir, f"{sub_name}{task}_epochs{task_stage}{block_name}-epo.fif"), overwrite=True)
        print(f'Epochs for {task}_{block_name} in {sub_name} saved SUCCESSFULLY')

        # Plot ERP
        fig_erp = epochs.average().plot(gfp=True, spatial_colors=True)
        # Save the ERP plot
        fig_erp.savefig(os.path.join(figs_dir, f"{sub_name}{task}{task_stage}{block_name}_erp_plot.png"), dpi=300)


        spectrum = epochs.compute_psd()
        bands = {'Theta (4-8 Hz)': (4, 8),
                'Alpha (8-12 Hz)': (8, 12),
                'Beta (12-30 Hz)': (12, 30),
                'Gamma (30-50 Hz)': (30, 50),
                'High gamma (50-80 Hz)': (50, 80)}

        fig_psd = spectrum.plot_topomap(bands=bands, vlim="joint", normalize=True)
        fig_psd.savefig(os.path.join(figs_dir, f"{sub_name}{task}{task_stage}{block_name}_psd_topomap.png"), dpi=300)

        print(f'Figures for {sub_name} saved SUCCESSFULLY')

        plt.close('all')

#save the trial number data to a CSV file
trial_num_df = pd.DataFrame(trial_num_data)
trial_num_df.to_csv(os.path.join(figs_dir, f"{task[1:]}{task_stage}_TRIAL_NUM.csv"), index=False)
print(f'TRIAL NUMBER saved SUCCESSFULLY to {os.path.join(figs_dir, f"{task[1:]}{task_stage}_TRIAL_NUM.csv")}')

______________________________________