# Predicting sex from brain rhythms with deep learning

Dans un premier temps, l'objectif est de reproduire le réseau décrit dans l'article original 

https://www.nature.com/articles/s41598-018-21495-7.pdf

In [91]:
import cv2
import numpy as np
from tensorflow import keras

from tensorflow.keras.models import Sequential
import tensorflow.keras.layers as layers
from kerastuner.tuners import RandomSearch

import sklearn


from core.data import *

### Loading and formatting data 

In [25]:
raw_x_train, raw_y_train = load_x('data/x_train.h5'), load_y('data/y_train.csv')

Started loading file data/x_train.h5
Finished loading the file.
Started loading file data/y_train.csv
Finished loading the file.


In [46]:
# here, flatten x and convert to format (N, 7, 500, 1) and y_train in format (N, 2)

def format_data(x_input):
    x = flatten_x(x_input)
    x = reorder_nhwc(x)

    # here, resize x to format 24, 256, 1

    n, h, w, c = len(x), 24, 256, 1
    x_resized = np.zeros((n, h, w, c))

    # resize each element of x
    for i in range(len(x)):
        x_resized[i, :, :, 0] = cv2.resize(x[i, :, :, 0], dsize=(w, h))

    return x_resized

x_resized = format_data(raw_x_train)
y = np.argmax(flatten_y(raw_y_train, 40), axis=1)

print(x_resized.shape, y.shape)

(37840, 24, 256, 1) (37840,)


### Definition of the model

In [94]:
# need to convert 7 x 500 into 24 x 256 ? linear interpolation + subsampling
# Hint : use openCV resize

def get_model(hp):
    model = Sequential()
    # our model is based on convolutional NN + reLu + average pooling
    # cf cours pour la justification théorique
    model.add(layers.Conv2D(filters=hp.Int('units', min_value=50, max_value=200, step=50), 
                            kernel_size=(3, 3), activation='relu', data_format='channels_last', input_shape=(h, w, c)))
    model.add(layers.AveragePooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(rate=0.25))

    model.add(layers.Conv2D(filters=hp.Int('units', min_value=50, max_value=200, step=50),
                            kernel_size=(3, 3), activation='relu', data_format='channels_last'))
    model.add(layers.AveragePooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(rate=0.25))

    model.add(layers.Conv2D(filters=hp.Int('units', min_value=100, max_value=400, step=50),
                            kernel_size=(3, 3), activation='relu', data_format='channels_last', input_shape=(h, w, c)))
    model.add(layers.AveragePooling2D(pool_size=(2, 2)))
    model.add(layers.Dropout(rate=0.25))

    model.add(layers.Conv2D(filters=hp.Int('units', min_value=100, max_value=400, step=50), kernel_size=(1, 7), activation='relu'))
    model.add(layers.AveragePooling2D(pool_size=(1, 2)))
    model.add(layers.Dropout(rate=0.25))

    model.add(layers.Conv2D(filters=hp.Int('units', min_value=50, max_value=200, step=50), kernel_size=(1, 3), activation='relu'))
    model.add(layers.Conv2D(filters=hp.Int('units', min_value=50, max_value=200, step=50), kernel_size=(1, 3), activation='relu'))

    # 2 units for binary_crossentropy
    model.add(layers.Flatten())
    # model.add(layers.Dense(units=2, activation='softmax'))
    model.add(layers.Dense(units=1, activation='sigmoid'))
    
    model.compile(loss='mean_squared_error', # can be binary_crossentropy
              optimizer='sgd',
              metrics=['accuracy'])
    return model

In [99]:
tuner = RandomSearch(
    get_model,
    objective='val_accuracy',
    max_trials=5,
    executions_per_trial=3,
    directory='my_dir',
    project_name='dreemmva')

In [97]:
class_weight = sklearn.utils.class_weight.compute_class_weight('balanced', np.unique(y), y)
print('Class weight is ', class_weight)

Class weight is  [0.64179104 2.26315789]


In [None]:
# model.fit(x_resized, y, epochs=10, batch_size=32, class_weight=class_weight)

tuner.search(x_resized, y, epochs=5)

