In [6]:
import numpy as np
import pandas as pd
import glob
import matplotlib.pyplot as plt
from scipy import signal

# Pre-Processing

things to do: ICA filtering to remove artifacts, butterworth filter, epoching ...

In [17]:
train_labels = pd.read_csv('Data/TrainLabels.csv')
submission = pd.read_csv('Data/SampleSubmission.csv')

There are 60 feedbacks for each session, AKA, 12 5 letter words. Each feedback/letter was either a right or wrong prediction from the user. Using the EEG data, we must train a model on the tendencies within the EEG data itself, whenever a feedback was presented.

In [18]:
train_labels.head(60)

Unnamed: 0,IdFeedBack,Prediction
0,S02_Sess01_FB001,1
1,S02_Sess01_FB002,1
2,S02_Sess01_FB003,0
3,S02_Sess01_FB004,0
4,S02_Sess01_FB005,1
5,S02_Sess01_FB006,0
6,S02_Sess01_FB007,1
7,S02_Sess01_FB008,1
8,S02_Sess01_FB009,0
9,S02_Sess01_FB010,1


Collecting all the names of the training files, and then running a loop through each file, it is imported as a DataFrame, and then turned into an array, where it is appended to the training/test set.

In [19]:
train_files = glob.glob('Data/train/Data*.csv')
test_files = glob.glob('Data/test/Data*.csv')
train_files[0:6]

['Data/train\\Data_S02_Sess01.csv',
 'Data/train\\Data_S02_Sess02.csv',
 'Data/train\\Data_S02_Sess03.csv',
 'Data/train\\Data_S02_Sess04.csv',
 'Data/train\\Data_S02_Sess05.csv',
 'Data/train\\Data_S06_Sess01.csv']

In [20]:
'''
Ingest Data by looping through files

Take only rows where FeedBackEvent == 1 (when the prediction was good) 
and then drop the column

Append values to list of arrays called temp

Input: files -> array of string of file names (Data_S*_Sess*.csv)
Output: temp -> final array of all appended values
'''
def extract_d(files):
    for i, f in enumerate(files):
        print(i,f)
        df = pd.read_csv(f)
        df = df[df.FeedBackEvent == 1]
        df = df.drop('FeedBackEvent', axis = 1)
        if i == 0:
            temp = np.array(df)
        else:
            temp = np.vstack((temp,np.array(df)))
            
    return temp

In [21]:
train = extract_d(train_files)

0 Data/train\Data_S02_Sess01.csv
1 Data/train\Data_S02_Sess02.csv
2 Data/train\Data_S02_Sess03.csv
3 Data/train\Data_S02_Sess04.csv
4 Data/train\Data_S02_Sess05.csv
5 Data/train\Data_S06_Sess01.csv
6 Data/train\Data_S06_Sess02.csv
7 Data/train\Data_S06_Sess03.csv
8 Data/train\Data_S06_Sess04.csv
9 Data/train\Data_S06_Sess05.csv
10 Data/train\Data_S07_Sess01.csv
11 Data/train\Data_S07_Sess02.csv
12 Data/train\Data_S07_Sess03.csv
13 Data/train\Data_S07_Sess04.csv
14 Data/train\Data_S07_Sess05.csv
15 Data/train\Data_S11_Sess01.csv
16 Data/train\Data_S11_Sess02.csv
17 Data/train\Data_S11_Sess03.csv
18 Data/train\Data_S11_Sess04.csv
19 Data/train\Data_S11_Sess05.csv
20 Data/train\Data_S12_Sess01.csv
21 Data/train\Data_S12_Sess02.csv
22 Data/train\Data_S12_Sess03.csv
23 Data/train\Data_S12_Sess04.csv
24 Data/train\Data_S12_Sess05.csv
25 Data/train\Data_S13_Sess01.csv
26 Data/train\Data_S13_Sess02.csv
27 Data/train\Data_S13_Sess03.csv
28 Data/train\Data_S13_Sess04.csv
29 Data/train\Data_S13_S

In [22]:
test = extract_d(test_files)

