# Pipeline for BIAPT lab EEG Preprocessing: 
#### inspired by: https://github.com/hoechenberger/pybrain_mne/
#### adapted by: Charlotte Maschke, Beatrice PDK and Victoria Sus
#### This pipeline uses MNE Python to preprocess EEG data: Plese go here: 
####                                https://mne.tools/stable/overview/index.html
####  for more documentation on MNE Python

## Some setup and import

In [1]:
import matplotlib
#import mne_bids
import pathlib
import mne
import os
import os.path as op
import sklearn 
from mne import viz

# interactive plotting functions.


from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,
                               corrmap)
#import openneuro

#from mne_bids import BIDSPath, read_raw_bids, print_dir_tree, make_report

# Ensure Matplotlib uses the Qt5Agg backend, 
# which is the best choice for MNE-Python's 
# interactive plotting functions.
matplotlib.use('Qt5Agg')

import matplotlib.pyplot as plt

### Which subject do you want to preprocess? 

In [2]:
ID = "001"
session = "02"
task = "post60"

In [3]:
raw_path = "./Data/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_eeg.mff".format(ID,session,ID,session,task)
raw_path

'./Data/coma_tacs/sub-001/ses-02/eeg/sub-001_ses-02_task-post60_eeg.mff'

## Load the raw data!

In [4]:
raw = mne.io.read_raw_egi(raw_path)
raw

FileNotFoundError: input_fname does not exist: /home/bea/Bureau/BIAPT/EEG_Preprocessing_TACS/Data/coma_tacs/sub-001/ses-02/eeg/sub-001_ses-02_task-post60_eeg.mff

## Resample the data to 250

In [None]:
if raw.info['sfreq'] != 250:
    raw = raw.resample(250)

### Keep only the EEG

In [None]:
# this is to load EEG. If you want to load other stuff please refer to the website documetation
eeg = raw.pick_types(eeg = True)
print('Number of channels in EEG:')
len(eeg.ch_names)

## Filter the data

In [None]:
# load actual data into system (before it was only metadata)
eeg.load_data()
#eeg.load_data()

# filter the data between 1 to 55 Hz
eeg_filtered = eeg.filter(l_freq=1, h_freq = 55)# we needed to lower the low_pass to 50 because of the presence of a weird machine artifact
#eeg_filtered = eeg.filter(l_freq=1, h_freq = 55) 

# notch filter the data for freq =60
eeg_notch = eeg_filtered.copy().notch_filter(freqs= 60)
#eeg_notch = eeg_filtered.notch_filter(freqs=60)


In [None]:
%matplotlib qt
viz.plot_raw_psd(eeg_notch, exclude = ['E129'], fmax = 70)
plt.savefig('./out_figures/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_raw_filtered.png'.format(ID,session,task,ID,session,task))

## Crop the data

Depending on the data, we may need to crop the begining or/and the end

In [None]:
eeg_notch.plot(duration=10, n_channels=60, title='raw')

Select if begining and end should be cropped and adapt next line accordingly

In [None]:
eeg_cropped= eeg_notch.crop(tmax=335.0)

In [None]:
eeg_cropped.plot(duration=10, n_channels=60,title='cropped')

In [None]:
eeg_cropped

## Visualize raw data to identify bad channels

In [None]:
eeg_cropped.plot(n_channels=60, duration=20)

Verify if labelled correctly

In [None]:
marked_bad = eeg_cropped.info['bads']
marked_bad

In [None]:
# save in a txt
with open('./eeg_output/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_marked_bads.json'.format(ID,session,ID,session,task), 'w') as outfile:
    outfile.write("\n".join(marked_bad))


# Interpolate bad channels

In [None]:
eeg_interpol = eeg_cropped.interpolate_bads()

## Average Reference the data

In [None]:
# use the average of all channels as reference
eeg_avg_ref = eeg_interpol.set_eeg_reference(ref_channels='average')

In [None]:
eeg_avg_ref.plot(duration=10,n_channels=60)