Train on 37840 samples
Epoch 1/5


 5664/37840 [===>..........................] - ETA: 26:10 - loss: 0.3115 - accuracy: 0.343 - ETA: 22:18 - loss: 0.2859 - accuracy: 0.468 - ETA: 19:43 - loss: 0.2820 - accuracy: 0.500 - ETA: 19:22 - loss: 0.2879 - accuracy: 0.484 - ETA: 19:00 - loss: 0.2744 - accuracy: 0.525 - ETA: 18:57 - loss: 0.2723 - accuracy: 0.552 - ETA: 18:47 - loss: 0.2595 - accuracy: 0.584 - ETA: 18:29 - loss: 0.2546 - accuracy: 0.597 - ETA: 18:13 - loss: 0.2503 - accuracy: 0.611 - ETA: 18:10 - loss: 0.2446 - accuracy: 0.631 - ETA: 18:11 - loss: 0.2432 - accuracy: 0.642 - ETA: 18:09 - loss: 0.2360 - accuracy: 0.656 - ETA: 18:07 - loss: 0.2280 - accuracy: 0.670 - ETA: 18:03 - loss: 0.2219 - accuracy: 0.680 - ETA: 17:48 - loss: 0.2186 - accuracy: 0.691 - ETA: 17:34 - loss: 0.2210 - accuracy: 0.691 - ETA: 17:16 - loss: 0.2212 - accuracy: 0.694 - ETA: 16:58 - loss: 0.2181 - accuracy: 0.701 - ETA: 16:44 - loss: 0.2172 - accuracy: 0.702 - ETA: 16:28 - loss: 0.2158 - accuracy: 0.707 - ETA: 16:13 - loss: 0.2147 - accur











Epoch 2/5


 5696/37840 [===>..........................] - ETA: 11:33 - loss: 0.1861 - accuracy: 0.812 - ETA: 11:33 - loss: 0.1455 - accuracy: 0.859 - ETA: 11:33 - loss: 0.1613 - accuracy: 0.833 - ETA: 11:34 - loss: 0.1672 - accuracy: 0.820 - ETA: 11:34 - loss: 0.1734 - accuracy: 0.800 - ETA: 11:35 - loss: 0.1976 - accuracy: 0.765 - ETA: 11:36 - loss: 0.1998 - accuracy: 0.754 - ETA: 11:36 - loss: 0.1990 - accuracy: 0.753 - ETA: 11:36 - loss: 0.2001 - accuracy: 0.756 - ETA: 11:34 - loss: 0.2046 - accuracy: 0.753 - ETA: 11:32 - loss: 0.2057 - accuracy: 0.750 - ETA: 11:32 - loss: 0.2032 - accuracy: 0.757 - ETA: 11:31 - loss: 0.2076 - accuracy: 0.750 - ETA: 11:32 - loss: 0.2063 - accuracy: 0.750 - ETA: 11:32 - loss: 0.2055 - accuracy: 0.750 - ETA: 11:32 - loss: 0.2015 - accuracy: 0.755 - ETA: 11:31 - loss: 0.1944 - accuracy: 0.766 - ETA: 11:30 - loss: 0.1921 - accuracy: 0.767 - ETA: 11:30 - loss: 0.1933 - accuracy: 0.766 - ETA: 11:29 - loss: 0.1945 - accuracy: 0.765 - ETA: 11:28 - loss: 0.1946 - accur











