## Set up the run parameters parameters

In [None]:
###use these parameters for testing this notebook outside of the automated loop of q1k_automated_reports.ipynb
#subject_id_in = "100162_P"
#subject_id_out = "100162P"
#task_id_in = "PLR"
#task_id_in_et = "PLR" 
#task_id_out = "PLR"
#run_id = "1"
#session_id = "01"
#project_path = "/project/def-emayada/q1k/experimental/HSJ/"
#dataset_group = "experimental"
#site_code = "HSJ" #'MHC' or 'HSJ'
#et_sync = True
#html_figures = True

#use these empty parameters when executing this notebook from an automation script.
subject_id_in = ""
subject_id_out = ""
task_id_in = ""
task_id_in_et = "" 
ask_id_out = ""
run_id = ""
session_id = ""
project_path = ""
dataset_group = ""
site_code = ""
et_sync = True
html_figures = False

print('subject_id_in: ' + subject_id_in)
print('subject_id_out: ' + subject_id_out)
print('task_id_in: ' + task_id_in)
print('task_id_in_et: ' + task_id_in_et)
print('task_id_out: ' + task_id_out)
print('run_id: ' + run_id)
print('session_id: ' + session_id)
print('project_path: ' + project_path)
print('dataset_group: ' + dataset_group)
print('site_code: ' + site_code)


In [None]:
# import packages
import pandas as pd
import numpy as np
import os
import mne
import mne_bids
from matplotlib import pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
import plotly.offline as py
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook"
import q1k_init_tools as qit
import glob
import re
import warnings
warnings.filterwarnings('ignore')
import seaborn as sns

In [None]:
#define the DIN events associated with events of interest based on the task_id..
din_str, event_dict_offset = qit.set_din_str(task_id_out)
print("DIN strings for " + task_id_out)
din_str

## Read the EEG file

In [None]:
subject_id_in

In [None]:
# generate the input paths
session_file_name_eeg, session_file_name_et = qit.generate_session_ids(dataset_group, project_path, site_code, task_id_in, subject_id_in, run_id)

In [None]:
print("EEG session file name: " + session_file_name_eeg[0])

In [None]:
if session_file_name_et and session_file_name_et[0]:
    print("ET session file name: " + session_file_name_et[0])
else:
    print("Could not find the session ET file.. abandoning ET sync portion of the initiation process.")
    et_sync = False
    

In [None]:
# read the input EEG session
print('Reading: ' + session_file_name_eeg[0])
eeg_raw = mne.io.read_raw_egi(session_file_name_eeg[0])
eeg_raw_fresh=eeg_raw.copy() #make a fresh copy for later
device_info=eeg_raw.info['device_info']

In [None]:
# peak... at the EEG channel types
channel_types = eeg_raw.get_channel_types()
print("EEG Channel Types:", channel_types)
print("EEG Channel Names:", eeg_raw.info['ch_names'])


In [None]:
# show the sensor topography
fig = eeg_raw.plot_sensors(show_names=True)

## Get and modify the EEG event structures

In [None]:
# create the EEG event structures
eeg_events = mne.find_events(eeg_raw, shortest_event = 1)
eeg_event_dict = qit.get_event_dict(eeg_raw, eeg_events, event_dict_offset)

In [None]:
print('EEG event dict:')
eeg_event_dict

In [None]:
# peak... at the EEG event scatter plot.. event time stamp by label index
fig=px.scatter(x=eeg_events[:,0],y=eeg_events[:,2])
fig.update_layout(title='Original EEG event times')
fig.update_xaxes(title_text='Time of event(ms)')
fig.update_yaxes(title_text='Event index')
py.iplot(fig)
# print the scatterplot to an html file for easy exploration.
if html_figures:
    fig.write_html("html_figures/eeg_event_times.html")

In [None]:
if not din_str:
    print('Required EEG DIN events are missing... skipping EEG DIN check and DIN distance display')
