# Data preprocessing II: ICA

The workhorse functions `clean_withar_local` and `get_ica_weights` are defined in `..\preprocess.py`. This allows to call them also from bash scripts to run them in the background to precalculate their outputs (which - depending on your setup - will take a while). 
You can then here set all `*_from_disc` arguments to `True` and load the precalculated values directly. 
This script has some interactive steps:
- marking stretches of clean EOG data (can be skipped if loaded from disk) 
- inspecting ICA components
- inspecting results of the ICA cleaning

So it won't run over all participants without your interaction. Good to look at some data anyway.
I recommend precalculating the ICA weights incl. the first round of Autoreject in separate script (use `..\utils\03_preprocess-ica-ar.py` for this) and then just load the values here. 

In [None]:
from os import path as op
import sys
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne.preprocessing import create_eog_epochs
import autoreject

from library import config, helpers, preprocess

In [None]:
%matplotlib qt

%load_ext autoreload
%autoreload 2

In [None]:
def get_eog_epochs_clean(data_raw,
                         t0=None, t1=None,
                         manual_mode=False,
                         eogannot_from_disc=False,
                         save_eogannot_to_disc=True,
                         eogannot_path=None):
    """Robust way to get epochs centered around EOG events.

    Produce EOG epochs from raw data by epoching to "EOG events" (peaks in the EOG). Based on
    `mne.preprocessing.create_eog_epochs`, but more robust.
    Specify either a (list of) start time(s) t0 and a (list of) end time(s) t1 of stretches in the data with somewhat
    clean EOG signal. Alternatively, open a plot of the data and mark it manually by annotating it as "clean_eog"
    (`manual_mode`). You can also save these manual annotations to disc and load them in again later.

    Parameters
    ----------
    data_raw : MNE raw object.
        Must contain `VEOG`and `HEOG` channel of type `eog`.
    t0 : int | list | array
        Start time(s) of clean EOG intervall(s).
        Ignored if `manual_mode` is `True`.
    t1 : int | list | array
        End time(s) of clean EOG intervall(s).
        Ignored if `manual_mode` is `True`.
    manual_mode : bool
        Set to `True` if you want to manually mark stretches of clean EOG data or load according annotations from disc.
    eogannot_from_disc : bool
        Load annotations from disc.
        Ignored if `manual_mode` is `False`.
    save_eogannot_to_disc : bool
        Save annotations for later reuse.
        Ignored if `manual_mode` is `False`.
    eogannot_path : str
        Filename to/from which the annotations shall be saved/loaded.
        Ignored if `manual_mode` is `False`.

    Returns
    -------
    epochs_veog : mne Epochs
        Epochs of length 1s, centered around VEOG events.
    epochs_heog : mne Epochs
        Epochs of length 1s, centered around HEOG events.
    """

    if not isinstance(t0, (list, np.ndarray)):
        t0 = [t0]
    if not isinstance(t1, (list, np.ndarray)):
        t1 = [t1]

    if manual_mode:
        mark_data = data_raw.copy()

        if eogannot_from_disc:
            fname = op.join(eogannot_path, subID + '-eog-annot.fif')
            annot = mne.read_annotations(fname)
            mark_data.set_annotations(annot)
        else:
            # Mark stretches (at least 1) of data with clean EOG
            mark_data.pick_types(eog=True, eeg=False).filter(l_freq=1, h_freq=None,
                                                             picks=['eog'],
                                                             verbose=False).\
                                                      plot(duration=180,
                                                           scalings={'eog': 850e-6},
                                                           #remove_dc=True,
                                                           block=True)
            if save_eogannot_to_disc:
                fname = op.join(eogannot_path, subID + '-eog-annot.fif')
                mark_data.annotations.save(fname, overwrite=True)

        # Calculate clean epochs locked to EOG-peaks (!name the annotation `clean_eeg`!):
        raws = [idx for idx in mark_data.annotations if idx['description'] == 'clean_eog']
        t0 = [raw['onset'] for raw in raws]
        t1 = [raw['onset'] + raw['duration'] for raw in raws]

    raw_eog = data_raw.copy().crop(t0[0], t1[0])
    if len(t0) > 1:
        for t0_, t1_ in zip(t0[1:], t1[1:]):
            raw_eog.append(data_raw.copy().crop(t0_, t1_))
            
    # Heuristically determine a robust threshold for EOG peak detection
    tmp = raw_eog.copy().load_data().\
                         pick_channels(['VEOG', 'HEOG']).\
                         filter(l_freq=1, h_freq=10, picks=['eog'], verbose=False)
    
    tmp_epo = mne.make_fixed_length_epochs(tmp, preload=True)

    mean_threshs = np.mean(tmp_epo.get_data().squeeze().ptp(axis=0) / 4, axis=1)
    thresh_dict = {ch:thresh for ch, thresh in zip(tmp_epo.ch_names, mean_threshs)}

    epochs_veog = create_eog_epochs(raw_eog,
                                    ch_name='VEOG',
                                    baseline=(None,None),
                                    thresh=thresh_dict['VEOG'],
                                    verbose=False)

    print(f'Created {len(epochs_veog)} VEOG epochs.')
    if (len(epochs_veog) < 50):
        print('Not really a lot. This might be a problem! Consider marking longer stretches of clean EOG data.')
        print('########################')

    epochs_heog = create_eog_epochs(raw_eog,
                                    ch_name='HEOG',
                                    baseline=(None,None),
                                    thresh=thresh_dict['HEOG'],
                                    verbose=False)

    print(f'Created {len(epochs_heog)} HEOG epochs.')
    if (len(epochs_heog) < 50):
        print('Not really a lot. This might be a problem! Consider marking longer stretches of clean EOG data.')
        print('########################')

    return epochs_veog, epochs_heog


