## Set up

In [1]:
import numpy as np
import os.path as op
from pprint import pformat
from typing import Tuple, Iterator
import time

# EEG utilities
import mne
from mne.preprocessing import ICA, create_eog_epochs
from pyprep.prep_pipeline import PrepPipeline
from autoreject import get_rejection_threshold, validation_curve

# BIDS utilities
from mne_bids import BIDSPath, read_raw_bids
from util.io.bids import DataSink
from bids import BIDSLayout

In [2]:
# Constants
BIDS_ROOT = '../data/bids'
DERIV_ROOT = '../data/bids/derivatives'
# CHANMAP_FPATH = '../data/captrak/CACS-64_NO_REF.bvef'
# CHANPOS_ROOT = '../data/captrak'
LOWPASS = 300
FS = 2000
REJECT_THRES = 5e-7 # 50 microvolts

## Functions
#### Import data 

In [3]:
# Create Iterator object to loop over all files
KeyType = Tuple[str, str, str]

def fpaths() -> Iterator[KeyType]:
    for sub in subjects:
        for task in tasks:
            for run in runs:
                run = str(run) # layout.get_runs() doesn't return strings for some reason
                key = (sub, task, run)
                yield key

def get_bids_path(bids_root, sub, task, run):
    bids_path = BIDSPath(root = bids_root,
                        subject = sub,
                        task = task,
                        run = run,
                        datatype = 'eeg',
                        )
    return bids_path

def import_bids_data(bids_path):
    raw = read_raw_bids(bids_path, verbose = False)
    raw = raw.pick_types(eeg = True)
    return raw

def read_events(raw):
    events, events_ids = mne.events_from_annotations(raw)
    return events, events_ids

#### Set electrode locations and create EOGs

In [4]:
# def set_electrode_mappings(chanmap_fpath):
#     dig = mne.channels.read_custom_montage(chanmap_fpath)
#     mapping = {'Ch%s'%i: dig.ch_names[i] for i in range(len(dig.ch_names))}
# #     del mapping['Ch33']
# #     del mapping['Ch61']
# #     del mapping['Ch64']
#     return mapping

# def get_chanpos_fpath(chanpos_root, sub):
#     chanpos_fpath = op.join(chanpos_root, 'subj_' + sub + '.bvct')
#     return chanpos_fpath

# def set_electrode_positions(raw, chanpos_fpath, stim_channel, mapping):
#     dig = mne.channels.read_dig_captrak(chanpos_fpath)
#     raw = raw.set_channel_types({stim_channel: 'stim'}) 
#     raw = raw.rename_channels(mapping)
#     raw = raw.set_montage(dig)
#     return raw

def create_eogs(raw):
    raw = mne.set_bipolar_reference(raw, anode = 'Fp1', cathode = 'leog', ch_name = 'eog1', drop_refs = False)
    raw = mne.set_bipolar_reference(raw, anode = 'Fp2', cathode = 'reog', ch_name = 'eog2', drop_refs = False)
    raw = raw.drop_channels(['reog', 'leog'])
    raw = raw.set_channel_types({'eog1': 'eog', 'eog2': 'eog'})
    return raw

#### Resampling and PREP

In [5]:
def resample(raw, fs, events): # Resample to a more manageable speed
    raw, events = raw.resample(fs, events = events)
    return raw, events

def run_PREP(raw, sub, run, LOWPASS): # Run PREP pipeline (notch, exclude bad channels, and re-reference)
    raw.load_data()
    seed = int(str(sub) + str(run))
    np.random.seed(seed)

    lf = raw.info['line_freq']
    prep_params = {
        'ref_chs': 'eeg',
        'reref_chs': 'eeg',
        'line_freqs': np.arange(lf, LOWPASS, lf) if np.arange(lf, LOWPASS, lf).size > 0 else [lf]
    }
    prep = PrepPipeline(raw, prep_params, raw.get_montage(), ransac = False, random_state = seed)
    prep = prep.fit()

#     raw = prep.raw # might not include the non-eeg channels
    raw = prep.raw_eeg # replace raw with cleaned version
    raw_non_eeg = prep.raw_non_eeg # return the eog
    raw = raw.add_channels([raw_non_eeg], force_update_info=True) # combine eeg and non eeg
    bads = prep.noisy_channels_original
    return raw, bads

#### Run ICA on one copy of the data
Split the data into two copies, one filtered more liberally for ICA so that high frequency noise can be detected, one band-pass filtered at the behaviorally relevant frequencies. All of the following preprocessing steps will be applied to each of the copies.

In [6]:
def bandpass(raw, h_freq, l_freq):
    raw = raw.filter(l_freq = l_freq, h_freq = h_freq)
    return raw

def epoch(raw):
    epochs = mne.Epochs(
        raw, 
        events, 
        tmin = -0.2, 
        tmax = 0.250, 
        baseline = None, # do NOT baseline correct the trials yet; we do that after ICA
        event_id = event_ids, # remember which epochs are associated with which condition
        preload = True # keep data in memory
    )
    return epochs

def compute_ICA(epochs):
    ica = ICA(n_components = 15, random_state = 0)
    ica = ica.fit(epochs, picks = ['eeg', 'eog'])
    return ica

def apply_ICA(epochs_for_ica, epochs):
    eog_indices, eog_scores = ica.find_bads_eog(epochs_for_ica, threshold = 1.96)
    ica.exclude = eog_indices
    epochs = ica.apply(epochs) # apply to aggressively filtered version of data
    return epochs, ica

#### Baseline correct and reject trials
Back to applying preprocessing on only one copy of the data. ICA is finished.

In [7]:
def baseline_correct(epochs):
    epochs = epochs.pick_types(eeg = True) # change syntax?
    epochs = epochs.apply_baseline((-0.2, 0.))
    return epochs

