In [1]:
import numpy as np
from matplotlib import pyplot as plt
import mne
mne.set_log_level(verbose=None)
import tensorflow as tf
import os
import re
from tqdm import tqdm

from sklearn.model_selection import StratifiedShuffleSplit

from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Activation, Permute, Dropout
from tensorflow.keras.layers import Conv2D, MaxPooling2D, AveragePooling2D
from tensorflow.keras.layers import SeparableConv2D, DepthwiseConv2D
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import SpatialDropout2D
from tensorflow.keras.regularizers import l1_l2
from tensorflow.keras.layers import Input, Flatten
from tensorflow.keras.constraints import max_norm
from tensorflow.keras import backend as K
from tensorflow.keras.losses import BinaryCrossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.metrics import TopKCategoricalAccuracy, AUC, Precision, Recall
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

In [2]:
N_FILES = -1

N_CLASSES = 2
N_CHANNELS = 25
N_SAMPLES = 128

WHERE = "kaggle"
if WHERE=="colab":
    BCICIV_PATH = "/content/BCICIV_2a_gdf"
    BATCH_SIZE = 32
elif WHERE=="kaggle":
    BCICIV_PATH = "../input/eeg-data-augmentation/BCICIV_2a_gdf"
    BATCH_SIZE = 64
elif WHERE=="home":
    PATH_LABELS = "../toy_data/BCICIV_2a_gdf"
    BATCH_SIZE = 4
SHUFFLE_BUFFER_SIZE = BATCH_SIZE * 2 

FIRST_SPLIT = 0.12
SECOND_SPLIT = 0.15

EPOCHS = 5
LEARNING_RATE = 1e-3

In [4]:
data, user_idx, labels = load_data()
X_train, y_train, X_valid, y_valid, X_test, y_test = preprocess_data(data, user_idx, labels)
train_dataset, valid_dataset, test_dataset = make_tf_dataset(X_train, y_train, X_valid, y_valid, X_test, y_test)

Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A01T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...


  next(self.gen)


Reading 0 ... 672527  =      0.000 ...  2690.108 secs...
EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A04E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 660046  =      0.000 ...  2640.184 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A08T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 675269  =      0.000 ...  2701.076 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A03E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 648774  =      0.000 ...  2595.096 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A09E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 675097  =      0.000 ...  2700.388 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A06E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 666372  =      0.000 ...  2665.488 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A06T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 678979  =      0.000 ...  2715.916 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A02T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 677168  =      0.000 ...  2708.672 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A05E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 679862  =      0.000 ...  2719.448 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A03T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 660529  =      0.000 ...  2642.116 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A07E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 673134  =      0.000 ...  2692.536 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A09T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 673327  =      0.000 ...  2693.308 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A08E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 687791  =      0.000 ...  2751.164 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A07T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 681070  =      0.000 ...  2724.280 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A05T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 686119  =      0.000 ...  2744.476 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A02E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 662665  =      0.000 ...  2650.660 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A01E.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 686999  =      0.000 ...  2747.996 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.
Extracting EDF parameters from /kaggle/input/eeg-data-augmentation/BCICIV_2a_gdf/A04T.gdf...
GDF file detected
Setting channel info structure...
Could not determine channel type of the following channels, they will be set as EEG:
EEG-Fz, EEG, EEG, EEG, EEG, EEG, EEG, EEG-C3, EEG, EEG-Cz, EEG, EEG-C4, EEG, EEG, EEG, EEG, EEG, EEG, EEG, EEG-Pz, EEG, EEG, EOG-left, EOG-central, EOG-right
Creating raw.info structure...
Reading 0 ... 600914  =      0.000 ...  2403.656 secs...


  next(self.gen)


EEG channel type selected for re-referencing
Applying average reference.
Applying a custom ('EEG',) reference.


100%|██████████| 17/17 [00:10<00:00,  1.64it/s]


--> This dataset contains 94107 examples, 25 channels and size of sample is 128
--> 25 channels: 22 EEG and 3 EOG
--> Shape of the users and labels arrays: (12045696,) (12045696,)
--> Final shape of data: (94107, 128, 25)


2022-11-20 19:37:04.169496: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-20 19:37:04.259992: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-20 19:37:04.260784: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:937] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2022-11-20 19:37:04.263776: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 AVX512F FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compil

(70391, 25, 128) (12423, 25, 128) (11293, 25, 128) (70391, 2) (12423, 2) (11293, 2)
(array(['1', '2', '3', '4', '5', '6', '7', '8', '9'], dtype='<U1'), array([1016832, 1002112,  979328,  943104, 1021568, 1006336, 1012864,
       1019392, 1008512])) (array(['1', '2', '3', '4', '5', '6', '7', '8', '9'], dtype='<U1'), array([179456, 176896, 172800, 166400, 180352, 177536, 178688, 179968,
       178048])) (array(['1', '2', '3', '4', '5', '6', '7', '8', '9'], dtype='<U1'), array([163200, 160768, 157056, 151296, 163968, 161408, 162432, 163584,
       161792]))