else:
    #check that the din_str events exist in the eeg_event_dict..
    din_str = qit.din_check(eeg_event_dict, din_str)
    #get the distance between the DIN events of interest..
    din_diffs, din_diffs_time = qit.get_din_diff(eeg_events, eeg_event_dict, din_str)
    #build the figure...
    fig=px.scatter(x=din_diffs_time, y=din_diffs)
    fig.update_layout(title='Time between EEG DIN events of interest')
    fig.update_xaxes(title_text='Time of event(ms)')
    fig.update_yaxes(title_text='Inter event interval')
    # print the scatterplot to an html file for easy exploration.
    if html_figures:
        fig.write_html("html_figures/eeg_din_diffs.html")

In [None]:
if not din_str:
    print('Required EEG DIN events are missing... skipping EEG stimulus onset DIN process')
else:
    # handle task specific EEG event interpretation..
    eeg_events, eeg_stims, eeg_iti, eeg_din_offset, eeg_event_dict, new_events = qit.eeg_event_test(eeg_events, eeg_event_dict, din_str, task_name=task_id_out)

In [None]:
# print some stim event summaries
column_values = eeg_events[:, 2]

# Get unique values and their counts
unique_values, counts = np.unique(column_values, return_counts=True)

# Print the counts of each unique value
print("Counts of each event ID value:")
for value, count in zip(unique_values, counts):
    print(f"Event ID: {value}, Count: {count}")

# Print the event labels
print("Label\tValue")
for label, value in eeg_event_dict.items():
    print(f"{label}\t{value}")

In [None]:
# peak... at the EEG event scatter plot.. event time stamp by label index... including new *_d DIN events if generated
fig=px.scatter(x=eeg_events[:,0],y=eeg_events[:,2])
fig.update_layout(title='DIN updated EEG event times')
fig.update_xaxes(title_text='Time of event(ms)')
fig.update_yaxes(title_text='Event index')
fig.show()
if html_figures:
    fig.write_html("html_figures/eeg_update_event_times.html")

In [None]:
if not din_str:
    print('Required DIN events are missing... skipping stimulus DIN ITI display')
else:
    # peak... at the distance between *_d stim DIN events
    fig=px.scatter(x=eeg_stims[1:,0],y=eeg_iti)
    fig.update_layout(title='Stim DIN event Inter Trial Intervals (ITI)')
    fig.update_xaxes(title_text='Time of event(ms)')
    fig.update_yaxes(title_text='Stim DIN event ITI (ms)')
    fig.show()
    if html_figures:
        fig.write_html("html_figures/eeg_din_iti.html")

In [None]:
if not din_str:
    print('Required DIN events are missing... skipping stimulus DIN event offset display')
else:
    # peak... at the distance between stim events and *_d stim DIN events
    fig=px.scatter(x=eeg_stims[:,0],y=eeg_din_offset)
    fig.update_layout(title='Stim DIN offsets')
    fig.update_xaxes(title_text='Time of event(ms)')
    fig.update_yaxes(title_text='Stim DIN offset (ms)')
    fig.show()
    if html_figures:
        fig.write_html("html_figures/eeg_stim_din_offset.html")

## Read the eye-tracking data

In [None]:
et_raw, et_raw_df, et_events, et_event_dict = qit.et_read(session_file_name_et[0], blink_interp=False, fill_nans=False, resamp=False)

In [None]:
if et_sync:
    # peak... at the ET channel types
    channel_types = et_raw.get_channel_types()
    print("ET Channel Types:", channel_types)
    print("ET Channel Names:", et_raw.info['ch_names'])
else:
    print("et_sync = False: not printing ET channle types")

In [None]:
et_raw.plot(duration=20, scalings=dict(eyegaze=1e2,pupil=1e3))

## Handle the Eye-Tracking events

In [None]:
if et_sync:
    print("ET event dict:", et_event_dict)
else:
    print("et_sync = False: not printing ET event dict")

In [None]:
if et_sync:
    # peak... at the ET event scatter plot.. event time stamp by label index
    fig=px.scatter(x=et_events[:,0],y=et_events[:,2])
    fig.update_layout(title='Original ET event times')
    fig.update_xaxes(title_text='Time of event(ms)')
    fig.update_yaxes(title_text='Event index')
    py.iplot(fig)
    # print the scatterplot to an html file for easy exploration.
    if html_figures:
        fig.write_html("html_figures/et_event_times.html")
else:
    print("et_sync = False: not plotting the original ET events")

