In [4]:
%matplotlib inline

In [9]:
%%capture
!pip install mne
!pip install torch
!pip install pytorch-lightning
!pip install pytorch
!pip install torchmetrics

In [6]:
import pandas as pd
import numpy as np
from scipy import signal
from sklearn.model_selection import train_test_split
from sklearn.linear_model import RidgeClassifierCV, Ridge, LogisticRegressionCV
from sklearn.metrics import classification_report
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from mne.decoding import Vectorizer
import torch.nn as nn
import torch.nn.functional as F
import torch
from torch.utils.data import DataLoader, TensorDataset, Dataset
import pytorch_lightning as pl
import torchmetrics

In [8]:
dataset_X, dataset_y = None, None
    
with np.load('./data/dataset.npz') as file:
    dataset_X, dataset_y = file['dataset_X'], file['dataset_y']

In [10]:
train_X, test_X, train_y, test_y = train_test_split(dataset_X, dataset_y)

In [11]:
from sklearn.base import BaseEstimator, TransformerMixin

class Transformer(BaseEstimator, TransformerMixin):
    '''
    Base class for transformers providing dummy implementation
        of the methods expected by sklearn
    '''
    def fit(self, x, y=None):
        return self

class ButterFilter(Transformer):
    '''Applies Scipy's Butterworth filter'''
    def __init__(self, sampling_rate: int, order: int, highpass: int, lowpass: int) -> None:
        self.sampling_rate = sampling_rate
        self.order = order
        self.highpass = highpass
        self.lowpass = lowpass

        normal_cutoff = [a / (0.5 * self.sampling_rate) for a in (self.highpass, self.lowpass)]
        self.filter = signal.butter(self.order, normal_cutoff, btype='bandpass')

    def transform(self, x):
        out = np.empty_like(x)
        out[:] = [signal.filtfilt(*self.filter, item) for item in x]
        return out

class ChannellwiseScaler(Transformer):
    '''Performs channelwise scaling according to given scaler
    '''
    def __init__(self, scaler: Transformer):
        '''Args:
            scaler: instance of one of sklearn.preprocessing classes
                StandardScaler or MinMaxScaler or analogue
        '''
        self.scaler = scaler

    def fit(self, x: np.ndarray, y=None):
        '''
        Args:
            x: array of eegs, that is every element of x is (n_channels, n_ticks)
                x shaped (n_eegs) of 2d array or (n_eegs, n_channels, n_ticks)
        '''
        for signals in x:
            self.scaler.partial_fit(signals.T)
        return self

    def transform(self, x):
        '''Scales each channel

        Wors either with one record, 2-dim input, (n_channels, n_samples)
            or many records 3-dim, (n_records, n_channels, n_samples)
        Returns the same format as input
        '''
        scaled = np.empty_like(x)
        for i, signals in enumerate(x):
            # double T for scaling each channel separately
            scaled[i] = self.scaler.transform(signals.T).T
        return scaled

In [12]:
from sklearn.discriminant_analysis import LinearDiscriminantAnalysis
from sklearn.linear_model import RidgeClassifierCV
from sklearn.neural_network import MLPClassifier
from sklearn.svm import LinearSVC, SVC

