# Train EpiDeNet (single channel)
In this script, the single-channel version of EpiDeNet is trained.

We first define the different models to train and test.

In [28]:
import tensorflow as tf

INPUT_WINDOW_SIZE = 1016
INPUT_NB_CHANNELS = 1

# Original one-channel epidenet
epidenet_1ch_og = tf.keras.Sequential()
epidenet_1ch_og.add(tf.keras.layers.InputLayer((INPUT_WINDOW_SIZE, INPUT_NB_CHANNELS, 1)))
epidenet_1ch_og.add(tf.keras.layers.Conv2D(4, kernel_size=(4, 1), padding="same"))
epidenet_1ch_og.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_og.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_og.add(tf.keras.layers.MaxPooling2D(pool_size=(8, 1)))
epidenet_1ch_og.add(tf.keras.layers.Conv2D(16, kernel_size=(16, 1), padding="same"))
epidenet_1ch_og.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_og.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_og.add(tf.keras.layers.MaxPooling2D(pool_size=(4, 1)))
epidenet_1ch_og.add(tf.keras.layers.Conv2D(16, kernel_size=(8, 1), padding="same"))
epidenet_1ch_og.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_og.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_og.add(tf.keras.layers.AveragePooling2D(pool_size=(8, 1)))
epidenet_1ch_og.add(tf.keras.layers.Flatten())
epidenet_1ch_og.add(tf.keras.layers.Dense(1, activation="sigmoid"))
# epidenet_1ch_og.summary()