In [None]:
##DIN testing


## fill NaNs in DIN channel with zeros
#et_raw_df['DIN']=et_raw_df['DIN'].fillna(0)

## Correct blips to zero for a single sample while DIN8 is on.
#for ind, row in et_raw_df.iterrows():
#    if ind < len(et_raw_df)-1:
#        if ind > 0:
#            if et_raw_df['DIN'][ind] == 0:
#                if et_raw_df['DIN'][ind-1] == 8:
#                    if et_raw_df['DIN'][ind+1] == 8:
#                        et_raw_df['DIN'].loc[ind] = 8

## convert the ET DIN channel into ET events
## find when the DIN channel changes values
#et_raw_df['DIN_diff']=et_raw_df['DIN'].diff()
## select all non-zero DIN changes
#et_din_events=et_raw_df.loc[et_raw_df['DIN_diff']>0]



In [None]:
#et_din_events

In [None]:
## perform the anomalous DIN conversion
#et_din_events = et_din_events.copy()
#et_din_events['DIN'].loc[et_din_events['DIN'].isin([2,18,26])] = 2
#et_din_events['DIN'].loc[et_din_events['DIN'].isin([4,20,28])] = 4

#et_din_events = et_din_events.copy()
#et_din_events=et_din_events.loc[et_raw_df['DIN'].isin([2,4])]
#et_din_events = et_din_events.reset_index()
#et_din_events['DIN_diff'] = et_din_events['DIN_diff'].astype(int)
#et_din_events    


In [None]:
#    #convert DIN_diff to integers
#    et_din_events['DIN_diff'] = et_din_events['DIN_diff'].astype(int)

#    #add DIN events to et_annot_event_dict with the next available small integer
#    existing_indices = set(et_event_dict.values())
#    next_index = max(existing_indices) + 1

#    for din_diff in et_din_events['DIN_diff']:
#        din_key = f'DIN{din_diff}'
#        if din_key not in et_event_dict:
#            et_event_dict[din_key] = next_index
#            next_index += 1

#    #create new rows for et_annot_events based on et_din_events
#    #map DIN_diff to the new dictionary indices
#    et_din_events['mapped_value'] = et_din_events['DIN_diff'].map(lambda x: et_event_dict[f'DIN{x}'])

#    #add new rows to et_annot_events
#    new_events = np.array([[row['index'], 0, row['mapped_value']] for _, row in et_din_events.iterrows()])
#    et_annot_events = np.vstack((et_events, new_events))

#    #sort the updated et_annot_events array by the first column (timestamps)
#    et_annot_events = et_annot_events[np.argsort(et_annot_events[:, 0])]
#    et_annot_events = et_annot_events.astype(int)


In [None]:
#et_event_dict

In [None]:
if et_sync:
    #do event cleaning..
    et_event_dict, et_events = qit.et_clean_events(et_event_dict, et_events)
    #do task specific event modifications..
    et_event_dict, et_events, et_raw_df = qit.et_task_events(et_raw_df,et_event_dict,et_events,task_id_out)
    print("updated ET event dict:", et_event_dict)
    # Extract the value for 'STIM_d' from the dictionary
    stim_d_value = et_event_dict['STIM_d']
    # Filter rows where the third column matches the 'STIM_d' value
    et_stims = et_events[et_events[:, 2] == stim_d_value]
    print('Number of stimulus onset DIN events: ' + str(len(et_stims)))
else:
    print("et_sync = False: not plotting the original ET events")

In [None]:
if et_sync:
    # peak... at the ET event scatter plot.. event time stamp by label index
    fig=px.scatter(x=et_events[:,0],y=et_events[:,2])
    fig.update_layout(title='Updated ET event times')
    fig.update_xaxes(title_text='Time of event(ms)')
    fig.update_yaxes(title_text='Event index')
    py.iplot(fig)
    # print the scatterplot to an html file for easy exploration.
    if html_figures:
        fig.write_html("html_figures/et_updated_event_times.html")
else:
    print("et_sync = False: not plotting the updated ET events")

## Examine the syncronization between the EEG and ET events

In [None]:

#THIS SHOULD BE MOVED TO A NEW QIT.EEG_ET_ALIGN FUNCTION..