preproc = make_pipeline(
    ButterFilter(512 // 5, 4, 0.5, 20), 
    ChannellwiseScaler(StandardScaler()),
    Vectorizer(),
)

preproc.fit(train_X)

In [13]:
in_features = 16 * 99
input_size = 1200
hidden_size = 800
num_classes = 2

In [14]:
class P300Mlp(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.accuracy = torchmetrics.classification.BinaryAccuracy()
        self.val_accuracy = torchmetrics.classification.BinaryAccuracy()
        self.model = nn.Sequential(
            nn.Linear(in_features, input_size, dtype=torch.double),
            nn.BatchNorm1d(input_size, dtype=torch.double),
            nn.Dropout(p=0.3),
            nn.ReLU(),
            nn.Linear(input_size, hidden_size, dtype=torch.double),
            nn.BatchNorm1d(hidden_size, dtype=torch.double),
            nn.Dropout(p=0.3),
            nn.ReLU(),
            nn.Linear(hidden_size, 1, dtype=torch.double),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.model(x)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-4)

    def training_step(self, train_batch, train_idx):
        X, y = train_batch

        y_pred = self.model(X).view(-1)
        loss = F.binary_cross_entropy(y_pred, y)

        self.accuracy(y_pred, y)
        self.log('train_acc', self.accuracy, prog_bar=True, on_step=True, on_epoch=True)
        self.log('train_loss', loss, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, train_batch, train_idx):
        X, y = train_batch

        y_pred = self.model(X).view(-1)
        loss = F.binary_cross_entropy(y_pred, y)

        self.val_accuracy(y_pred, y)
        self.log('valid_acc', self.val_accuracy, prog_bar=True, on_step=True, on_epoch=True)
        self.log('valid_loss', loss, on_step=True, on_epoch=True)
        return loss

    def predict_step(self, batch, batch_idx):
        return self.model(batch[0])

In [15]:
shape = train_X.shape
valid_shape = test_X.shape

In [17]:
dataloader = DataLoader(
    TensorDataset(
        torch.tensor(preproc.transform(train_X).reshape(shape[0], shape[1] * shape[2])).type(torch.double), 
        torch.tensor(train_y).type(torch.double)
    ),
    batch_size=8,
    num_workers=8
)

valid_dataloader = DataLoader(
    TensorDataset(
        torch.tensor(preproc.transform(test_X).reshape(valid_shape[0], valid_shape[1] * valid_shape[2])).type(torch.double),
        torch.tensor(test_y).type(torch.double)
    ),
    batch_size=8,
    num_workers=8
)

model = P300Mlp()

In [23]:
trainer = pl.Trainer(max_epochs=10)
trainer.fit(model, dataloader, valid_dataloader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name         | Type           | Params
------------------------------------------------
0 | accuracy     | BinaryAccuracy | 0     
1 | val_accuracy | BinaryAccuracy | 0     
2 | model        | Sequential     | 2.9 M 
------------------------------------------------
2.9 M     Trainable params
0         Non-trainable params
2.9 M     Total params
11.470    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [19]:
test_shape = test_X.shape
test_shape

(3958, 16, 99)

In [24]:
test_dataloader = DataLoader(
    TensorDataset(
        torch.tensor(preproc.transform(test_X).reshape(test_shape[0], test_shape[1] * test_shape[2]))
    ),
    batch_size=1
)

model.eval()

with torch.no_grad():
    y_pred = trainer.predict(model, test_dataloader)

  rank_zero_warn(


Predicting: 69it [00:00, ?it/s]

In [25]:
from sklearn.metrics import roc_auc_score, precision_recall_curve

print(roc_auc_score(test_y, torch.tensor(y_pred).flatten()))
precisions, recalls, thresholds = precision_recall_curve(test_y, y_pred)

f1_max = 0

for cur_precision, cur_recall, cur_threshold in zip(precisions, recalls, thresholds):
    f1 = 2 * (cur_precision * cur_recall) / (cur_precision + cur_recall) 
    if f1_max < f1:
        f1_max = f1
        print(f'{f1} {cur_precision} {cur_recall} {cur_threshold}')

0.8587146156091


  y = np.asarray(y)
  y = np.asarray(y)


0.24634470536109881 0.14047498736735725 1.0 tensor([[1.0073e-06]], dtype=torch.float64)
0.24639929093729224 0.14051048774323982 1.0 tensor([[8.7618e-06]], dtype=torch.float64)
0.24645390070921985 0.14054600606673406 1.0 tensor([[9.2038e-06]], dtype=torch.float64)
0.2465085346929727 0.14058154235145384 1.0 tensor([[1.0833e-05]], dtype=torch.float64)
0.24656319290465634 0.1406170966110268 1.0 tensor([[1.2152e-05]], dtype=torch.float64)
0.24661787536039034 0.14065266885909436 1.0 tensor([[1.4650e-05]], dtype=torch.float64)
0.2466725820763088 0.14068825910931174 1.0 tensor([[1.6691e-05]], dtype=torch.float64)
0.24672731306856002 0.140723867375348 1.0 tensor([[1.9257e-05]], dtype=torch.float64)
0.2467820683533067 0.14075949367088608 1.0 tensor([[1.9279e-05]], dtype=torch.float64)
0.24683684794672586 0.14079513800962268 1.0 tensor([[1.9715e-05]], dtype=torch.float64)
0.2468916518650089 0.1408308004052685 1.0 tensor([[2.0875e-05]], dtype=torch.float64)
0.24694648012436154 0.14086648087154802 

0.26072684642438454 0.1499056349420329 1.0 tensor([[0.0008]], dtype=torch.float64)
0.2607879924953096 0.1499460625674218 1.0 tensor([[0.0008]], dtype=torch.float64)
0.2608491672531081 0.14998651200431615 1.0 tensor([[0.0008]], dtype=torch.float64)
0.26091037071797274 0.15002698327037237 1.0 tensor([[0.0008]], dtype=torch.float64)
0.260971602910115 0.15006747638326587 1.0 tensor([[0.0008]], dtype=torch.float64)
0.26103286384976526 0.15010799136069114 1.0 tensor([[0.0008]], dtype=torch.float64)
0.2610941535571731 0.15014852822036187 1.0 tensor([[0.0008]], dtype=torch.float64)
0.2611554720526068 0.1501890869800108 1.0 tensor([[0.0009]], dtype=torch.float64)
0.2612168193563542 0.1502296676573899 1.0 tensor([[0.0009]], dtype=torch.float64)
0.2612781954887218 0.15027027027027026 1.0 tensor([[0.0009]], dtype=torch.float64)
0.26133960047003524 0.15031089483644228 1.0 tensor([[0.0009]], dtype=torch.float64)
0.2614010343206394 0.1503515413737155 1.0 tensor([[0.0009]], dtype=torch.float64)
0.2614

0.2765046587761269 0.16076134699853586 0.987410071942446 tensor([[0.0025]], dtype=torch.float64)
0.27657430730478594 0.1608084358523726 0.987410071942446 tensor([[0.0025]], dtype=torch.float64)
0.2766439909297052 0.1608555523000293 0.987410071942446 tensor([[0.0025]], dtype=torch.float64)
0.2767137096774193 0.16090269636576787 0.987410071942446 tensor([[0.0026]], dtype=torch.float64)
0.2767834635744895 0.16094986807387862 0.987410071942446 tensor([[0.0026]], dtype=torch.float64)
0.2768532526475038 0.16099706744868036 0.987410071942446 tensor([[0.0026]], dtype=torch.float64)
0.2769230769230769 0.16104429451452038 0.987410071942446 tensor([[0.0026]], dtype=torch.float64)
0.27699293642785067 0.16109154929577466 0.987410071942446 tensor([[0.0026]], dtype=torch.float64)
0.2770628311884935 0.16113883181684766 0.987410071942446 tensor([[0.0026]], dtype=torch.float64)
0.2771175726927939 0.16122388937922918 0.9856115107913669 tensor([[0.0026]], dtype=torch.float64)
0.27718765806777945 0.1612713

0.2957900492072171 0.17440361057382334 0.9730215827338129 tensor([[0.0052]], dtype=torch.float64)
0.29587093245829915 0.1744598516607546 0.9730215827338129 tensor([[0.0052]], dtype=torch.float64)
0.29595185995623635 0.17451612903225808 0.9730215827338129 tensor([[0.0052]], dtype=torch.float64)
0.2960328317373461 0.17457244272345918 0.9730215827338129 tensor([[0.0052]], dtype=torch.float64)
0.2961138478379858 0.17462879276952872 0.9730215827338129 tensor([[0.0053]], dtype=torch.float64)
0.2961949082945524 0.1746851792056829 0.9730215827338129 tensor([[0.0053]], dtype=torch.float64)
0.296276013143483 0.17474160206718345 0.9730215827338129 tensor([[0.0053]], dtype=torch.float64)
0.2963571624212545 0.17479806138933765 0.9730215827338129 tensor([[0.0053]], dtype=torch.float64)
0.2964383561643835 0.17485455720749837 0.9730215827338129 tensor([[0.0053]], dtype=torch.float64)
0.29651959440942727 0.17491108955706433 0.9730215827338129 tensor([[0.0053]], dtype=torch.float64)
0.2966008771929825 0

0.32250300842358604 0.1936416184971098 0.9640287769784173 tensor([[0.0096]], dtype=torch.float64)
0.3226000601865784 0.19371160101192628 0.9640287769784173 tensor([[0.0097]], dtype=torch.float64)
0.3226971703792896 0.19378163412870572 0.9640287769784173 tensor([[0.0097]], dtype=torch.float64)
0.3227943390545016 0.1938517179023508 0.9640287769784173 tensor([[0.0097]], dtype=torch.float64)
0.32289156626506027 0.1939218523878437 0.9640287769784173 tensor([[0.0097]], dtype=torch.float64)
0.3229888520638747 0.1939920376402461 0.9640287769784173 tensor([[0.0097]], dtype=torch.float64)
0.323086196503918 0.1940622737146995 0.9640287769784173 tensor([[0.0097]], dtype=torch.float64)
0.3231835996382273 0.1941325606664252 0.9640287769784173 tensor([[0.0098]], dtype=torch.float64)
0.3232810615199035 0.19420289855072465 0.9640287769784173 tensor([[0.0098]], dtype=torch.float64)
0.3233785822021116 0.19427328742297933 0.9640287769784173 tensor([[0.0098]], dtype=torch.float64)
0.32347616173808086 0.194

0.34983277591973244 0.21487263763352507 0.9406474820143885 tensor([[0.0167]], dtype=torch.float64)
0.3499498159919706 0.21496095355528155 0.9406474820143885 tensor([[0.0167]], dtype=torch.float64)
0.3500669344042838 0.21504934210526316 0.9406474820143885 tensor([[0.0167]], dtype=torch.float64)
0.35010060362173046 0.21516900247320692 0.9388489208633094 tensor([[0.0170]], dtype=torch.float64)
0.3502180476350218 0.21525773195876288 0.9388489208633094 tensor([[0.0170]], dtype=torch.float64)
0.35033557046979863 0.21534653465346534 0.9388489208633094 tensor([[0.0170]], dtype=torch.float64)
0.3504531722054381 0.21543541064795707 0.9388489208633094 tensor([[0.0171]], dtype=torch.float64)
0.35057085292142376 0.21552436003303055 0.9388489208633094 tensor([[0.0171]], dtype=torch.float64)
0.35068861269734636 0.21561338289962825 0.9388489208633094 tensor([[0.0171]], dtype=torch.float64)
0.35080645161290325 0.21570247933884298 0.9388489208633094 tensor([[0.0171]], dtype=torch.float64)
0.350924369747

0.3884013735215566 0.24648910411622277 0.9154676258992805 tensor([[0.0309]], dtype=torch.float64)
0.3885496183206107 0.24660852713178294 0.9154676258992805 tensor([[0.0309]], dtype=torch.float64)
0.38869797632684233 0.2467280659234125 0.9154676258992805 tensor([[0.0309]], dtype=torch.float64)
0.38884644766997706 0.24684772065955382 0.9154676258992805 tensor([[0.0309]], dtype=torch.float64)
0.3889950324799389 0.24696749150897623 0.9154676258992805 tensor([[0.0310]], dtype=torch.float64)
0.38914373088685017 0.2470873786407767 0.9154676258992805 tensor([[0.0310]], dtype=torch.float64)
0.38929254302103244 0.24720738222438077 0.9154676258992805 tensor([[0.0310]], dtype=torch.float64)
0.38944146901300686 0.24732750242954324 0.9154676258992805 tensor([[0.0311]], dtype=torch.float64)
0.389590508993494 0.24744773942634904 0.9154676258992805 tensor([[0.0311]], dtype=torch.float64)
0.38973966309341507 0.24756809338521402 0.9154676258992805 tensor([[0.0312]], dtype=torch.float64)
0.389888931443891

0.43138974134151686 0.2852173913043478 0.8848920863309353 tensor([[0.0510]], dtype=torch.float64)
0.431578947368421 0.2853828306264501 0.8848920863309353 tensor([[0.0511]], dtype=torch.float64)
0.43176831943835015 0.28554846198491 0.8848920863309353 tensor([[0.0511]], dtype=torch.float64)
0.4319578577699736 0.2857142857142857 0.8848920863309353 tensor([[0.0511]], dtype=torch.float64)
0.4321475625823452 0.28588030214991283 0.8848920863309353 tensor([[0.0512]], dtype=torch.float64)
0.4323374340949033 0.28604651162790695 0.8848920863309353 tensor([[0.0513]], dtype=torch.float64)
0.43252747252747253 0.2862129144851658 0.8848920863309353 tensor([[0.0513]], dtype=torch.float64)
0.4327176781002639 0.28637951105937137 0.8848920863309353 tensor([[0.0513]], dtype=torch.float64)
0.432908051033876 0.2865463016889924 0.8848920863309353 tensor([[0.0514]], dtype=torch.float64)
0.43309859154929586 0.2867132867132867 0.8848920863309353 tensor([[0.0515]], dtype=torch.float64)
0.43328929986789955 0.28688

0.4936120789779327 0.3644939965694683 0.7643884892086331 tensor([[0.1180]], dtype=torch.float64)
0.4938988959907031 0.3648068669527897 0.7643884892086331 tensor([[0.1181]], dtype=torch.float64)
0.49418604651162795 0.3651202749140893 0.7643884892086331 tensor([[0.1183]], dtype=torch.float64)
0.4944735311227457 0.3654342218400688 0.7643884892086331 tensor([[0.1186]], dtype=torch.float64)
0.49476135040745056 0.3657487091222031 0.7643884892086331 tensor([[0.1189]], dtype=torch.float64)
0.4950495049504951 0.3660637381567614 0.7643884892086331 tensor([[0.1191]], dtype=torch.float64)
0.49532710280373843 0.36678200692041524 0.762589928057554 tensor([[0.1198]], dtype=torch.float64)
0.49561659848042083 0.3670995670995671 0.762589928057554 tensor([[0.1199]], dtype=torch.float64)
0.495906432748538 0.36741767764298094 0.762589928057554 tensor([[0.1199]], dtype=torch.float64)
0.49619660620245754 0.3677363399826539 0.762589928057554 tensor([[0.1202]], dtype=torch.float64)
0.49648711943793916 0.368055

In [26]:
results = []
for i in y_pred:
    results.append(0. if i <= 0.3194 else 1.)

results = np.array(results)

print(classification_report(test_y, results))
results

              precision    recall  f1-score   support

         0.0       0.93      0.92      0.93      3402
         1.0       0.55      0.59      0.57       556

    accuracy                           0.87      3958
   macro avg       0.74      0.75      0.75      3958
weighted avg       0.88      0.87      0.88      3958



array([0., 0., 0., ..., 0., 0., 0.])