2022-11-20 19:37:08.385152: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 1802009600 exceeds 10% of free system memory.
2022-11-20 19:37:10.370632: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 1802009600 exceeds 10% of free system memory.
2022-11-20 19:37:11.940641: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 318028800 exceeds 10% of free system memory.
2022-11-20 19:37:12.290445: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 318028800 exceeds 10% of free system memory.
2022-11-20 19:37:12.660245: W tensorflow/core/framework/cpu_allocator_impl.cc:80] Allocation of 289100800 exceeds 10% of free system memory.


In [12]:
print(np.unique(y_train, return_counts=True))
print(np.unique(y_valid, return_counts=True))
print(np.unique(y_test, return_counts=True))

(array([0., 1.], dtype=float32), array([70391, 70391]))
(array([0., 1.], dtype=float32), array([12423, 12423]))
(array([0., 1.], dtype=float32), array([11293, 11293]))


In [5]:
def get_callbacks():
    return [
        # ModelCheckpoint(
        #     "best_model.h5", save_best_only=True, monitor="loss"
        # ),
        # ReduceLROnPlateau(
        #     monitor="val_top_k_categorical_accuracy",
        #     factor=0.2,
        #     patience=2,
        #     min_lr=0.000001,
        # ),
        EarlyStopping(
            monitor="val_loss",
            min_delta=1e-3,
            patience=10,
            verbose=1,
            mode="auto",
            baseline=None,
            restore_best_weights=True,
        ),
    ]

# https://arxiv.org/pdf/1611.08024.pdf
def EEGNet(n_classes, n_channels=64, n_samples=128, kernel_length=64, n_filters1=8, 
           n_filters2=16, depth_multiplier=2, norm_rate=0.25, dropout_rate=0.5, 
           dropoutType="Dropout"):

    if dropoutType == "SpatialDropout2D":
        dropoutType=SpatialDropout2D
    elif dropoutType == "Dropout":
        dropoutType=Dropout

    inputs = Input(shape=(n_channels, n_samples, 1))

    block1 = Conv2D(n_filters1, (1, kernel_length), padding="same", input_shape=(n_channels, n_samples, 1), use_bias=False)(inputs)
    block1 = BatchNormalization()(block1)
    block1 = DepthwiseConv2D((n_channels, 1), use_bias=False, depth_multiplier=depth_multiplier, depthwise_constraint=max_norm(1.0))(block1)
    block1 = BatchNormalization()(block1)
    block1 = Activation("elu")(block1)
    block1 = AveragePooling2D((1, 4))(block1)
    block1 = dropoutType(dropout_rate)(block1)

    block2 = SeparableConv2D(n_filters2, (1, 16), use_bias=False, padding="same")(block1)
    block2 = BatchNormalization()(block2)
    block2 = Activation("elu")(block2)
    block2 = AveragePooling2D((1, 8))(block2)
    block2 = dropoutType(dropout_rate)(block2)
    block2 = Flatten(name="flatten")(block2)

    classifier = Dense(n_classes, name="dense", kernel_constraint=max_norm(norm_rate))(block2)
    classifier = Activation("softmax", name="softmax")(classifier)

    model = Model(inputs=inputs, outputs=classifier) 

    optimizer = Adam(amsgrad=True, learning_rate=LEARNING_RATE)
    loss = BinaryCrossentropy(from_logits=True)

    model.compile(
        optimizer=optimizer,
        loss=loss,
        metrics=[
            TopKCategoricalAccuracy(k=3),
            AUC(),
            Precision(),
            Recall(),
        ],
    )

    model.summary()

    return model

In [6]:
def plot_history_metrics(history):
    total_plots = len(history.history)
    cols = total_plots // 2

    rows = total_plots // cols

    if total_plots % cols != 0:
        rows += 1

    pos = range(1, total_plots + 1)
    plt.figure(figsize=(15, 10))
    for i, (key, value) in enumerate(history.history.items()):
        plt.subplot(rows, cols, pos[i])
        plt.plot(range(len(value)), value)
        plt.title(str(key))
    plt.show()

In [7]:
model = EEGNet(n_classes=N_CLASSES, n_channels=N_CHANNELS, n_samples=N_SAMPLES)

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 25, 128, 1)]      0         
_________________________________________________________________
conv2d (Conv2D)              (None, 25, 128, 8)        512       
_________________________________________________________________
batch_normalization (BatchNo (None, 25, 128, 8)        32        
_________________________________________________________________
depthwise_conv2d (DepthwiseC (None, 1, 128, 16)        400       
_________________________________________________________________
batch_normalization_1 (Batch (None, 1, 128, 16)        64        
_________________________________________________________________
activation (Activation)      (None, 1, 128, 16)        0         
_________________________________________________________________
average_pooling2d (AveragePo (None, 1, 32, 16)         0     

In [8]:
history = model.fit(
    train_dataset,
    epochs=EPOCHS,
    callbacks=get_callbacks(),
    validation_data=valid_dataset,
#     class_weight=???,
)

Epoch 1/5


  '"`binary_crossentropy` received `from_logits=True`, but the `output`'
2022-11-20 19:37:17.809711: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)
2022-11-20 19:37:18.930192: I tensorflow/stream_executor/cuda/cuda_dnn.cc:369] Loaded cuDNN version 8005


Epoch 2/5
Epoch 3/5
Epoch 4/5

KeyboardInterrupt: 

In [None]:
plot_history_metrics(history)