if et_sync:
    eeg_event_dict, et_event_dict, eeg_events, et_events, eeg_times, et_times = qit.eeg_et_align(
        eeg_event_dict, et_event_dict, 
        eeg_events, et_events, 
        eeg_stims, et_stims, 
        eeg_raw.info["sfreq"], et_raw.info["sfreq"])
    #eeg_times = eeg_stims[:, 0] / eeg_raw.info["sfreq"]
    #et_times = et_stims[:, 0] / et_raw.info["sfreq"]

    #n_eeg_times = len(eeg_times)
    #n_et_times = len(et_times)
    
    #if n_eeg_times > n_et_times:
    #    print("there are more eeg_times and there are et_times.. attempting align")
    #    eeg_times = qit.times_align(eeg_times,et_times)
    #elif n_eeg_times < n_et_times:
    #    print("there are more et_times and there are eeg_times.. attempting align")
    #    et_times = qit.times_align(et_times,eeg_times)
    #else:
    #    print("there are the same number of eeg_times and et_times.. continuing")
        
    ##check if alignment was successfull..
    #n_eeg_times = len(eeg_times)
    #n_et_times = len(et_times)
    #if n_eeg_times != n_et_times:
    #    print("EEG and ET times alignment was not successful... abandoning sync procedures...")
    #    et_syn = False
    #else:
    #    #create the sync_time events for the EEG and ET data.
    #    #convert times to samples..
    #    #eeg_samps = eeg_stims[:, 0]
    #    #et_samps = et_stims[:, 0]
    #    eeg_samps = eeg_times * eeg_raw.info["sfreq"] / 1000        
    #    et_samps = et_times * et_raw.info["sfreq"] / 1000                
    #    #add "sync_time" to the dictionary
    #    eeg_event_dict['sync_time'] = len(eeg_event_dict) + 1
    #    et_event_dict['sync_time'] = len(et_event_dict) + 1
    #    #add rows to the events array for "sync_time"
    #    eeg_sync_time_rows = [[samp, 0, eeg_event_dict['sync_time']] for samp in eeg_samps]
    #    et_sync_time_rows = [[samp, 0, et_event_dict['sync_time']] for samp in et_samps]
    #    #eeg_sync_time_rows = [[time, 0, eeg_event_dict['sync_time']] for time in eeg_times]
    #    #et_sync_time_rows = [[time, 0, et_event_dict['sync_time']] for time in et_times]
    #    #combine the new rows with the existing events
    #    eeg_events = np.vstack([eeg_events, eeg_sync_time_rows])
    #    eeg_events = eeg_events[eeg_events[:, 0].argsort()]  # Sort by the first column (time)
    #    et_events = np.vstack([et_events, et_sync_time_rows])
    #    et_events = et_events[et_events[:, 0].argsort()]  # Sort by the first column (time)

    #    print("Updated EEG event dictionary:")
    #    eeg_event_dict

else:
    print("et_sync = False: not checking eeg_times and et_times alignment")

In [None]:
et_event_dict

In [None]:
if et_sync:
    #eeg_stims = eeg_stims.astype(int) * 1000/eeg_raw.info['sfreq']
    #et_stims = et_stims.astype(int) * 1000/et_raw.info['sfreq']
    # peak... at the ET event scatter plot.. event time stamp by label index
    fig=px.scatter(x=eeg_times,y=et_times)
    fig.update_layout(title='EEG by ET stim times')
    fig.update_xaxes(title_text='EEG stim times')
    fig.update_yaxes(title_text='ET stim times')
    py.iplot(fig)
    # print the scatterplot to an html file for easy exploration.
    if html_figures:
        fig.write_html("html_figures/eeg_et_times.html")
else:
    print("et_sync = False: not plotting the EEG by ET event times")

In [None]:
if et_sync:
    # peak... at the ET event scatter plot.. event time stamp by label index
    #eeg_stims_t = eeg_stims.astype(int) * 1000/eeg_raw.info['sfreq']
    #et_stims_t = et_stims.astype(int) * 1000/et_raw.info['sfreq']
    eeg_et_offset = eeg_times - et_times
    fig = px.scatter(y=eeg_et_offset)
    fig.update_layout(title='EEG ET stim event offset times')
    fig.update_xaxes(title_text='EEG ET stim times')
    fig.update_yaxes(title_text='EEG ET stim event offsets')
    fig.show()
    if html_figures:
        fig.write_html("html_figures/eeg_et_sync_offsets.html")