Epoch 3/5


 5696/37840 [===>..........................] - ETA: 11:54 - loss: 0.1792 - accuracy: 0.781 - ETA: 11:54 - loss: 0.1751 - accuracy: 0.781 - ETA: 11:51 - loss: 0.1749 - accuracy: 0.781 - ETA: 11:52 - loss: 0.1854 - accuracy: 0.757 - ETA: 11:51 - loss: 0.1753 - accuracy: 0.781 - ETA: 11:53 - loss: 0.1787 - accuracy: 0.776 - ETA: 11:54 - loss: 0.1804 - accuracy: 0.772 - ETA: 11:53 - loss: 0.1734 - accuracy: 0.789 - ETA: 11:53 - loss: 0.1714 - accuracy: 0.791 - ETA: 11:51 - loss: 0.1713 - accuracy: 0.790 - ETA: 11:51 - loss: 0.1722 - accuracy: 0.786 - ETA: 11:51 - loss: 0.1730 - accuracy: 0.786 - ETA: 11:51 - loss: 0.1686 - accuracy: 0.793 - ETA: 11:51 - loss: 0.1708 - accuracy: 0.790 - ETA: 11:49 - loss: 0.1758 - accuracy: 0.783 - ETA: 11:49 - loss: 0.1773 - accuracy: 0.781 - ETA: 11:48 - loss: 0.1787 - accuracy: 0.779 - ETA: 11:48 - loss: 0.1766 - accuracy: 0.783 - ETA: 11:48 - loss: 0.1795 - accuracy: 0.776 - ETA: 11:47 - loss: 0.1788 - accuracy: 0.778 - ETA: 11:47 - loss: 0.1798 - accur











Epoch 4/5


 5696/37840 [===>..........................] - ETA: 13:01 - loss: 0.1770 - accuracy: 0.781 - ETA: 12:23 - loss: 0.1642 - accuracy: 0.796 - ETA: 14:01 - loss: 0.1628 - accuracy: 0.802 - ETA: 15:02 - loss: 0.1511 - accuracy: 0.820 - ETA: 14:30 - loss: 0.1628 - accuracy: 0.793 - ETA: 14:18 - loss: 0.1702 - accuracy: 0.781 - ETA: 13:54 - loss: 0.1627 - accuracy: 0.794 - ETA: 13:47 - loss: 0.1643 - accuracy: 0.793 - ETA: 13:49 - loss: 0.1709 - accuracy: 0.781 - ETA: 13:48 - loss: 0.1747 - accuracy: 0.775 - ETA: 13:37 - loss: 0.1745 - accuracy: 0.775 - ETA: 13:36 - loss: 0.1738 - accuracy: 0.778 - ETA: 13:36 - loss: 0.1732 - accuracy: 0.778 - ETA: 13:32 - loss: 0.1727 - accuracy: 0.779 - ETA: 13:28 - loss: 0.1744 - accuracy: 0.775 - ETA: 13:34 - loss: 0.1716 - accuracy: 0.779 - ETA: 13:36 - loss: 0.1664 - accuracy: 0.788 - ETA: 13:42 - loss: 0.1735 - accuracy: 0.776 - ETA: 13:48 - loss: 0.1713 - accuracy: 0.779 - ETA: 13:50 - loss: 0.1722 - accuracy: 0.778 - ETA: 13:46 - loss: 0.1749 - accur











Epoch 5/5


 5696/37840 [===>..........................] - ETA: 12:28 - loss: 0.1839 - accuracy: 0.750 - ETA: 13:24 - loss: 0.1741 - accuracy: 0.765 - ETA: 14:36 - loss: 0.1835 - accuracy: 0.750 - ETA: 14:06 - loss: 0.1748 - accuracy: 0.765 - ETA: 13:44 - loss: 0.1729 - accuracy: 0.768 - ETA: 14:00 - loss: 0.1741 - accuracy: 0.765 - ETA: 14:06 - loss: 0.1695 - accuracy: 0.776 - ETA: 14:08 - loss: 0.1687 - accuracy: 0.777 - ETA: 14:07 - loss: 0.1785 - accuracy: 0.760 - ETA: 14:10 - loss: 0.1775 - accuracy: 0.762 - ETA: 14:08 - loss: 0.1750 - accuracy: 0.767 - ETA: 14:05 - loss: 0.1706 - accuracy: 0.776 - ETA: 13:58 - loss: 0.1690 - accuracy: 0.781 - ETA: 13:59 - loss: 0.1712 - accuracy: 0.776 - ETA: 13:52 - loss: 0.1712 - accuracy: 0.777 - ETA: 13:46 - loss: 0.1707 - accuracy: 0.777 - ETA: 13:38 - loss: 0.1685 - accuracy: 0.781 - ETA: 13:30 - loss: 0.1654 - accuracy: 0.786 - ETA: 13:24 - loss: 0.1685 - accuracy: 0.781 - ETA: 13:17 - loss: 0.1696 - accuracy: 0.779 - ETA: 13:14 - loss: 0.1678 - accur