In [None]:
%matplotlib tk
viz.plot_raw_psd(eeg_avg_ref, fmax = 70)
plt.savefig('./out_figures/coma_tacs/sub-{}/ses-{}/task-{}/sub-{}_ses-{}_task-{}_PSD_avg_ref.png'.format(ID,session,task,ID,session,task))

In [None]:
print(eeg_avg_ref.info)

# RUN ICA

## Manual selection of ICA components

In [None]:
# pick some channels that clearly show heartbeats and blinks
#eog_channels = ['E8', 'E12','E14','E21', 'E25', 'E126', 'E127']

In [None]:
#eeg_avg_ref.info['EOG channels']
from mne.preprocessing import (ICA, create_eog_epochs, create_ecg_epochs,
                               corrmap)

ica = ICA(n_components=15, max_iter='auto', random_state=97)
ica.fit(eeg_avg_ref)

In [None]:
ica.plot_sources(eeg_avg_ref, show_scrollbars=False, stop = 60)


In [None]:
ica.plot_components()

In [None]:
ica

In [None]:
# blinks
ica.plot_overlay(eeg_avg_ref, start = 30*250, stop = 40*250)
#ica.plot_overlay(eeg_avg_ref, exclude=[0,1,2,3],start = 30*250, stop = 40*250)

Double check which component to remove: 

In [None]:
ica.exclude

Remove component definetely

In [None]:
eeg_postica = eeg_avg_ref.copy()
ica.apply(eeg_postica)

In [None]:
# PLot to compare both signals pre and post ICA
eeg_avg_ref.plot(title='raw', duration = 30, n_channels=100, scalings=20e-6)
eeg_postica.plot(title='ICA correction', duration = 30, n_channels=100, scalings=20e-6)

Identify remaining  bad channels

In [None]:

eeg_postica.plot(title='Identify bad channels, post ICA and interpolation',duration = 30, n_channels=100, scalings=20e-6)

## Remove Non-Brain Electrodes and bad channels 

In [None]:
non_brain_el = ['E127', 'E126', 'E17', 'E21', 'E14', 'E25', 'E8', 'E128', 'E125', 'E43', 'E120', 'E48', 
                'E119', 'E49', 'E113', 'E81', 'E73', 'E88', 'E68', 'E94', 'E63', 'E99', 'E56', 'E107' ]

#only add non-brain channels if not already part of noisy channels
for e in non_brain_el: 
    if e not in eeg_avg_ref.info['bads']:
        eeg_postica.info['bads'].append(e)


In [None]:
#print(eeg_postica.info['bads'])

In [None]:
# remove channels marked as bad and non-brain channels
eeg_brainonly = eeg_postica.copy().drop_channels(eeg_postica.info['bads'])
eeg_brainonly

In [None]:
eeg_brainonly.plot(title='brain only',n_channels=60, duration=10)

## Save final non brain data

In [None]:
#out_dir = pathlib.Path("./Results/sub-{}/ses-{}/eeg/".format(ID,session))

#if not os.path.exists(out_dir):
    #os.makedirs(out_dir)
    
#out_path = pathlib.Path(out_dir, "sub-{}_ses-{}_task-{}.set".format(ID,session,task))

In [None]:
eeg_brainonly.save("./eeg_output/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_{}_eeg.fif".format(ID, session, ID, session, task, 'nonbrainfiltered_noncut'), overwrite=True)
#ici changer pour le sauver dans derivatives, dossier clean (avec un readme expliquant le cleaning) et le fichier sous le bids format

# Make epochs of 10 s

In [None]:
epochs = mne.make_fixed_length_epochs(eeg_brainonly, duration = 10, overlap=0)

In [None]:
epochs  #verify initial number

In [None]:
epochs.plot(n_epochs=3, n_channels=100, scalings=20e-6)

In [None]:
epochs #verify how many are left

Verify psd

In [None]:
%matplotlib qt
epochs.plot_psd(fmax=55)

In [None]:
epochs.save(("./eeg_output/coma_tacs/sub-{}/ses-{}/eeg/sub-{}_ses-{}_task-{}_{}_eeg.fif".format(ID, session, ID, session, task, 'epoch')),overwrite=True)