# variant 3.2
epidenet_1ch_v3_2 = tf.keras.Sequential()
epidenet_1ch_v3_2.add(tf.keras.layers.InputLayer((INPUT_WINDOW_SIZE, INPUT_NB_CHANNELS, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.ZeroPadding2D(padding=(2, 0)))
epidenet_1ch_v3_2.add(tf.keras.layers.Conv2D(4, kernel_size=(4, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_v3_2.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_v3_2.add(tf.keras.layers.MaxPooling2D(pool_size=(8, 1)))

epidenet_1ch_v3_2.add(tf.keras.layers.ZeroPadding2D(padding=(2, 0)))
epidenet_1ch_v3_2.add(tf.keras.layers.Conv2D(16, kernel_size=(5, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_v3_2.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_v3_2.add(tf.keras.layers.ZeroPadding2D(padding=(2, 0)))
epidenet_1ch_v3_2.add(tf.keras.layers.Conv2D(16, kernel_size=(5, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_v3_2.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_v3_2.add(tf.keras.layers.ZeroPadding2D(padding=(2, 0)))
epidenet_1ch_v3_2.add(tf.keras.layers.Conv2D(16, kernel_size=(5, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_v3_2.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_v3_2.add(tf.keras.layers.ZeroPadding2D(padding=(2, 0)))
epidenet_1ch_v3_2.add(tf.keras.layers.Conv2D(16, kernel_size=(4, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_v3_2.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_v3_2.add(tf.keras.layers.MaxPooling2D(pool_size=(4, 1)))

epidenet_1ch_v3_2.add(tf.keras.layers.ZeroPadding2D(padding=(2, 0)))
epidenet_1ch_v3_2.add(tf.keras.layers.Conv2D(16, kernel_size=(8, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.BatchNormalization())
epidenet_1ch_v3_2.add(tf.keras.layers.Activation("relu"))
epidenet_1ch_v3_2.add(tf.keras.layers.MaxPooling2D(pool_size=(4, 1)))

epidenet_1ch_v3_2.add(tf.keras.layers.AveragePooling2D(pool_size=(7, 1), strides=(1, 1)))
epidenet_1ch_v3_2.add(tf.keras.layers.Flatten())
epidenet_1ch_v3_2.add(tf.keras.layers.Dense(1, activation="sigmoid"))
epidenet_1ch_v3_2.summary()

Model: "sequential_13"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 zero_padding2d_18 (ZeroPad  (None, 1020, 1, 1)        0         
 ding2D)                                                         
                                                                 
 conv2d_45 (Conv2D)          (None, 1017, 1, 4)        20        
                                                                 
 batch_normalization_43 (Ba  (None, 1017, 1, 4)        16        
 tchNormalization)                                               
                                                                 
 activation_43 (Activation)  (None, 1017, 1, 4)        0         
                                                                 
 max_pooling2d_27 (MaxPooli  (None, 127, 1, 4)         0         
 ng2D)                                                           
                                                     

Load train and test datasets

In [29]:
from brainmepnas import Dataset

dataset = Dataset("data/chbmit_singlech")
train_data = dataset.get_data({"5": [1, 2, 3, 4]}, set="train", shuffle=True, shuffle_seed=42)
test_data = dataset.get_data({"5": [0]}, set="test", shuffle=False)

Train models

In [31]:
# Compile the model
monitoring_metrics = [tf.keras.metrics.AUC(num_thresholds=25,
                                           curve='PR',
                                           name="auc_pr")]

epidenet_1ch_og.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1**-4,
                                beta_1=0.9, beta_2=0.999), 
                        loss="binary_crossentropy",
              metrics=monitoring_metrics)
epidenet_1ch_v3_2.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1**-4,
                                beta_1=0.9, beta_2=0.999), 
                          loss="binary_crossentropy",
              metrics=monitoring_metrics)

# Callbacks
callbacks = [tf.keras.callbacks.EarlyStopping(monitor="val_loss",
                                                      patience=10,
                                                      mode="min",
                                                      start_from_epoch=10)]

# Train model
epidenet_1ch_og.fit(train_data[0][:,:1016], train_data[1],
                    validation_split=0.2,
                    epochs=1000, batch_size=256,
                    verbose=1,
                    callbacks=callbacks,
                    use_multiprocessing=False, shuffle=False)
epidenet_1ch_v3_2.fit(train_data[0][:,:1016], train_data[1],
                    validation_split=0.2,
                    epochs=1000, batch_size=256,
                    verbose=1,
                    callbacks=callbacks,
                    use_multiprocessing=False, shuffle=False)

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000


<keras.src.callbacks.History at 0x7fed505dacd0>

Test model on test set



We convert the models to tflite.

In [34]:
# -*- coding: utf-8 -*-

# import built-in module
import tempfile
from typing import Literal, Optional

# import third-party modules
import tensorflow as tf
import numpy as np

# import your own module


def generate_tflite_model(keras_model: tf.keras.Model,
                          input_format: Literal["float", "int8"],
                          output_format: Literal["float", "int8"],
                          representative_input: Optional[np.ndarray] = None):
    """
    Convert the given keras model to a tflite model.

    Parameters
    ----------
    keras_model: tf.keras.Model
        Keras model to convert.
    input_format: Literal["float", "int8"]
        Format of the input to the keras model.
    output_format: Literal["float", "int8"]
        Format of the output of the keras model.
    representative_input: np.ndarray, optional
        Representative input data, used to perform the quantization. If not
        given, quantization is performed using randomly generated data between
        -1 and 1 (float) or -128 and 127 (int8).

    Returns
    -------
    tflite_model: tf.keras.Model
    """
    if input_format == "float":
        input_tensor_type = tf.float32
        representative_dataset_min = -1
        representative_dataset_max = 1
    elif input_format == "int8":
        input_tensor_type = tf.int8
        representative_dataset_min = -128
        representative_dataset_max = 127
    else:
        raise ValueError(f"input_format={input_format} is not supported.")

    if output_format == "float":
        output_tensor_type = tf.float32
    elif output_format == "int8":
        output_tensor_type = tf.int8
    else:
        raise ValueError(f"output_format={output_format} is not supported.")

    keras_model.build()

    if representative_input is None:
        def representative_dataset():
            nb_samples = 100
            min_value = representative_dataset_min
            max_value = representative_dataset_max
            input_shape = keras_model.input_shape[1:]
            for i in range(nb_samples):
                data = (max_value - min_value) * np.random.random_sample(input_shape) + min_value
                data = data.astype(np.float32)
                yield [np.expand_dims(data, axis=0)]
    else:
        def representative_dataset():
            for x in representative_input:
                yield [np.expand_dims(x, axis=[0, 3]).astype(
                np.float32)]

    # Bug: tensorflow 2.16.1
    # converter.convert() raises AttributeError, see https://github.com/tensorflow/tensorflow/issues/63867
    # Fix is to use save model.
    # converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
    with tempfile.TemporaryDirectory() as tmp_dir:
        keras_model.export(tmp_dir)
        converter = tf.lite.TFLiteConverter.from_saved_model(tmp_dir)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]  # Should always be set to [tf.lite.Optimize.DEFAULT]
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]  # Should always be set to [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = input_tensor_type  # Either tf.float32, tf.int8 (recommended), tf.uint8
        converter.inference_output_type = output_tensor_type  # Either tf.float32, tf.int8 (recommended), tf.uint8
        converter.representative_dataset = representative_dataset
        tflite_model = converter.convert()

    return tflite_model

In [61]:
epidenet_1ch_og_tflite = generate_tflite_model(epidenet_1ch_og, input_format="float", output_format="float", representative_input=train_data[0][:,:1016])
with open("epidenet_1ch_og.tflite", "wb") as f:
    f.write(epidenet_1ch_og_tflite)
    
epidenet_1ch_v3_2_tflite = generate_tflite_model(epidenet_1ch_v3_2, input_format="float", output_format="float", representative_input=train_data[0][:,:1016])
with open("epidenet_1ch_v3_2.tflite", "wb") as f:
    f.write(epidenet_1ch_v3_2_tflite)

INFO:tensorflow:Assets written to: /tmp/tmpd7dh4qao/assets


INFO:tensorflow:Assets written to: /tmp/tmpd7dh4qao/assets


Saved artifact at '/tmp/tmpd7dh4qao'. The following endpoints are available:

* Endpoint 'serve'
  Args:
    args_0: float32 Tensor, shape=(None, 1016, 1, 1)
  Returns:
    float32 Tensor, shape=(None, 1)


2024-09-02 17:26:16.969913: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2024-09-02 17:26:16.969959: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2024-09-02 17:26:16.970092: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpd7dh4qao
2024-09-02 17:26:16.970499: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2024-09-02 17:26:16.970505: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /tmp/tmpd7dh4qao
2024-09-02 17:26:16.971603: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2024-09-02 17:26:16.993043: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/tmpd7dh4qao
2024-09-02 17:26:17.001339: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 31246 m

INFO:tensorflow:Assets written to: /tmp/tmpuiltd5da/assets


INFO:tensorflow:Assets written to: /tmp/tmpuiltd5da/assets


Saved artifact at '/tmp/tmpuiltd5da'. The following endpoints are available:

* Endpoint 'serve'
  Args:
    args_0: float32 Tensor, shape=(None, 1016, 1, 1)
  Returns:
    float32 Tensor, shape=(None, 1)


2024-09-02 17:26:18.867921: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:364] Ignored output_format.
2024-09-02 17:26:18.867978: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:367] Ignored drop_control_dependency.
2024-09-02 17:26:18.868116: I tensorflow/cc/saved_model/reader.cc:45] Reading SavedModel from: /tmp/tmpuiltd5da
2024-09-02 17:26:18.868660: I tensorflow/cc/saved_model/reader.cc:91] Reading meta graph with tags { serve }
2024-09-02 17:26:18.868668: I tensorflow/cc/saved_model/reader.cc:132] Reading SavedModel debug info (if present) from: /tmp/tmpuiltd5da
2024-09-02 17:26:18.870370: I tensorflow/cc/saved_model/loader.cc:231] Restoring SavedModel bundle.
2024-09-02 17:26:18.901595: I tensorflow/cc/saved_model/loader.cc:215] Running initialization op on SavedModel bundle at path: /tmp/tmpuiltd5da
2024-09-02 17:26:18.912237: I tensorflow/cc/saved_model/loader.cc:314] SavedModel load for tags { serve }; Status: success: OK. Took 44120 m

Test quantized model

In [62]:
from brainmepnas import AccuracyMetrics
import csv


interpreter_og = tf.lite.Interpreter("epidenet_1ch_og.tflite")
interpreter_og.allocate_tensors()
input_details = interpreter_og.get_input_details()
output_details = interpreter_og.get_output_details()
predicted_og = []
formatted_test_data = np.expand_dims(test_data[0][:,:1016].astype(np.float32), axis=3)
for sample in formatted_test_data:
    interpreter_og.set_tensor(input_details[0]["index"], [sample])
    interpreter_og.invoke()
    predicted_og.append(interpreter_og.get_tensor(output_details[0]["index"]))

am_og = AccuracyMetrics(test_data[1].flatten(), np.array(predicted_og).flatten(), sample_duration=1016/256, sample_offset=2, threshold="max_f_score")

interpreter_v3_2 = tf.lite.Interpreter("epidenet_1ch_v3_2.tflite")
interpreter_v3_2.allocate_tensors()
input_details = interpreter_v3_2.get_input_details()
output_details = interpreter_v3_2.get_output_details()
predicted_v3_2 = []
formatted_test_data = np.expand_dims(test_data[0][:,:1016].astype(np.float32), axis=3)
for sample in formatted_test_data:
    interpreter_v3_2.set_tensor(input_details[0]["index"], [sample])
    interpreter_v3_2.invoke()
    predicted_v3_2.append(interpreter_v3_2.get_tensor(output_details[0]["index"]))

am_v3_2 = AccuracyMetrics(test_data[1].flatten(), np.array(predicted_v3_2).flatten(), sample_duration=1016/256, sample_offset=2, threshold="max_f_score")

with open("epidenet_1ch_og.csv", "w") as f:
    am_dict = am_og.as_dict()
    writer = csv.DictWriter(f, fieldnames=am_dict.keys())
    writer.writeheader()
    writer.writerow(am_dict)
    
with open("epidenet_1ch_v3_2.csv", "w") as f:
    am_dict = am_v3_2.as_dict()
    writer = csv.DictWriter(f, fieldnames=am_dict.keys())
    writer.writeheader()
    writer.writerow(am_dict)