else:
    print("et_sync = False: not plotting the EEG by ET offset times")

## Insert ET signals into the EEG raw object

In [None]:
#if et_sync:
#    
#    #NOW THAT THE EEG AND ET SESSIONS HAVE SYNC_TIME EVENTS WE DO NOT NEED TO COMBINE UNTIL POST PROCESSING...
#    # combine the EEG and ET recordings given the matched event times...
#    eeg_raw = qit.eeg_et_combine(eeg_raw, et_raw, eeg_times, et_times, eeg_events, eeg_event_dict, et_events, et_event_dict)
#
#    # Identify the channels with specific types (eyegaze and pupil) and rename them to 'misc' (to be fixed after bids compliance updates)
#    channel_types = eeg_raw.get_channel_types()
#    channel_mapping = {
#        ch_name: 'misc'
#        for ch_name, ch_type in zip(eeg_raw.ch_names, channel_types)
#        if ch_type in ['eyegaze', 'pupil']
#    }
#
#    # Update the channel types
#    eeg_raw.set_channel_types(channel_mapping)
#
#    #Update the event times since the sync
#    eeg_events, eeg_event_dict = mne.events_from_annotations(eeg_raw)
#    eeg_events[:,0]=eeg_events[:,0]-eeg_raw.first_samp
#else:
#    print('et_sync = False: skipping the integration of the ET signals into the EEG structure')

In [None]:
# peak... at the eeg_raw properties now that ET signals have been synced..
channel_types = eeg_raw.get_channel_types()
print("EEG Channel Types:", channel_types)
print("EEG Channel Names:", eeg_raw.info['ch_names'])
print("EEG sampling rate: ", eeg_raw.info["sfreq"])
print('EEG event dict:')
eeg_event_dict

In [None]:
## Define a few channel groups of interest and plot the data
#frontal = ["E19", "E11", "E4", "E12", "E5"]
#occipital = ["E61", "E62", "E78", "E67", "E72", "E77"]
#din = ["DIN"]
#pupil = ["pupil_left"]
#x_pos = ["xpos_left"]
#y_pos = ["ypos_left"]
#
#scale_dict = dict(eeg=1e-4, misc=1e3)
#
## picks must be numeric (not string) when passed to `raw.plot(..., order=)`
#picks_idx = mne.pick_channels(eeg_raw.ch_names, din + frontal + occipital + pupil + x_pos + y_pos, ordered=True)
#eeg_raw.plot(start=0,duration=20,order=picks_idx, scalings=scale_dict)

## Write the raw structure to a BIDS directory in the project root.

In [None]:
eeg_bids_path = qit.write_eeg(eeg_raw, 
              eeg_event_dict, 
              eeg_events, 
              subject_id_out, 
              session_id, 
              task_id_out, 
              project_path, 
              device_info)
## write the BIDS output files
## specify power line frequency as required by BIDS
#eeg_raw.info["line_freq"] = 60
#eeg_raw.info['device_info']=device_info
#eeg_raw.info['device_info']['type'] = eeg_raw.info['device_info']['type'].replace(' ', '-')

##THIS SHOULD BE MOVED TO QIT.FILLNA if it is needed...
#def fillna(raw, fill_val=0):
#    return mne.io.RawArray(np.nan_to_num(raw.get_data(), nan=fill_val), raw.info)
#eeg_raw=fillna(eeg_raw,fill_val=0)
#
#eeg_bids_path = mne_bids.BIDSPath(
#    subject=subject_id_out, session=session_id, task=task_id_out, run="1", datatype="eeg", root=project_path
#)

#print(eeg_bids_path)
#mne_bids.write_raw_bids(
#    raw=eeg_raw,
#    bids_path=eeg_bids_path,
#    events=eeg_events,
#    event_id=eeg_event_dict,
#    format = "EDF",
#    overwrite=True,
#    allow_preload=True,
#)

