# Prepare EEG data for training of machine-learning models
+ Import data.
+ Apply filters (bandpass).
+ Detect potential bad channels and replace them by interpolation.
+ Detect potential bad epochs and remove them.

## Import packages & links

In [1]:
# Import packages
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import mne
#%matplotlib inline

from mayavi import mlab

In [2]:
ROOT = "C:\\OneDrive - Netherlands eScience Center\\Project_ePodium\\"
PATH_CODE = ROOT + "EEG_explorer\\"
PATH_DATA = ROOT + "Data\\"
PATH_OUTPUT = ROOT + "Data\\processed\\"
PATH_METADATA = ROOT + "Data\\metadata\\"
file_labels = "metadata.xlsx"

import sys
sys.path.insert(0, PATH_CODE)

In [3]:
metadata = pd.read_excel(PATH_METADATA+file_labels)
print(metadata[["file", "group"]][:10])

                     file  group
0    031_04_mc_mmn36_wk_1      1
1      034_17_mc_mmn36_wk      1
2      036_17_mc_mmn36_wk      1
3        039_04_jc_mnn_wk      1
4      305_17_jc_mmn36_wk      1
5      306_17_mc_mmn36_wk      1
6  307_17_jc_mmn36_wakker      1
7      308_17_jc_mmn36_wk      1
8           309_17_jc_mmn      1
9           310_17_mc_mmn      1


## Search all *.cnt files and check for how many we have a label

In [4]:
import fnmatch
import warnings
warnings.filterwarnings('ignore')

import helper_functions

dirs = os.listdir(PATH_DATA)
cnt_files = fnmatch.filter(dirs, "*.cnt")

In [6]:
files_present = [x[:-4] for x in cnt_files]
files_labels_known = list(metadata["file"])

In [66]:
labels_known = 0
labels_unknown = 0
labels_type = []
for file in files_present:
    if file in files_labels_known:
        labels_known += 1
        labels_type.append(int(metadata[metadata["file"].str.match(file)]['group']))
    else:
        labels_unknown += 1  
        labels_type.append(0)
print("Files with proper labels:", labels_known, "||| Files without proper label:", labels_unknown)

Files with proper labels: 57 ||| Files without proper label: 135


In [72]:
labels_type.count(1), labels_type.count(2)

(24, 33)

## Custom cnt-file import function:

In [51]:
def read_cnt_file(file,
                  label_group,
                  event_idx = [3, 13, 66],
                  channel_set = "30",
                  tmin = -0.2,
                  tmax = 0.8,
                  lpass = 0.5, 
                  hpass = 40, 
                  threshold = 5, 
                  max_bad_fraction = 0.2):
    """ Function to read cnt file. Run bandpass filter. 
    Then detect and correct/remove bad channels and bad epochs.
    Store resulting epochs as arrays.
    
    Args:
    --------
    file: str
        Name of file to import.
    label_group: int
        Unique ID of specific group (must be >0).
    channel_set: str
        Select among pre-defined channel sets. Here: "30" or "62"
    """
    
    if channel_set == "30":
        channel_set = ['O2', 'O1', 'OZ', 'PZ', 'P4', 'CP4', 'P8', 'C4', 'TP8', 'T8', 'P7', 
                       'P3', 'CP3', 'CPZ', 'CZ', 'FC4', 'FT8', 'TP7', 'C3', 'FCZ', 'FZ', 
                       'F4', 'F8', 'T7', 'FT7', 'FC3', 'F3', 'FP2', 'F7', 'FP1']
    elif channel_set == "62":
        channel_set = ['O2', 'O1', 'OZ', 'PZ', 'P4', 'CP4', 'P8', 'C4', 'TP8', 'T8', 'P7', 
                       'P3', 'CP3', 'CPZ', 'CZ', 'FC4', 'FT8', 'TP7', 'C3', 'FCZ', 'FZ', 
                       'F4', 'F8', 'T7', 'FT7', 'FC3', 'F3', 'FP2', 'F7', 'FP1', 'AFZ', 'PO3', 
                       'P1', 'POZ', 'P2', 'PO4', 'CP2', 'P6', 'M1', 'CP6', 'C6', 'PO8', 'PO7', 
                       'P5', 'CP5', 'CP1', 'C1', 'C2', 'FC2', 'FC6', 'C5', 'FC1', 'F2', 'F6', 
                       'FC5', 'F1', 'AF4', 'AF8', 'F5', 'AF7', 'AF3', 'FPZ']
    else:
        print("Predefined channel set given by 'channel_set' not known...")
        
    
    # Initialize array
    signal_collection = np.zeros((0,len(channel_set),501))
    label_collection = np.zeros((0))
    
    # Import file 
    data_raw = mne.io.read_raw_cnt(file, montage=None, eog='auto', preload=True)
    
    # Band-pass filter (between 0.5 and 40 Hz. was 0.5 to 30Hz in Stober 2016)
    data_raw.filter(0.5, 40, fir_design='firwin')

    events = mne.find_events(data_raw, shortest_event=0, stim_channel='STI 014', verbose=False)
    
    # Set baseline:
    baseline = (None, 0)  # means from the first instant to t = 0

    # Select channels to exclude (if any)
    channels_exclude = [x for x in data_raw.ch_names if x not in channel_set]
    channels_exclude = [x for x in channels_exclude if x not in ['HEOG', 'VEOG', 'STI 014']]
    
    for event_id in event_idx:
    
        # Pick EEG channels 
        picks = mne.pick_types(data_raw.info, meg=False, eeg=True, stim=False, eog=False,
                           #exclude=data_exclude)#'bads'])
                               include=channel_set, exclude=channels_exclude)#'bads'])

        epochs = mne.Epochs(data_raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                        baseline=baseline, preload=True, verbose=False)

        # Detect potential bad channels and epochs
        bad_channels, bad_epochs = helper_functions.select_bad_epochs(epochs, 
                                                                      event_id, 
                                                                      threshold = threshold, 
                                                                      max_bad_fraction = max_bad_fraction)

        # Interpolate bad channels
        if len(bad_channels) > 0: 
            # Mark bad channels:
            data_raw.info['bads'] = bad_channels
            # Pick EEG channels:
            picks = mne.pick_types(data_raw.info, meg=False, eeg=True, stim=False, eog=False,
                               #exclude=data_exclude)#'bads'])
                               include=channel_set, exclude=channels_exclude)#'bads'])
            epochs = mne.Epochs(data_raw, events, event_id, tmin, tmax, proj=True, picks=picks,
                            baseline=baseline, preload=True, verbose=False)
            # Interpolate bad channels using functionality of 'mne'
            epochs.interpolate_bads()

        # Get signals as array and add to total collection
        signals_cleaned = epochs[str(event_id)].drop(bad_epochs).get_data()
        signal_collection = np.concatenate((signal_collection, signals_cleaned), axis=0)
        label_collection = np.concatenate((label_collection, event_id*label_group*np.ones((signals_cleaned.shape[0]))), axis=0)

    return signal_collection, label_collection.astype(int)

In [139]:
# channel names for 30 EEG channel case: 
print(epochs.ch_names)

['O2', 'O1', 'OZ', 'PZ', 'P4', 'CP4', 'P8', 'C4', 'TP8', 'T8', 'P7', 'P3', 'CP3', 'CPZ', 'CZ', 'FC4', 'FT8', 'TP7', 'C3', 'FCZ', 'FZ', 'F4', 'F8', 'T7', 'FT7', 'FC3', 'F3', 'FP2', 'F7', 'FP1']


In [126]:
# channel names for 62 EEG channel case: 
print(epochs.ch_names)

['O2', 'O1', 'OZ', 'PZ', 'P4', 'CP4', 'P8', 'C4', 'TP8', 'T8', 'P7', 'P3', 'CP3', 'CPZ', 'CZ', 'FC4', 'FT8', 'TP7', 'C3', 'FCZ', 'FZ', 'F4', 'F8', 'T7', 'FT7', 'FC3', 'F3', 'FP2', 'F7', 'FP1', 'AFZ', 'PO3', 'P1', 'POZ', 'P2', 'PO4', 'CP2', 'P6', 'M1', 'CP6', 'C6', 'PO8', 'PO7', 'P5', 'CP5', 'CP1', 'C1', 'C2', 'FC2', 'FC6', 'C5', 'FC1', 'F2', 'F6', 'FC5', 'F1', 'AF4', 'AF8', 'F5', 'AF7', 'AF3', 'FPZ']