def reject_trials(threshold, epochs):
    epochs = epochs.drop_bad(reject = {'eeg': threshold})
    return epochs

#### Save results and generate report

In [8]:
def get_save_path(deriv_root, sub, task, run):
    sink = DataSink(deriv_root, 'preprocessing')

    # save cleaned data
    fpath = sink.get_path(
                    subject = sub,
                    task = task, 
                    run = run,
                    desc = 'clean',
                    suffix = 'epo', # this suffix is following MNE, not BIDS, naming conventions
                    extension = 'fif.gz',
                    )
    return fpath

def save_preprocessed_data(fpath, epochs):
    epochs.save(fpath, overwrite = True)
    
def generate_report(fpath, epochs, ica, bads):
    report = mne.Report(verbose = True)
    report.parse_folder(op.dirname(fpath), pattern = '*epo.fif.gz', render_bem = False)

    # Plot the ERP
    fig_erp = epochs['50'].average().plot(spatial_colors = True)
    report.add_figs_to_section(
        fig_erp, 
        captions = 'Average Evoked Response', 
        section = 'evoked'
    )

    # Plot the excluded ICAs
    if ica.exclude: # if we found any bad components
        fig_ica_removed = ica.plot_components(ica.exclude)
        report.add_figs_to_section(
            fig_ica_removed, 
            captions = 'Removed ICA Components', 
            section = 'ICA'
        )     
    
    # Format output
    html_lines = []
    for line in pformat(bads).splitlines():
        html_lines.append('<br/>%s' % line) 
    html = '\n'.join(html_lines)
    report.add_htmls_to_section(html, captions = 'Interpolated Channels', section = 'channels')
    report.add_htmls_to_section('<br/>threshold: {:0.2f} microvolts</br>'.format(thres['eeg'] * 1e6), 
                                captions = 'Trial Rejection Criteria', section = 'rejection')
    report.add_htmls_to_section(epochs.info._repr_html_(), captions = 'Info', section = 'info')
    report.save(op.join(sink.deriv_root, 'sub-%s.html'%sub), overwrite = True)

## Preprocessing wrapper
Since we have to loop over all the data files the section below will contain the for loop to wrap all the preprocessing functions contained in the subsequent sections.

In [None]:
start = time.time()

# Parse BIDS directory
layout = BIDSLayout(BIDS_ROOT)
subjects = layout.get_subjects()
tasks = layout.get_tasks()
runs = layout.get_runs()
print(subjects, tasks, runs)

for (sub, task, run) in fpaths():
    if sub == '4':
        continue # REMOVE?
    
    # Import data
    print("Import data")
    bids_path = get_bids_path(BIDS_ROOT, sub, task, run)
    print(bids_path)
    if not bids_path.fpath.is_file(): # skip if file doesn't exist
        continue
    raw = import_bids_data(bids_path)
    events, event_ids = read_events(raw)

    # Create virtual EOGs
    raw.load_data()
    raw = create_eogs(raw)
    
    # Resampling and PREP
    print("Resampling and PREP")
    raw, events = resample(raw, FS, events)
    raw, bads = run_PREP(raw, sub, run, LOWPASS)

    # Run ICA on one copy of the data
    print("Run ICA on one copy of the data")
    raw_for_ica = bandpass(raw, 1., None)
    raw = bandpass(raw, 30, 270)
    
    epochs_for_ica = epoch(raw_for_ica)
    epochs = epoch(raw)
    
    ica = compute_ICA(epochs_for_ica) # run ICA on less aggressively filtered data
    epochs, ica = apply_ICA(epochs_for_ica, epochs) # apply ICA on more aggressively filtered data
    
    # Baseline correct and reject trials
    print("Baseline correct and reject trials")
    epochs = baseline_correct(epochs)
    epochs = reject_trials(REJECT_THRES, epochs)
    
    # Save results and generate report
    print("Save results and generate report")
    fpath = get_save_path(DERIV_ROOT, sub, task, run)
    save_preprocessed_data(fpath, epochs)
    generate_report(fpath, epochs, ica, bads)
    
end = time.time()
print(end - start)

['2', '6', '4', '3', '5'] ['pitch'] [1, 2]
Import data
../data/bids/sub-2/eeg/sub-2_task-pitch_run-1_eeg.vhdr
Used Annotations descriptions: ['100', '150', '200', '250', '50']


  raw = read_raw_bids(bids_path, verbose = False)
['Aux1']
Consider setting the channel types to be of EEG/sEEG/ECoG/DBS/fNIRS using inst.set_channel_types before calling inst.set_montage, or omit these channels when creating your montage.
  raw = read_raw_bids(bids_path, verbose = False)


Reading 0 ... 21218499  =      0.000 ...  2121.850 secs...
EEG channel type selected for re-referencing
Creating RawArray with float64 data, n_channels=1, n_times=21218500
    Range : 0 ... 21218499 =      0.000 ...  2121.850 secs
Ready.


In [None]:
raw.info

In [None]:
# start = time.time()

# sub = '2'
# task = 'pitch'
# run = '1'

# # Import data
# print("Import data")
# bids_path = get_bids_path(BIDS_ROOT, sub, task, run)
# print(bids_path)
# raw = import_bids_data(bids_path)
# raw.ch_names

# # # Set channel locations and create EOGs
# # mapping = set_electrode_mappings(CHANMAP_FPATH)
# # chanpos_fpath = get_chanpos_fpath(CHANPOS_ROOT, sub)
# # raw = set_electrode_positions(raw, chanpos_fpath, 'Aux1', mapping)
# # raw = create_eogs(raw)

# # end = time.time()
# # print(end - start)