# 1) Prepare EEG data for training of machine-learning models
+ Import data.
+ Apply filters (bandpass).
+ Detect potential bad channels and replace them by interpolation.
+ Detect potential bad epochs and remove them.

In [1]:
# Import packages
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import mne
#%matplotlib inline
#%matplotlib qt

from mayavi import mlab
#%qui qt

In [2]:
ROOT = "C:\\OneDrive - Netherlands eScience Center\\Project_ePodium\\"
PATH_CODE = ROOT + "EEG_explorer\\"
PATH_DATA = ROOT + "Data\\"
PATH_OUTPUT = ROOT + "Data\\"

import sys
sys.path.insert(0, PATH_CODE)

### Custom cnt file importer function:

In [9]:
def read_cnt_file(file, 
                  event_idx = [3, 13, 66],
                  tmin = -0.2,
                  tmax = 0.8,
                  lpass = 0.5, 
                  hpass = 40, 
                  threshold = 5, 
                  max_bad_fraction = 0.2):
    """ Function to read cnt file. Run bandpass filter. 
    Then detect and correct/remove bad channels and bad epochs.
    Store resulting epochs as arrays.
    """
    
    # Initialize array
    signal_collection = np.zeros((0,62,501))  # TODO: check if that format works for all files (guess not!)
    label_collection = np.zeros((0))
    
    # Import file 
    data_raw = mne.io.read_raw_cnt(file, montage=None, eog='auto', preload=True)
    
    # Band-pass filter (between 0.5 and 40 Hz. was 0.5 to 30Hz in Stober 2016)
    data_raw.filter(0.5, 40, fir_design='firwin')

    events = mne.find_events(data_raw, shortest_event=0, stim_channel='STI 014', verbose=False)
    
    # Set baseline:
    baseline = (None, 0)  # means from the first instant to t = 0

    for event_id in event_idx:
    
        # Pick EEG channels 
        picks = mne.pick_types(data_raw.info, meg=False, eeg=True, stim=False, eog=False,
                           exclude='bads')

        epochs = mne.Epochs(data_raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                        baseline=baseline, preload=True, verbose=False)

        # Detect potential bad channels and epochs
        bad_channels, bad_epochs = helper_functions.select_bad_epochs(epochs, 
                                                                      event_id, 
                                                                      threshold = threshold, 
                                                                      max_bad_fraction = max_bad_fraction)

        # Interpolate bad channels
        if len(bad_channels) > 0: 
            # Mark bad channels:
            data_raw.info['bads'] = bad_channels
            # Pick EEG channels:
            picks = mne.pick_types(data_raw.info, meg=False, eeg=True, stim=False, eog=False,
                               exclude=[])
            epochs = mne.Epochs(data_raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                            baseline=baseline, preload=True, verbose=False)
            # Interpolate bad channels using functionality of 'mne'
            epochs.interpolate_bads()
        
        # Get signals as array and add to total collection
        signals_cleaned = epochs[str(event_id)].drop(bad_epochs).get_data()
        signal_collection = np.concatenate((signal_collection, signals_cleaned), axis=0)
        label_collection = np.concatenate((label_collection, event_id*np.ones((signals_cleaned.shape[0]))), axis=0)
        
    return signal_collection, label_collection.astype(int)

#  Workflow data processing
1. Load cnt files.
2. Preprocess raw data (bandpass + detect outliers and 'bad' epochs).
3. Store epoch data and event type as array

In [20]:
import fnmatch
import warnings
warnings.filterwarnings('ignore')

import helper_functions

dirs = os.listdir(PATH_DATA)
cnt_files = fnmatch.filter(dirs, "*.cnt")


# Initialize array
signal_collection = np.zeros((0,62,501))
label_collection = np.zeros((0))

for filename in cnt_files[0:3]:
    print(40*"=")
    print("Imported file ",filename)
    # Import data and events
    file = PATH_DATA + filename
    
    signal_collect, label_collect = read_cnt_file(file, 
                  event_idx = [3, 13, 66],
                  tmin = -0.2,
                  tmax = 0.8,
                  lpass = 0.5, 
                  hpass = 40, 
                  threshold = 5, 
                  max_bad_fraction = 0.2)
            
    # Get signals as array and add to total collection
    signal_collection = np.concatenate((signal_collection, signal_collect), axis=0)
    label_collection = np.concatenate((label_collection, label_collect), axis=0)


Imported file  015_thomas_mmn36w.cnt
Reading 0 ... 370279  =      0.000 ...   740.558 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 66 bad epochs in a total of 11  channels.
Marked 66 bad epochs in a total of 400  epochs.
Found 8 bad epochs in a total of 4  channels.
Marked 8 bad epochs in a total of 50  epochs.
Found 7 bad epochs in a total of 2  channels.
Marked 7 bad epochs in a total of 50  epochs.
Imported file  034_17_mc_mmn36_wk.cnt
Reading 0 ... 373379  =      0.000 ...   746.758 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 72 bad epochs in a total of 52  channels.
Marked 72 bad epochs in a total of 400  epochs.
Found 14 bad epochs in a total of 6  channels.
Marked 14 bad epochs in a total of 50  epo

In [17]:
signal_collection.shape, label_collection.shape

((1626, 62, 501), (1626,))