### Preproc EMU....

In [None]:
import mne, os, ast, logging
import pandas as pd
import numpy as np
from os.path import join
from autoreject import get_rejection_threshold
from matplotlib import pyplot as plt
from datetime import datetime

from expt_params import *
from functions import *


mne.set_log_level('warning')

# ------------------- subject-specific data paths --------------
channels_info_path = join(params['channels_info_dir'], f'out_Subject_{subj['id']}.xlsx')
logfile_path = join(params['logfiles_dir'], f'{subj['id']}_Trial_RTs.csv')

# ------------------ output dirs and logging
outdir_logging = join(params['outdir_logging'], subj['emu'])
if not os.path.isdir(outdir_logging):
    os.mkdir(outdir_logging)

logging.basicConfig(
    filename= join(outdir_logging, f'{subj['emu']}.log'),
    filemode='a',  # Append mode to keep all logs
    format='%(message)s',
    level=logging.INFO)

# save experiment params to log
now = datetime.now().strftime("%d/%m/%Y %H:%M:%S")
logging.info(now)

for key, val in params.items():
    logging.info(f"{key}: {val}")

# ------------------- Load data ---------------------------- 
print('reading edf...')
edf = mne.io.read_raw_edf(edf_path, preload=True)
print('Done')

# if resample:
if params['resample_fs']:
    print('resampling...')
    edf.resample(params['resample_fs'])
    print('Done')

# get channel names from pandas df
df_channel_info = pd.read_excel(channels_info_path)
channel_names = np.array(df_channel_info.Channel)

print('N channels =', len(channel_names))
logging.info(f"N channels: {len(channel_names)}")

# pick the channels based on anatomy csv
edf.pick_channels(channel_names)
assert len(edf.info['ch_names']) == len(channel_names)

# correct channel types
chaedf.set_channel_types({ch_n:'seeg' for ch_n in edf.ch_names})

# ----------------- Add montage ------------------------------
montage = get_montage_from_bs(electrode_data=df_channel_info, 
                              electrode_space=params['electrode_space'], 
                              subjects_dir=params['subjects_dir'])
edf.set_montage(montage)
# _fig = mne.viz.plot_montage(montage, show=False)
# _fig.suptitle(subj['emu'])
# _fig.savefig(join(outdir_logging, f'{subj['emu']}_montage.jpg'))


# ------------------- Filters -------------------------------
print('Applying filters...')
edf = apply_standard_filters(edf, params['high_pass_cutoff'], params['notch_freqs'])
print('Done')

### Reject bad channels

In [None]:
# ---------- Interactive plot to reject bads ------------------ #
%matplotlib qt

# --- PSD plot
psd =  edf.compute_psd(fmax=200)
psd.plot()

# --- raw plot
edf.plot()
print('>> Mark bads on raw plot (popup window)')

In [None]:
# --------------------- dropping bad channels here ------------------ #
print('Bad channels:', edf.info['bads'])

# We will append more bad channels later if we find them. 
# We need this list because after we actually drop the channels this will be gone from the info dict
all_bad_channels =  edf.info['bads'] 

edf.drop_channels(edf.info['bads'])

# ------------------- epochs and baseline correction ------------------ #
logfile = pd.read_csv(logfile_path)

######
print('WARNING: systematically dropping the last line of logfile')
logfile = logfile[:-1] # drop the last row, I think it's meant to be empty ?
#######

event_times = np.array(logfile.EMU_OnsetTime)

# Convert onset times to sample indices
sfreq = edf.info['sfreq'] 
onset_samples = (event_times * sfreq).astype(int)
event_id = 1 # same for all triggers
events = np.column_stack((onset_samples, np.zeros(len(onset_samples), int), np.full(len(onset_samples), event_id)))

print('First 10 events:\n', events[:10])

epochs = mne.Epochs(edf, 
                    events, 
                    tmin = params['epochs_tmin'],
                    tmax = params['epochs_tmax'],
                    baseline = params['epochs_baseline'],
                    decim = params['epochs_decim'],
                    metadata = logfile,
                    reject = None,
                    preload = True)

logging.info(f'Init num of epochs: {len(epochs)}')

# reject noisy epochs using autoreject:
# see docs here https://autoreject.github.io/stable/explanation.html
rejection_threshold = get_rejection_threshold(epochs, decim=2)
print(rejection_threshold)
logging.info(f'epochs_rejection_threshold: {rejection_threshold}')

epochs_rej = epochs.copy()
epochs_rej.drop_bad(reject=rejection_threshold, verbose='warning')

print_epochs_rejection_info(epochs, epochs_rej)

#--- plot a graph showing how many epochs were rejected due to each channel
# Note: epochs can be rejected based on multiple channels, this only shows the # of times a channel caused a bad epoch
%matplotlib inline
fig = plot_epochs_rejcount_by_channel(epochs_rej, method='zscore', threshold=2)

### Manually determine if there are more bad channels

In [None]:
# manually iterate over some rejection of channels if needed
drop_channel_then_print_new_epoch_rejcount(epochs=epochs, 
                                           bad_channels=[], 
                                           rejection_threshold=rejection_threshold)

In [None]:
# Based on the above iterations derermine which channels to reject for the final cleaning
channels_to_drop = [] # if none, leave empty

# append to list of all bad channels which we will add to the logs
for ch in channels_to_drop:
    all_bad_channels.append(ch)

epochs_clean = epochs.copy() # make a copy
epochs_clean.drop_channels(channels_to_drop) # drop the outlier channel(s) based on plot above
epochs_clean.drop_bad(reject=rejection_threshold, verbose='warning') # then reject bad epochs using the threshold
print('WARNING: threshold was calculated using the full set of electrodes, but applied after rejecting some new ones')

num_rej, percent_rej = print_epochs_rejection_info(epochs, epochs_clean)
# plot_epochs_rejcount_by_channel(epochs_clean, method='zscore', threshold=2)


logging.info(f'Bad channels: {all_bad_channels}')
logging.info(f'Num bad epochs: {num_rej} ({round(percent_rej, 2)}%)')
logging.info(f'Num clean epochs: {len(epochs_clean)}')

# save epochs
epochs_clean.save(join(params['outdir_epochs'], f'{subj['emu']}-epo.fif'))



In [None]:
# plot evoked and savefig
%matplotlib inline
fig_evoked = epochs_clean.average().plot_joint()
fig_evoked.savefig(join(outdir_logging, f'{subj['emu']}_evoked.jpg'))


# finally plot sensors on fsaverage
brain = plot_seeg_freesurfer(epochs_clean, subjects_dir=params['subjects_dir'])
brain.save_image(join(outdir_logging, f'{subj['emu']}_brain.jpg'))
brain.close()
