# Pipeline for BIAPT lab EEG Preprocessing: 
#### inspired by: https://github.com/hoechenberger/pybrain_mne/
#### adapted by: Beatrice PDK, Victoria Sus and Charlotte Maschke, 
#### This pipeline uses MNE Python to preprocess EEG data: Plese go here: 
####                                https://mne.tools/stable/overview/index.html
####  for more documentation on MNE Python

## Setup and import

In [None]:
import matplotlib
#import mne_bids
import pathlib
import mne
import os
import os.path as op
from mne import viz

# interactive plotting functions.


from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,
                               corrmap)
#import openneuro

#from mne_bids import BIDSPath, read_raw_bids, print_dir_tree, make_report

# Ensure Matplotlib uses the Qt5Agg backend, 
# which is the best choice for MNE-Python's 
# interactive plotting functions.
matplotlib.use('Qt5Agg')

import matplotlib.pyplot as plt

### Enter the recording information you want preprocess

In [None]:
ID = "007"
session = "07"
task = "post7"

In [None]:
raw_path = "./Data/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_eeg.mff".format(ID,session,ID,session,task)
raw_path

## Load the raw data!

In [None]:
raw = mne.io.read_raw_egi(raw_path)
raw

## Resample the data to 250

In [None]:
if raw.info['sfreq'] != 250:
    raw = raw.resample(250)

### Keep the EEG only

In [None]:
# this is to load EEG. If you want to load other stuff please refer to the website documentation
eeg = raw.pick_types(eeg = True)
print('Number of channels in EEG:')
len(eeg.ch_names)

## Apply filtering

In [None]:
# load actual data into system (before it was only metadata)
eeg.load_data()
#eeg.load_data()

# filter the data between 1 to 55 Hz
eeg_filtered = eeg.filter(l_freq=1, h_freq = 55)# we needed to lower the low_pass to 50 because of the presence of a weird machine artifact
#eeg_filtered = eeg.filter(l_freq=1, h_freq = 55) 

# notch filter the data for freq =60
eeg_notch = eeg_filtered.copy().notch_filter(freqs= 60)
#eeg_notch = eeg_filtered.notch_filter(freqs=60)