def vis_compare_ica(data_before, data_after, show_data_before=False, block=True):
    # visual comparison:
    picks=['eeg','eog']
    if show_data_before:
        old = data_before.copy().\
        apply_baseline((None,None)).\
        plot(scalings=dict(eeg=50e-6),
                               n_epochs=15,
                               n_channels=30,
                               picks=picks) 
    new = data_after.copy().\
    apply_baseline((None,None)).\
    plot(scalings=dict(eeg=50e-6),
                          n_epochs=15,
                          n_channels=30,
                          picks=picks,
                          block=block)

In [None]:
# Loop over all subjects:
sub_list = np.setdiff1d(np.arange(1,config.n_subjects_total+1), config.ids_missing_subjects)
sub_list_str = ['VME_S%02d' % sub for sub in sub_list]

## to run a single subject, modify and uncomment one of the following lines:
# sub_list_str = ['VME_S27']

for subID in sub_list_str:
    data_raw = helpers.load_data(subID + '-prepared',
                                        config.paths['01_prepared'],
                                        append='-raw').load_data()
    data_forICA = helpers.load_data(subID + '-ica',
                                        op.join(config.paths['02_epochs'], '0.01', 'ica'),
                                        append='-epo')

    # clean it with autoreject local to remove bad epochs for better ICA fit:
    data_forAR = data_forICA.copy().apply_baseline((-config.times_dict['bl_dur_erp'],0)) # AR does not perform well on non-baseline corrected data

    _, ar, reject_log = preprocess.clean_with_ar_local(subID,
                                                       data_forAR,
                                                       ar_from_disc=True,
                                                       save_to_disc=False,
                                                       ar_path=config.paths['03_preproc-ica-ar'])

    # Get ICA weights
    ica = preprocess.get_ica_weights(subID,
                                      data_forICA[~reject_log.bad_epochs],
                                      picks=None,
                                      reject=None,
                                      method='picard',
                                      fit_params=None,
                                      ica_from_disc=True,
                                      save_to_disc=False,
                                      ica_path=config.paths['03_preproc-ica'])

    # Calculate EOG epochs from clean stretches of data:
    epochs_veog, epochs_heog = get_eog_epochs_clean(data_raw,
                                                    manual_mode=True,
                                                    eogannot_from_disc=True,
                                                    save_eogannot_to_disc=False,
                                                    eogannot_path=config.paths['03_preproc-ica-eog'])

    # Reject ICA components with high correlations with V/HEOG channel
    threshold = 0.9
    indices_veog, scores_veog = ica.find_bads_eog(epochs_veog,
                                                  ch_name='VEOG',
                                                  measure='correlation',
                                                  threshold=threshold)

    indices_heog, scores_heog = ica.find_bads_eog(epochs_heog,
                                                  ch_name='HEOG',
                                                  measure='correlation',
                                                  threshold=threshold)
    
#     ica.plot_scores([scores_veog, scores_heog],
#                     labels=['VEOG', 'HEOG'],
#                     exclude=indices_veog + indices_heog)

    exclude = list(np.unique(indices_veog + indices_heog))
    ica.exclude = exclude

#    fig = ica.plot_components(range(25), inst=data_forICA)
#    plt.show(block = True)

    # kick out components from data to clean and save those to disc:
    for epo_part in ['cue', 'fulllength', 'stimon']:

        data_pre = helpers.load_data(f'{subID}-{epo_part}',
                                     op.join(config.paths['02_epochs'], '0.1', epo_part),
                                     append='-epo',
                                     verbose=False)

        data_post = ica.apply(data_pre.copy())

        helpers.save_data(data_post,
                          f'{subID}-{epo_part}-postICA',
                          op.join(config.paths['03_preproc-ica'], 'cleaneddata', '0.1', epo_part),
                          append='-epo')
        
#     evo = data_post.copy().average().apply_baseline((-config.times_dict['bl_dur_erp'], 0))
#     evo.plot()
#     data_post.plot_psd()
#     vis_compare_ica(data_pre, data_post, show_data_before=True, block=True)