In [None]:
import h5py
import numpy as np
from sklearn import model_selection
import matplotlib.pyplot as plt
from sklearn import metrics
import os
import tensorflow as tf
from tensorflow.keras import Model, Input
from tensorflow.keras.layers import Conv2D, UpSampling2D, MaxPooling2D, AveragePooling2D, Attention
from tensorflow.keras.layers import ELU, BatchNormalization, Reshape, Concatenate, Dropout, Add, Multiply

from utils import SeizureState, setup_tf, AttentionPooling, BiasedConv

In [None]:
val_path = 'PATH_TO_DATASET.h5'
saved_predictions = 'PATH_TO_PREDICTIONS.h5'
network_path = 'PATH_TO_NETWORK_WEIGHTS.h5'

fs = 200
n_channels = 18
seizure = 'seiz'
background = 'bckg'

In [None]:
setup_tf()

In [None]:
with h5py.File(val_path, 'r') as f:
    file_names_test = []
    signals_test = []
    
    file_names_ds = f['filenames']
    signals_ds = f['signals']
    
    for i in range(len(signals_ds)):
        file_names_test.append(file_names_ds[i])
        data = np.asarray(np.vstack(signals_ds[i]).T, dtype=np.float32)
        mean = np.mean(data, axis=0)
        std = np.std(data, axis=0)
        signals_test.append((data-mean)/(std+1e-8))

# Seizure detection

### Building U-Net

In [None]:
n_filters = 8

In [None]:
input_seq = Input(shape=(None, n_channels, 1))

x = Conv2D(filters=n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(input_seq)
x = BatchNormalization()(x)
lvl0 = ELU()(x)

x = MaxPooling2D(pool_size=(4, 1), padding='same')(lvl0)
x = Conv2D(filters=2*n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
lvl1 = ELU()(x)

x = MaxPooling2D(pool_size=(4, 1), padding='same')(lvl1)
x = Conv2D(filters=4*n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
lvl2 = ELU()(x)

x = MaxPooling2D(pool_size=(4, 1), padding='same')(lvl2)
x = Conv2D(filters=4*n_filters, kernel_size=(7, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
lvl3 = ELU()(x)

x = MaxPooling2D(pool_size=(4, 1), padding='same')(lvl3)
x = Conv2D(filters=8*n_filters, kernel_size=(3, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
lvl4 = ELU()(x)

x = MaxPooling2D(pool_size=(4, 1), padding='same')(lvl4)
x = Conv2D(filters=8*n_filters, kernel_size=(3, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)
lvl5 = x

x = MaxPooling2D(pool_size=(1, 20), padding='same')(lvl5)
x = Conv2D(filters=4*n_filters, kernel_size=(3, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)
x = Dropout(rate=0.5)(x)
x = Conv2D(filters=4*n_filters, kernel_size=(3, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)
x = Dropout(rate=0.5)(x)

up4 = UpSampling2D(size=(4, 1))(x)
att4 = AttentionPooling(filters=4*n_filters, channels=n_channels)([up4, lvl4])

x = Concatenate(axis=-1)([up4, att4])
x = Conv2D(filters=4*n_filters, kernel_size=(3, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)

up3 = UpSampling2D(size=(4, 1))(x)
att3 = AttentionPooling(filters=4*n_filters, channels=n_channels)([up3, lvl3])

x = Concatenate(axis=-1)([up3, att3])
x = Conv2D(filters=4*n_filters, kernel_size=(7, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)

up2 = UpSampling2D(size=(4, 1))(x)
att2 = AttentionPooling(filters=4*n_filters, channels=n_channels)([up2, lvl2])

x = Concatenate(axis=-1)([up2, att2])
x = Conv2D(filters=4*n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)


up1 = UpSampling2D(size=(4, 1))(x)
att1 = AttentionPooling(filters=4*n_filters, channels=n_channels)([up1, lvl1])

x = Concatenate(axis=-1)([up1, att1])
x = Conv2D(filters=4*n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)

up0 = UpSampling2D(size=(4, 1))(x)
att0 = AttentionPooling(filters=4*n_filters, channels=n_channels)([up0, lvl0])
x = Concatenate(axis=-1)([up0, att0])
x = Conv2D(filters=4*n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)
x = Conv2D(filters=4*n_filters, kernel_size=(15, 1), strides=(1, 1), padding='same', activation=None)(x)
x = BatchNormalization()(x)
x = ELU()(x)
output = Conv2D(filters=1, kernel_size=(15, 1), strides=(1, 1), padding='same', activation='sigmoid')(x)

unet = Model(input_seq, output)

In [None]:
unet.load_weights(network_path)
unet.summary()

### Prediction step

In [None]:
y_probas = []
reduction = 4096//4
with tf.device('cpu:0'):
    for signal in signals_test:
        signal = signal[:len(signal)//reduction*reduction, :]
        prediction = unet.predict(signal[np.newaxis, :, :, np.newaxis])[0, :, 0, 0]
        y_probas.append(prediction)

# Saving predictions

In [None]:
dt_fl = h5py.vlen_dtype(np.dtype('float32'))
dt_str = h5py.special_dtype(vlen=str)

with h5py.File(saved_predictions, 'w') as f:
    dset_signals = f.create_dataset('signals', (len(file_names_test),), dtype=dt_fl)
    dset_file_names = f.create_dataset('filenames', (len(file_names_test),), dtype=dt_str)
    
    for i in range(len(file_names_test)):
        dset_signals[i] = y_probas[i]
        dset_file_names[i] = file_names_test[i]