0 Data/test\Data_S01_Sess01.csv
1 Data/test\Data_S01_Sess02.csv
2 Data/test\Data_S01_Sess03.csv
3 Data/test\Data_S01_Sess04.csv
4 Data/test\Data_S01_Sess05.csv
5 Data/test\Data_S03_Sess01.csv
6 Data/test\Data_S03_Sess02.csv
7 Data/test\Data_S03_Sess03.csv
8 Data/test\Data_S03_Sess04.csv
9 Data/test\Data_S03_Sess05.csv
10 Data/test\Data_S04_Sess01.csv
11 Data/test\Data_S04_Sess02.csv
12 Data/test\Data_S04_Sess03.csv
13 Data/test\Data_S04_Sess04.csv
14 Data/test\Data_S04_Sess05.csv
15 Data/test\Data_S05_Sess01.csv
16 Data/test\Data_S05_Sess02.csv
17 Data/test\Data_S05_Sess03.csv
18 Data/test\Data_S05_Sess04.csv
19 Data/test\Data_S05_Sess05.csv
20 Data/test\Data_S08_Sess01.csv
21 Data/test\Data_S08_Sess02.csv
22 Data/test\Data_S08_Sess03.csv
23 Data/test\Data_S08_Sess04.csv
24 Data/test\Data_S08_Sess05.csv
25 Data/test\Data_S09_Sess01.csv
26 Data/test\Data_S09_Sess02.csv
27 Data/test\Data_S09_Sess03.csv
28 Data/test\Data_S09_Sess04.csv
29 Data/test\Data_S09_Sess05.csv
30 Data/test\Data_S1

Output new arrays into 2 csvs for model use

In [14]:
cols = ['Time', 'Fp1', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3',
       'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
       'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4',
       'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6',
       'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7',
       'POz', 'P08', 'O1', 'O2', 'EOG']
test_df = pd.DataFrame(test, columns = cols)
test_df.to_csv('Data/test_df.csv')

In [78]:
cols = ['Time', 'Fp1', 'Fp2', 'AF7', 'AF3', 'AF4', 'AF8', 'F7', 'F5', 'F3',
       'F1', 'Fz', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCz',
       'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'Cz', 'C2', 'C4',
       'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6',
       'TP8', 'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8', 'PO7',
       'POz', 'P08', 'O1', 'O2', 'EOG']
train_df = pd.DataFrame(train, columns = cols)
train_df.to_csv('Data/train_df.csv')

Butter worth filter:

In [7]:
train_df = pd.read_csv('Data/train_df.csv')
test_df = pd.read_csv('Data/test_df.csv')

In [56]:
def butter_filter(order, low_pass, high_pass, fs,sig):
    nyq = 0.5 * fs
    lp = low_pass / nyq
    hp = high_pass / nyq
    sos = signal.butter(order, [lp, hp], btype='band', output = 'sos')
    return signal.sosfilt(sos, sig)

In [81]:
sig = np.array(train)
EEG = sig[:,1:-1]

In [82]:
order = 5
low_pass = 1
high_pass = 40
fs = 200
filtered = butter_filter(order, low_pass, high_pass, fs, EEG)

In [85]:
df_filt = pd.DataFrame(filtered, columns = cols[1:-1])
df_filt.insert(0, column='time', value=train_df.Time)
df_filt

