In [1]:
%matplotlib inline

In [2]:
%%capture
!pip install mne

In [31]:
import os
import pandas as pd
import numpy as np
from scipy import signal
from sklearn.model_selection import train_test_split
from sklearn.linear_model import RidgeClassifierCV, Ridge
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from mne.decoding import Vectorizer

In [14]:
data = pd.read_csv('./data/epochs.csv', index_col=[0])

In [15]:
data.head()

Unnamed: 0,time,condition,epoch,Fp1,Fp2,F5,AFz,F6,T7,Cz,T8,P7,P3,Pz,P4,P8,O1,Oz,O2,STI 014
0,-0.097656,Target,0,-1.965512,43.550689,-18.469894,7.615624,13.062358,2.950307,5.19308,38.892224,2.482709,2.366828,-10.425305,6.175924,28.6126,15.837989,-0.00855,-5.403414,0.0
1,-0.087891,Target,0,-8.878524,-0.753702,-39.762111,5.581617,-17.066895,-9.037337,-23.830406,19.622402,-41.573694,-27.536086,-45.680985,-20.230455,-17.798841,-18.669673,-28.362084,-27.623759,0.0
2,-0.078125,Target,0,-7.9233,22.35738,-14.613151,1.043238,3.730958,-22.740073,-0.526012,12.032473,3.816798,-0.776339,-13.769359,-15.783225,-12.258964,10.163884,-5.59576,-5.650848,0.0
3,-0.068359,Target,0,10.888668,2.224371,-22.415468,22.70799,2.656191,-8.252214,-3.393788,24.320502,-23.151776,-10.029675,-28.945462,-0.673688,13.192939,1.618473,-8.307013,-7.111478,0.0
4,-0.058594,Target,0,-9.551013,11.443248,-17.984111,-4.119327,5.810256,-0.089011,-6.248183,33.287151,-9.872043,-13.848481,-19.698242,-8.886211,-3.499024,-4.661148,-11.851558,-11.887896,0.0


In [46]:
def split_by_channels(data):
    columns = data.columns
    result = []

    for col in columns:
        result.append(data[col].to_list())

    return np.array(result, np.float)

def to_record(data, epoch_idx):
    epoch_data = data[(data['epoch'] == i) & (data['STI 014'] <= 0)]
    exclude_cols = ~data.columns.isin(['time', 'condition', 'epoch', 'STI 014'])

    label = 0. if epoch_data['condition'].iloc[-1] == 'NonTarget' else 1.
    values = split_by_channels(epoch_data.loc[:, exclude_cols])

    return (values, label)

In [54]:
dataset_X = None
dataset_y = None

if os.path.isfile('./data/dataset.npz'):
    with np.load('./data/dataset.npz') as file:
        dataset_X, dataset_y = file['dataset_X'], file['dataset_y']
else:
    for i in data['epoch'].unique():
        X, y = to_record(data, i)
        if dataset_X is None:
            dataset_X = np.array([X[:, :99]], np.float)
        else:
            dataset_X = np.append(dataset_X, [X[:, :99]], axis=0)

        if dataset_y is None:
            dataset_y = np.array([y], np.float)
        else:
            dataset_y = np.append(dataset_y, [y], axis=0)
            
    np.savez_compressed('./data/dataset.npz', dataset_X=dataset_X, dataset_y=dataset_y)

In [55]:
train_X, test_X, train_y, test_y = train_test_split(dataset_X, dataset_y)

In [56]:
from sklearn.base import BaseEstimator, TransformerMixin

class Transformer(BaseEstimator, TransformerMixin):
    '''
    Base class for transformers providing dummy implementation
        of the methods expected by sklearn
    '''
    def fit(self, x, y=None):
        return self

class ButterFilter(Transformer):
    '''Applies Scipy's Butterworth filter'''
    def __init__(self, sampling_rate: int, order: int, highpass: int, lowpass: int) -> None:
        self.sampling_rate = sampling_rate
        self.order = order
        self.highpass = highpass
        self.lowpass = lowpass

        normal_cutoff = [a / (0.5 * self.sampling_rate) for a in (self.highpass, self.lowpass)]
        self.filter = signal.butter(self.order, normal_cutoff, btype='bandpass')

    def transform(self, x):
        out = np.empty_like(x)
        out[:] = [signal.filtfilt(*self.filter, item) for item in x]
        return out

class ChannellwiseScaler(Transformer):
    '''Performs channelwise scaling according to given scaler
    '''
    def __init__(self, scaler: Transformer):
        '''Args:
            scaler: instance of one of sklearn.preprocessing classes
                StandardScaler or MinMaxScaler or analogue
        '''
        self.scaler = scaler

    def fit(self, x: np.ndarray, y=None):
        '''
        Args:
            x: array of eegs, that is every element of x is (n_channels, n_ticks)
                x shaped (n_eegs) of 2d array or (n_eegs, n_channels, n_ticks)
        '''
        for signals in x:
            self.scaler.partial_fit(signals.T)
        return self

    def transform(self, x):
        '''Scales each channel

        Wors either with one record, 2-dim input, (n_channels, n_samples)
            or many records 3-dim, (n_records, n_channels, n_samples)
        Returns the same format as input
        '''
        scaled = np.empty_like(x)
        for i, signals in enumerate(x):
            # double T for scaling each channel separately
            scaled[i] = self.scaler.transform(signals.T).T
        return scaled

In [60]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import RidgeClassifierCV
from sklearn.neural_network import MLPClassifier
from sklearn.svm import LinearSVC, SVC

preproc = make_pipeline(
    ButterFilter(512 // 5, 4, 0.5, 20), 
    ChannellwiseScaler(StandardScaler()),
    Vectorizer(),
)

preproc.fit(train_X)

# cls = MLPClassifier(hidden_layer_sizes=(1000, 500), learning_rate='adaptive', 
#                    activation='relu')
cls = LinearDiscriminantAnalysis(solver='eigen', shrinkage='auto')
cls.fit(preproc.transform(train_X), train_y)

preds = cls.predict(preproc.transform(test_X))
print(classification_report(test_y, preds))

              precision    recall  f1-score   support

         0.0       0.92      0.96      0.94      3345
         1.0       0.71      0.52      0.60       613

    accuracy                           0.89      3958
   macro avg       0.81      0.74      0.77      3958
weighted avg       0.88      0.89      0.89      3958

