# MultiSPI DL Classification

This notebooks implements a simple deep learning classifier that receives all the SPIs of a subject and predict their label (AVGP or NVGP).

In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['KERAS_BACKEND'] = 'torch'

import torch
import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split

In [None]:
# load data

data = pd.read_csv('data/Julia2018/spis_dosenbach2010_network_fast.csv')
na_mask = data.groupby('spi').apply(lambda x: x.isna().sum().sum(), include_groups=False) > 0
na_spis = na_mask[na_mask].index
data = data.query('spi not in @na_spis')
data = data.pivot(index=['subject', 'label'], columns=['spi'])
data.columns = ['_'.join(col) for col in data.columns.values]

In [None]:

X = data.values
labels = data.index.get_level_values('label').values

# y_encoder = OneHotEncoder()
# y = y_encoder.fit_transform(labels).toarray()

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(labels).reshape(-1, 1)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)

In [50]:


# build model
model = keras.Sequential(
    [
        keras.layers.Input(shape=X.shape[1:]),
        keras.layers.Dense(128, activation='relu'),
        keras.layers.Dense(2, activation='softmax'),
    ]
)

model.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(),
    optimizer=keras.optimizers.Adam(learning_rate=1e-3),
    metrics=[
        keras.metrics.Accuracy(name='acc'),
    ],
)

model.summary()

# evaluate model

callbacks = [
    keras.callbacks.ModelCheckpoint(filepath="tmp/checkpoints/spi_v1_at_epoch_{epoch}.keras"),
    # keras.callbacks.EarlyStopping(monitor="loss", patience=10)
]

model.fit(
    X_train, y_train,
    batch_size=8,
    epochs=1000,
    verbose=0,
    callbacks=callbacks,
)
score = model.evaluate(X_test, y_test, verbose=0)


Epoch 1/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 40ms/step - acc: 0.1042 - loss: 5.6612    
Epoch 2/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 29ms/step - acc: 0.2969 - loss: 6.5480
Epoch 3/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 26ms/step - acc: 0.2552 - loss: 8.5627 
Epoch 4/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - acc: 0.3177 - loss: 9.0664
Epoch 5/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/step - acc: 0.4193 - loss: 7.0517
Epoch 6/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - acc: 0.4271 - loss: 7.0517
Epoch 7/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 19ms/step - acc: 0.4453 - loss: 8.3109
Epoch 8/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 22ms/step - acc: 0.4714 - loss: 7.5554
Epoch 9/1000
[1m3/3[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 20ms/ste

KeyboardInterrupt: 