## Check how many EEG channels the cnt-files feature... 

In [95]:
format_collection = []
for i, filename in enumerate(cnt_files):
    # Import file 
    file = PATH_DATA + filename
    data_raw = mne.io.read_raw_cnt(file, montage=None, eog='auto', preload=True)
    format_collection.append((i, len(data_raw.ch_names)))
    print(i, len(data_raw.ch_names))

Reading 0 ... 370279  =      0.000 ...   740.558 secs...
0 65
Reading 0 ... 373379  =      0.000 ...   746.758 secs...
1 65
Reading 0 ... 743799  =      0.000 ...  1487.598 secs...
2 65
Reading 0 ... 778159  =      0.000 ...  1556.318 secs...
3 65
Reading 0 ... 1497319  =      0.000 ...  2994.638 secs...
4 65
Reading 0 ... 372959  =      0.000 ...   745.918 secs...
5 65
Reading 0 ... 758319  =      0.000 ...  1516.638 secs...
6 65
Reading 0 ... 751639  =      0.000 ...  1503.278 secs...
7 65
Reading 0 ... 458319  =      0.000 ...   916.638 secs...
8 65
Reading 0 ... 376679  =      0.000 ...   753.358 secs...
9 65
Reading 0 ... 368719  =      0.000 ...   737.438 secs...
10 65
Reading 0 ... 774119  =      0.000 ...  1548.238 secs...
11 65
Reading 0 ... 373459  =      0.000 ...   746.918 secs...
12 33
Reading 0 ... 758559  =      0.000 ...  1517.118 secs...
13 65
Reading 0 ... 743039  =      0.000 ...  1486.078 secs...
14 65
Reading 0 ... 743199  =      0.000 ...  1486.398 secs...
15 65
R

129 65
Reading 0 ... 372199  =      0.000 ...   744.398 secs...
130 65
Reading 0 ... 744299  =      0.000 ...  1488.598 secs...
131 65
Reading 0 ... 742379  =      0.000 ...  1484.758 secs...
132 65
Reading 0 ... 740279  =      0.000 ...  1480.558 secs...
133 65
Reading 0 ... 153799  =      0.000 ...   307.598 secs...
134 65
Reading 0 ... 736119  =      0.000 ...  1472.238 secs...
135 65
Reading 0 ... 760639  =      0.000 ...  1521.278 secs...
136 65
Reading 0 ... 739679  =      0.000 ...  1479.358 secs...
137 65
Reading 0 ... 586559  =      0.000 ...  1173.118 secs...
138 65
Reading 0 ... 741919  =      0.000 ...  1483.838 secs...
139 65
Reading 0 ... 997039  =      0.000 ...  1994.078 secs...
140 33
Reading 0 ... 997039  =      0.000 ...  1994.078 secs...
141 33
Reading 0 ... 370459  =      0.000 ...   740.918 secs...
142 33
Reading 0 ... 741979  =      0.000 ...  1483.958 secs...
143 33
Reading 0 ... 381219  =      0.000 ...   762.438 secs...
144 33
Reading 0 ... 378859  =      0.00

In [111]:
a,b = zip(*format_collection)
len(np.where((np.array(b) == 65))[0]), len(np.where((np.array(b) == 33))[0]), len(a)

(128, 64, 192)

In [40]:
# number of cnt files which have both 65 channels (62 EEG) AND a label.
len(np.where((np.array(b) == 65) & (np.array(labels_type) == 1))[0])

NameError: name 'b' is not defined

So far we 'only' have about 60 cnt-files of which we have a label ("risk group" vs "no risc group").
And only 42 of them feature 62 EEG channels. I hence switched to 30 EEG channels and picked the ones that are present in all patient datasets.

# Workflow data processing
1. Load cnt files.
2. Select same number of channels (here: 30 same channels which exist for both 30 and 62 channel data)
3. Preprocess raw data (bandpass + detect outliers and 'bad' epochs).
4. Store epoch data and event type as array

## CAREFUL!
+ We have **NO PROPER LABELS** (yet).  
Only for 57 out of 192 CNT files do we have the information that they belong into "group 1"or "group 2". One of them supposedly is the risk-group, the other the control group.
+ In the risk group only about 40% will become dislexic. With only 24 and 33 cases for each group will can hence expect no more than 10-15 "real" dislexia labeled cases.