In [None]:
%matplotlib qt
viz.plot_raw_psd(eeg_notch, exclude = ['E129'], fmax = 70)
if not os.path.exists('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task)) :
    os.makedirs('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task))
plt.savefig('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_raw_filtered.png'.format(ID,session,task,ID,session,task))

## Crop the data

Depending on the data, we may need to crop the begining or/and the end

In [None]:
eeg_notch.plot(duration=10, n_channels=100, title='raw')

Select if begining and end should be cropped and adapt next line accordingly

In [None]:
eeg_cropped= eeg_notch.crop(tmin=4.0,tmax=401.0)

In [None]:
eeg_cropped.plot(duration=10, n_channels=120,title='cropped')

## Visualize raw data to identify bad channels

In [None]:
eeg_cropped.plot(n_channels=60, duration=20)

Verify if labelled correctly

In [None]:
marked_bad = eeg_cropped.info['bads']
marked_bad

In [None]:
#eeg_cropped.plot(n_channels =30, duration=20)

In [None]:
# save in a txt
if not os.path.exists('./eeg_output_pre_only/coma_tacs/sub-{}/ses-{}/eeg'.format(ID,session)) :
    os.makedirs('./eeg_output_pre_only/coma_tacs/sub-{}/ses-{}/eeg'.format(ID,session))
with open('./eeg_output_pre_only/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_marked_bads.json'.format(ID,session,ID,session,task), 'w') as outfile:
    outfile.write("\n".join(marked_bad))


## Remove bad channels

In [None]:
eeg_cropped_removed = eeg_cropped.copy().drop_channels(eeg_cropped.info['bads'])

In [None]:
eeg_cropped_removed.plot(n_channels =60, duration=20)

## Segment into 10-sec epochs

In [None]:
epochs = mne.make_fixed_length_epochs(eeg_cropped_removed, duration = 10, overlap=0)

In [None]:
epochs  #verify initial number

### Reject epochs with amplitude bigger than 2000 µVolt 

Peak to peak amplitude on brain scalp > 2000 µVolt are epochs not linked with physiological causes, physiological amplitude accepted < 800 µVolt

In [None]:
epochs_clean = epochs.copy().load_data()
epochs_clean.drop_bad({'eeg':2000*1e-6})
epochs_clean.plot_drop_log()

In [None]:
epochs_clean

Remove epochs of bad quality remaining

In [None]:
epochs_clean.plot(title='bad_epochs_remaining', n_epochs=3, n_channels=60, scalings=20e-6)

In [None]:
epochs_clean

In [None]:
epochs_clean.plot(n_epochs=3, n_channels=100, scalings=20e-6)

In [None]:
%matplotlib qt
epochs_clean.plot_psd(fmax=70, exclude = ['E129'])
#if not os.path.exists('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task)) :
#    os.makedirs('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task))
#plt.savefig('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_epochs_filtered.png'.format(ID,session,task,ID,session,task))

## Drop additionnal channels afterwards if necessarry

In [None]:
#Select bad channels on the plot
epochs_clean.plot(n_epochs=3, n_channels=100, scalings=20e-6)

In [None]:
epochs_clean.info['bads']

In [None]:
epochs_clean = epochs_clean.copy().drop_channels(epochs_clean.info['bads'])

In [None]:
epochs_clean.plot(n_epochs=3, n_channels=100, scalings=20e-6)

# RUN ICA

## Manual selection of ICA components

Number of components = 30 because it's better to have as many components as possible (up to the nb of electrodes ~100 after clearing and rejection of non brain)
30 remains a good compromise

In [None]:
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,
                               corrmap)

ica = ICA(n_components=30, max_iter='auto', random_state=97)
ica.fit(epochs_clean)

In [None]:
ica.plot_sources(epochs_clean)

In [None]:
ica.plot_components(inst=epochs_clean)

Main artifacts, eyes and jaw muscle contractions and clear non physiological artifacts (machine at the ICU for example). Coordinates in the variance table, x corresponds to epochs, y corresponds to variance. Enables to identify if the component has an impact on the variance all recording long or rather if it is a specific epoch that we need to reject. we need to decide if it is a component to remove or rather an epoch. 

In [None]:
ica.plot_properties(epochs_clean, picks=ica.exclude)

In [None]:
ica

Double check which component to remove: 

In [None]:
ica.exclude

In [None]:
len(ica.exclude)

## Remove component definetely

In [None]:
epochs_clean.load_data()
eeg_postica= ica.apply(epochs_clean.copy())


In [None]:
# PLot to compare both signals pre and post ICA
epochs_clean.plot(title='raw', n_epochs=3, n_channels=60, scalings=20e-6)
eeg_postica.plot(title='ICA correction', n_epochs=3, n_channels=60, scalings=20e-6)

In [None]:
%matplotlib qt
eeg_postica.plot_psd(exclude = ['E129'], fmax = 70)
#if not os.path.exists('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task)) :
#    os.makedirs('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task))
#plt.savefig('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_epochs_post_ica.png'.format(ID,session,task,ID,session,task))

## Average Reference the data

In [None]:
# use the average of all channels as reference
eeg_avg_ref = eeg_postica.set_eeg_reference(ref_channels='average')

In [None]:
#eeg_avg_ref.plot(title='Avg ref', n_epochs=3, n_channels=100, scalings=20e-6)

In [None]:
%matplotlib tk
eeg_avg_ref.plot_psd(fmax=70)
#if not os.path.exists('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task)) :
#    os.makedirs('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}'.format(ID,session,task))
#plt.savefig('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_avg_ref.png'.format(ID,session,task,ID,session,task))

## Remove Non-Brain Electrodes 

In [None]:
non_brain_el = ['E127', 'E126', 'E17', 'E21', 'E14', 'E25', 'E8', 'E128', 'E125', 'E43', 'E120', 'E48', 
                'E119', 'E49', 'E113', 'E81', 'E73', 'E88', 'E68', 'E94', 'E63', 'E99', 'E56', 'E107' ]

#only add non-brain channels if not already part of noisy channels
for e in non_brain_el: 
    if e in eeg_avg_ref.info['ch_names']:
        if e not in marked_bad :
            eeg_avg_ref.info['bads'].append(e)
    


In [None]:
print(eeg_avg_ref.info['bads'])

In [None]:
# remove non-brain channels
eeg_brainonly = eeg_avg_ref.copy().drop_channels(eeg_avg_ref.info['bads'])

## Verify removal

In [None]:
eeg_brainonly.plot(title='brain only', n_epochs=3, n_channels=100, scalings=20e-6)

In [None]:
eeg_brainonly= eeg_brainonly.copy().drop_channels(eeg_brainonly.info['bads'])

Verify psd

In [None]:
%matplotlib qt
eeg_brainonly.plot_psd(fmax=70)
plt.savefig('./out_figures2/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_brainonly.png'.format(ID,session,task,ID,session,task))

## Save final non brain data

In [None]:
eeg_brainonly.save("./eeg_output_pre_only/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_{}_eeg.fif".format(ID, session, ID, session, task, 'epoch'), overwrite=True)
#ici changer pour le sauver dans derivatives, dossier clean (avec un readme expliquant le cleaning) et le fichier sous le bids format