Train on 37840 samples
Epoch 1/5


 5696/37840 [===>..........................] - ETA: 18:59 - loss: 0.1346 - accuracy: 0.906 - ETA: 15:22 - loss: 0.1471 - accuracy: 0.828 - ETA: 14:17 - loss: 0.1469 - accuracy: 0.843 - ETA: 13:46 - loss: 0.1753 - accuracy: 0.812 - ETA: 13:35 - loss: 0.1760 - accuracy: 0.806 - ETA: 13:29 - loss: 0.2008 - accuracy: 0.770 - ETA: 13:17 - loss: 0.1933 - accuracy: 0.790 - ETA: 13:05 - loss: 0.1847 - accuracy: 0.804 - ETA: 12:56 - loss: 0.1834 - accuracy: 0.798 - ETA: 12:52 - loss: 0.1747 - accuracy: 0.809 - ETA: 12:47 - loss: 0.1769 - accuracy: 0.806 - ETA: 12:49 - loss: 0.1785 - accuracy: 0.804 - ETA: 12:48 - loss: 0.1782 - accuracy: 0.807 - ETA: 13:02 - loss: 0.1781 - accuracy: 0.805 - ETA: 13:21 - loss: 0.1770 - accuracy: 0.804 - ETA: 13:36 - loss: 0.1776 - accuracy: 0.804 - ETA: 13:43 - loss: 0.1819 - accuracy: 0.797 - ETA: 13:51 - loss: 0.1868 - accuracy: 0.789 - ETA: 14:04 - loss: 0.1834 - accuracy: 0.794 - ETA: 14:03 - loss: 0.1836 - accuracy: 0.793 - ETA: 14:05 - loss: 0.1838 - accur

 8288/37840 [=====>........................] - ETA: 12:06 - loss: 0.1908 - accuracy: 0.781 - ETA: 12:09 - loss: 0.1912 - accuracy: 0.780 - ETA: 12:11 - loss: 0.1915 - accuracy: 0.780 - ETA: 12:14 - loss: 0.1913 - accuracy: 0.780 - ETA: 12:16 - loss: 0.1915 - accuracy: 0.780 - ETA: 12:19 - loss: 0.1916 - accuracy: 0.780 - ETA: 12:21 - loss: 0.1917 - accuracy: 0.780 - ETA: 12:22 - loss: 0.1915 - accuracy: 0.780 - ETA: 12:25 - loss: 0.1921 - accuracy: 0.779 - ETA: 12:26 - loss: 0.1923 - accuracy: 0.779 - ETA: 12:28 - loss: 0.1925 - accuracy: 0.778 - ETA: 12:31 - loss: 0.1924 - accuracy: 0.778 - ETA: 12:33 - loss: 0.1927 - accuracy: 0.778 - ETA: 12:36 - loss: 0.1925 - accuracy: 0.778 - ETA: 12:39 - loss: 0.1923 - accuracy: 0.778 - ETA: 12:42 - loss: 0.1920 - accuracy: 0.779 - ETA: 12:45 - loss: 0.1920 - accuracy: 0.779 - ETA: 12:48 - loss: 0.1923 - accuracy: 0.778 - ETA: 12:51 - loss: 0.1920 - accuracy: 0.779 - ETA: 12:54 - loss: 0.1921 - accuracy: 0.779 - ETA: 12:56 - loss: 0.1920 - accur

In [55]:
raw_x_test = load_x('data/x_test.h5')

Started loading file data/x_test.h5
Finished loading the file.


In [74]:
x_test = format_data(raw_x_test)

print(x_test.shape)

raw_predictions = model.predict(x_test)
y_test = average_predictions(raw_predictions, nb_trials = 40)
print(y_test)

(37840, 24, 256, 1)
[0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 0.
 0. 0. 0. 0. 0. 0. 0. 0. 0. 0. 