# NOW THAT EEG AND ET FILES HAVE SYNC_TIME EVENTS WRITE THE ET TO BIDS AS WELL...

## Identify the channels with specific types (eyegaze and pupil) and rename them to 'misc' (to be fixed after bids compliance updates)
#channel_types = et_raw.get_channel_types()
#channel_mapping = {
#    ch_name: 'misc'
#    for ch_name, ch_type in zip(et_raw.ch_names, channel_types)
#    if ch_type in ['eyegaze', 'pupil']
#}
#
## Update the channel types
#et_raw.set_channel_types(channel_mapping)
#
#task_id_out_et = task_id_out + "et"
#et_bids_path = mne_bids.BIDSPath(
#    #subject=subject_id_out, session=session_id, task=task_id_out_et, run="1", datatype="eeg", root=project_path
#    subject=subject_id_out, session=session_id, task=task_id_out_et, run="1", root=project_path
#)
#
#print(et_bids_path)
#mne_bids.write_raw_bids(
#    raw=et_raw,
#    bids_path=et_bids_path,
#    events=et_events,
#    event_id=et_event_dict,
#    format = "EDF",
#    overwrite=True,
#    allow_preload=True,
#)

In [None]:
et_out_path = qit.write_et(et_raw, et_event_dict, et_events, eeg_bids_path)
#event_id_to_name = {v: k for k, v in et_event_dict.items()}
#onsets = et_events[:, 0] / et_raw.info['sfreq']  # Convert sample indices to time (in seconds)
#durations = [0] * len(onsets)  # Example: all durations set to 0
#descriptions = [event_id_to_name[event_id] for event_id in et_events[:, 2]]

#annotations = mne.Annotations(onset=onsets, duration=durations, description=descriptions)

##overwrite the annotations in the Raw object
#et_raw.set_annotations(annotations)

#et_out_path = str(eeg_bids_path)
#et_out_path = et_out_path.replace("/eeg/", "/et/")
#et_out_path = et_out_path.replace("_eeg.", "_et.")
#et_out_path = et_out_path.split(".")[0] + ".fif"
## Extract the directory path
#directory = os.path.dirname(et_out_path)

## Ensure the directory exists
#os.makedirs(directory, exist_ok=True)

##write the updated Raw object to the specified path
#et_raw.save(et_out_path, overwrite=True)


## Read tests...

In [None]:
#et_raw_r = mne.io.read_raw_fif=(et_out_path, preload=True)
raw = mne.io.read_raw_fif(et_out_path, preload=True)

In [None]:
raw

In [None]:
(events_from_annot, event_dict) = mne.events_from_annotations(raw)

In [None]:
# create the EEG event structures
raw_events = mne.find_events(raw, shortest_event = 1)
raw_event_dict = qit.get_event_dict(raw, raw_events, 1)

In [None]:
raw_event_dict

In [None]:
##Read test
#bids_path = mne_bids.BIDSPath(
#    subject="100134F1", session="01", task="VEP", run="1", datatype="eeg", root="/project/def-emayada/q1k/experimental/HSJ/"
#)
#raw = mne_bids.read_raw_bids(bids_path=bids_path)


In [None]:
## peak... at the eeg_raw properties
#channel_types = raw.get_channel_types()
#print("EEG Channel Types:", channel_types)
#print("EEG Channel Names:", raw.info['ch_names'])
#print("EEG sampling rate: ", raw.info["sfreq"])

In [None]:
## Define a few channel groups of interest and plot the data
#frontal = ["E19", "E11", "E4", "E12", "E5"]
#occipital = ["E61", "E62", "E78", "E67", "E72", "E77"]
#din = ["DIN"]
#pupil = ["pupil_left"]
#x_pos = ["xpos_left"]
#y_pos = ["ypos_left"]

#scale_dict = dict(eeg=1e-4, eyegaze=30, pupil=30)

## picks must be numeric (not string) when passed to `raw.plot(..., order=)`
#picks_idx = mne.pick_channels(raw.ch_names, din + frontal + occipital + pupil + x_pos + y_pos, ordered=True)
#raw.plot(start=0,duration=4,order=picks_idx, scalings=scale_dict)