Unnamed: 0,time,Fp1,Fp2,AF7,AF3,AF4,AF8,F7,F5,F3,...,Pz,P2,P4,P6,P8,PO7,POz,P08,O1,O2
0,47.995,16.487111,117.488482,372.896142,694.320804,835.750085,672.111795,378.058242,228.048886,287.564569,...,-198.893643,-202.710706,-264.157470,-339.862939,-310.514648,-156.122245,-12.295022,-18.975215,-138.891111,-184.346622
1,55.860,14.715102,104.814552,332.742163,620.071351,746.410189,596.847682,328.229563,194.882281,259.484352,...,-172.363169,-173.867309,-230.553577,-303.867944,-277.924090,-129.524208,9.347657,0.015911,-118.988280,-160.782316
2,63.775,13.014403,93.276124,297.953448,558.442526,675.602346,542.939624,301.513433,182.826152,242.795575,...,-143.722077,-146.733458,-204.771072,-283.375411,-266.425135,-124.018349,16.325505,12.320011,-106.059587,-155.168439
3,71.660,10.033497,72.690745,234.415974,442.171602,534.219745,420.066311,215.485711,116.839199,172.515980,...,-108.135784,-110.376451,-156.960953,-225.261445,-210.407706,-80.420813,48.488254,40.063051,-79.272553,-131.109167
4,79.525,7.790420,56.344143,181.032040,339.131153,403.534712,303.558271,132.469266,52.597404,109.497253,...,-74.090527,-81.104778,-119.137547,-170.601771,-153.434742,-44.462744,61.959545,54.152540,-47.842579,-93.013904
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5435,1273.420,2.236742,15.581834,47.579870,85.496155,108.412137,123.506338,147.928145,151.858856,82.655340,...,-2.646996,-22.720768,1.577076,59.074610,130.838368,153.305606,44.190093,-139.317802,-222.401015,-171.634660
5436,1288.220,2.947350,21.760664,71.513376,139.777541,189.582760,212.668793,232.250839,230.275420,149.801027,...,-37.546229,-53.638088,-31.372996,20.625414,90.348569,115.564093,7.321757,-177.791397,-254.558835,-185.791997
5437,1299.085,2.205803,16.727750,55.111263,105.771874,140.824811,163.179314,195.356141,208.508287,142.170422,...,-44.633218,-61.041456,-29.650096,34.046097,107.397276,125.667478,8.094177,-181.013528,-259.137556,-196.467291
5438,1311.585,3.859401,27.755727,88.678967,168.196828,220.536943,236.912768,245.813011,233.945385,150.180431,...,-51.624861,-69.109814,-44.341437,12.374947,85.710754,110.535242,-2.538594,-192.574998,-269.167971,-196.841206


In [97]:
train_df

Unnamed: 0,Time,Fp1,Fp2,AF7,AF3,AF4,AF8,F7,F5,F3,...,P2,P4,P6,P8,PO7,POz,P08,O1,O2,EOG
0,47.995,830.677222,979.638619,847.257758,766.929505,555.929311,853.074414,777.926970,910.416082,798.451960,...,527.221414,568.057948,817.546838,741.133120,742.770663,313.393700,932.304475,750.347476,969.756009,-1591.606547
1,55.860,741.397327,872.008242,766.300141,677.442596,457.186287,782.537999,700.256845,824.570032,711.395305,...,447.781496,493.392074,757.499946,663.731070,643.358617,280.655005,835.169209,682.361834,985.912887,-1632.253751
2,63.775,655.710247,800.222757,690.103381,608.341656,428.382266,705.809794,657.269119,744.304954,623.569055,...,384.128875,437.289179,695.405881,609.067129,588.200731,240.277625,735.432159,637.662106,939.623087,-1588.697668
3,71.660,505.521992,656.193543,535.766742,460.644899,295.091007,548.810110,491.984303,573.692467,457.116312,...,261.897538,307.653602,551.175548,481.569049,443.286527,93.176044,597.832214,487.044101,765.703260,-1471.171030
4,79.525,392.508071,504.661776,395.453906,339.234001,157.797753,389.749396,359.081097,444.971754,326.184913,...,151.220874,177.583071,372.191496,339.157296,304.684089,-32.193317,449.920405,315.785541,620.761488,-1695.429472
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
5435,1273.420,112.694747,114.898153,70.390233,193.279030,266.860094,256.430933,17.197693,13.012652,-28.589442,...,266.366283,597.599634,536.529318,65.652461,49.511812,410.058346,172.048085,-71.402527,140.903768,175.836648
5436,1288.220,148.497599,213.298034,168.922279,265.369510,422.072884,386.100337,122.566070,88.679542,110.368632,...,379.944861,714.940603,674.233784,152.643107,156.223042,548.254459,308.678100,90.574961,266.753191,59.219730
5437,1299.085,111.135944,181.903146,74.119436,243.381662,359.825901,324.794443,100.102038,98.790124,72.437542,...,355.662238,693.343212,632.322554,123.951267,121.201240,506.900448,252.061627,12.931682,225.064720,65.476238
5438,1311.585,194.449881,242.082911,192.424477,313.985728,442.439491,388.237099,149.393032,129.524614,130.964152,...,401.690767,740.662045,697.448886,160.463993,164.674692,570.302978,321.015431,90.127412,276.589415,-4.818828