In [52]:
# Initialize array
signal_collection = np.zeros((0,30,501)) #62
label_collection = np.zeros((0))
metadata_collection = []

for i, filename in enumerate(cnt_files):
    # First check if we have proper label for that file
    if len(metadata[metadata["file"].str.match(filename[:-4])]['group']) < 1:
        print("No proper label found for file: ", filename)
    else:
        label_group = int(metadata[metadata["file"].str.match(filename[:-4])]['group'])
        
        print(40*"=")
        print("Importing file: ",filename)
        print("Data belongs into group: ", label_group)

        # Import data and events
        file = PATH_DATA + filename

        signal_collect, label_collect = read_cnt_file(file, 
                                                      label_group,
                                                      event_idx = [3, 13, 66],
                                                      channel_set = "30",
                                                      tmin = -0.2,
                                                      tmax = 0.8,
                                                      lpass = 0.5, 
                                                      hpass = 40, 
                                                      threshold = 5, 
                                                      max_bad_fraction = 0.2)

        # Get signals as array and add to total collection
        signal_collection = np.concatenate((signal_collection, signal_collect), axis=0)
        label_collection = np.concatenate((label_collection, label_collect), axis=0)
        metadata_collection.append((i, filename, signal_collection.shape[0]))


No proper label found for file:  015_thomas_mmn36w.cnt
Importing file:  034_17_mc_mmn36_wk.cnt
Data belongs into group:  1
Reading 0 ... 373379  =      0.000 ...   746.758 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 44 bad epochs in a total of 24  channels.
Marked 44 bad epochs in a total of 400  epochs.
Found 6 bad epochs in a total of 2  channels.
Marked 6 bad epochs in a total of 50  epochs.
Found 7 bad epochs in a total of 5  channels.
Marked 7 bad epochs in a total of 50  epochs.
Importing file:  036_17_mc_mmn36_wk.cnt
Data belongs into group:  1
Reading 0 ... 743799  =      0.000 ...  1487.598 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 414 bad epochs in a total of 23  channels.
Found bad channel (m

Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 135 bad epochs in a total of 15  channels.
Marked 135 bad epochs in a total of 953  epochs.
Found 20 bad epochs in a total of 9  channels.
Marked 20 bad epochs in a total of 118  epochs.
Found 12 bad epochs in a total of 12  channels.
Marked 12 bad epochs in a total of 118  epochs.
No proper label found for file:  344_17_mmn36_wk.cnt
No proper label found for file:  344_29_jc_mmn36_wk.cnt
Importing file:  345_17_mc_mmn36_wk.cnt
Data belongs into group:  1
Reading 0 ... 726239  =      0.000 ...  1452.478 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 55 bad epochs in a total of 22  channels.
Marked 55 bad epochs in a total of 800  epochs.
Found 8 bad epochs in a total of 15

Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 5 bad epochs in a total of 4  channels.
Marked 5 bad epochs in a total of 400  epochs.
Found 1 bad epochs in a total of 1  channels.
Marked 1 bad epochs in a total of 50  epochs.
No outliers found with given threshold.
No proper label found for file:  457_17_jd_mmn36_wk.cnt
No proper label found for file:  457_29_jd_mmn36_wk.cnt
No proper label found for file:  465_17_jd_mmn_36_wk.cnt
No proper label found for file:  465_29_jd_mmn36_wk.cnt
Importing file:  466_17_md_mmn36_wk.cnt
Data belongs into group:  2
Reading 0 ... 1089419  =      0.000 ...  2178.838 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 39 bad epochs in a total of 30  channels.
Marked 39 bad epochs in a tota

Found 6 bad epochs in a total of 30  channels.
Marked 6 bad epochs in a total of 50  epochs.
Found 28 bad epochs in a total of 30  channels.
Found bad channel (more than 10.0  bad epochs): Channel no:  6
Found bad channel (more than 10.0  bad epochs): Channel no:  8
Marked 11 bad epochs in a total of 50  epochs.
No proper label found for file:  604-133-29m-jc-mmn36.cnt
Importing file:  605-131-17m-jc-mmn.cnt
Data belongs into group:  1
Reading 0 ... 378919  =      0.000 ...   757.838 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 9 bad epochs in a total of 16  channels.
Marked 9 bad epochs in a total of 400  epochs.
Found 1 bad epochs in a total of 2  channels.
Marked 1 bad epochs in a total of 50  epochs.
Found 1 bad epochs in a total of 3  channels.
Marked 1 bad epochs in a total of 50  epochs.
No proper label found for file:  605-131-29m-jc-mmn3

Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 73 bad epochs in a total of 21  channels.
Marked 73 bad epochs in a total of 800  epochs.
Found 7 bad epochs in a total of 23  channels.
Marked 7 bad epochs in a total of 100  epochs.
Found 10 bad epochs in a total of 9  channels.
Marked 10 bad epochs in a total of 100  epochs.
No proper label found for file:  642-485-29m-jc-mmn36.cnt
Importing file:  646-478-17m-mc-mmn36.cnt
Data belongs into group:  1
Reading 0 ... 586559  =      0.000 ...  1173.118 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 28 bad epochs in a total of 9  channels.
Marked 28 bad epochs in a total of 634  epochs.
Found 7 bad epochs in a total of 4  channels.
Marked 7 bad epochs in a total of 80  epoch

h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 148 bad epochs in a total of 30  channels.
Marked 148 bad epochs in a total of 1200  epochs.
Found 22 bad epochs in a total of 24  channels.
Marked 22 bad epochs in a total of 150  epochs.
Found 17 bad epochs in a total of 23  channels.
Marked 17 bad epochs in a total of 150  epochs.
No proper label found for file:  751-452-29m-jr-mmn36.cnt
Importing file:  751-542-17m-jr-mmn36.cnt
Data belongs into group:  2
Reading 0 ... 742799  =      0.000 ...  1485.598 secs...
Setting up band-pass filter from 0.5 - 40 Hz
l_trans_bandwidth chosen to be 0.5 Hz
h_trans_bandwidth chosen to be 10.0 Hz
Filter length of 3301 samples (6.602 sec) selected
Found 24 bad epochs in a total of 12  channels.
Marked 24 bad epochs in a total of 800  epochs.
No outliers found with given threshold.
Found 1 bad epochs in a total of 1  channels.
Marked 1 bad epochs in a total of 100  epochs.
No proper label found for file: 

In [53]:
signal_collection.shape, label_collection.shape

((39083, 30, 501), (39083,))

In [74]:
metadata_collection[:10]

[(1, '034_17_mc_mmn36_wk.cnt', 443),
 (2, '036_17_mc_mmn36_wk.cnt', 1384),
 (18, '175_17_jd_mmn_wk.cnt', 1858),
 (26, '305_17_jc_mmn36_wk.cnt', 2349),
 (27, '306_17_mc_mmn36_wk.cnt', 2813),
 (28, '307_17_jc_mmn36_wakker.cnt', 3447),
 (29, '308_17_jc_mmn36_wk.cnt', 3938),
 (30, '309_17_jc_mmn.cnt', 4425),
 (33, '314_17_mc_mmn36_wk.cnt', 4917),
 (38, '337_17_jc_mmn36_wk.cnt', 5389)]

We hence get a dataset of 39083 datapoints with known label.  
Each datapoint consits of a 1-second EEG signal of 30 channels with a 500Hz sampling rate. Thus arrays with a size of 30 x 501. 

## Labels 
Here we have 6 labels. 
1. Group 1, stimuli 3 --> "3"
2. Group 1, stimuli 13 --> "13"
3. Group 1, stimuli 66 --> "66"
4. Group 2, stimuli 3 --> "6"
5. Group 2, stimuli 13 --> "26"
6. Group 2, stimuli 66 --> "132"

In [41]:
label_collection[1500:2000].astype(int)

array([  6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,
         6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   6,   

# Save entire processed dataset:

In [75]:
filename = PATH_OUTPUT + "EEG_data_30channels_1s_corrected.npy"
np.save(filename, signal_collection)

filename = PATH_OUTPUT + "EEG_data_30channels_1s_corrected_labels.npy"
np.save(filename, label_collection)

import csv
filename = PATH_OUTPUT + "EEG_data_30channels_1s_corrected_metadata.csv"

with open(filename, 'w') as csvFile:
    writer = csv.writer(csvFile)
    writer.writerows(metadata_collection)
csvFile.close()