# MultiSPI DL Classification

This notebooks implements a simple deep learning classifier that receives all the SPIs of a subject and predict their label (AVGP or NVGP).

In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['KERAS_BACKEND'] = 'torch'

import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm

In [3]:
# load data

data = pd.read_csv('data/Julia2018/spis_dosenbach2010_network_fast.csv')
na_mask = data.groupby('spi').apply(lambda x: x.isna().sum().sum(), include_groups=False) > 0
na_features = na_mask[na_mask].index
data = data.query('spi=="phase_multitaper_max_fs-1_fmin-0_fmax-0-5"')  # TOP featureset using svm
data = data.query('spi not in @na_features')  # remove features with missing values
data = data.pivot(index=['subject', 'label'], columns=['spi'])
data.columns = ['_'.join(col) for col in data.columns.values]

In [4]:

X = data.values
labels = data.index.get_level_values('label').values

# y_encoder = OneHotEncoder()
# y = y_encoder.fit_transform(labels.reshape(-1,1)).toarray()

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(labels).reshape(-1,1)

In [5]:
class ProgressBar(keras.callbacks.Callback):
    def __init__(self, n_epochs=None,
                 n_runs=None,
                 run_index=None, reusable_pbar=None):

        self.n_epochs = n_epochs
        self.pbar = reusable_pbar
        if self.pbar is None:
            self.pbar = tqdm(
                total=n_epochs, unit='epoch',
                dynamic_ncols=True, leave=False)
        else:
            self.pbar.total = n_epochs
        self.pbar.set_description(f'run {run_index:02}/{n_runs:02}')

    def on_train_begin(self, logs=None):
        self.pbar.reset()

    def on_epoch_end(self, epoch, logs=None):
        self.pbar.set_postfix(logs)
        self.pbar.update(epoch - self.pbar.n + 1)

    def on_train_end(self, logs=None):
        self.pbar.reset()

# build model
def build_model(): 

    model = keras.Sequential(
        [
            keras.layers.Input(shape=X.shape[1:]),
            keras.layers.Dense(X.shape[1]*2, activation='relu'),
            keras.layers.Dense(16, activation='relu'),
            keras.layers.Dense(2, activation='softmax'),
        ]
    )

    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(),
        optimizer=keras.optimizers.Adam(learning_rate=.01),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
        ],
    )

    return model

# evaluate model
scores = []
n_runs = 500
n_epochs = 200
reusable_pbar = tqdm(
    total=n_epochs, unit='epoch',
    dynamic_ncols=True, leave=False)

for run in range(1, n_runs+1):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y)

    model = build_model()

    history = model.fit(
        X_train, y_train,
        epochs=n_epochs,
        verbose=0,
        shuffle=True,
        callbacks=[
            keras.callbacks.EarlyStopping(monitor='loss', patience=n_epochs//10),
            keras.callbacks.CSVLogger('tmp/keras_logs/spi_v1_train.csv'),
            ProgressBar(n_epochs=n_epochs,
                        n_runs=n_runs,
                        run_index=run, reusable_pbar=reusable_pbar),
        ]
    )
    score = model.evaluate(
        X_test, y_test,
        return_dict=True,
        verbose=0)

    print(f'run {run:03}/{n_runs:03} - accuracy: {score["accuracy"]:.3f}')

    scores.append(score['accuracy'])

print('mean accuracy:', np.mean(scores))

  0%|          | 0/200 [00:00<?, ?epoch/s]

run 001/500 - accuracy: 0.625
run 002/500 - accuracy: 0.875
run 003/500 - accuracy: 0.875
run 004/500 - accuracy: 0.750
run 005/500 - accuracy: 0.500
run 006/500 - accuracy: 0.500
run 007/500 - accuracy: 0.875
run 008/500 - accuracy: 0.625
run 009/500 - accuracy: 0.750
run 010/500 - accuracy: 0.750
run 011/500 - accuracy: 0.750
run 012/500 - accuracy: 1.000
run 013/500 - accuracy: 0.750
run 014/500 - accuracy: 0.875
run 015/500 - accuracy: 0.875
run 016/500 - accuracy: 0.500
run 017/500 - accuracy: 0.875
run 018/500 - accuracy: 0.625
run 019/500 - accuracy: 1.000
run 020/500 - accuracy: 0.875
run 021/500 - accuracy: 0.875
run 022/500 - accuracy: 0.875
run 023/500 - accuracy: 0.750
run 024/500 - accuracy: 0.625
run 025/500 - accuracy: 0.500
run 026/500 - accuracy: 0.625
run 027/500 - accuracy: 0.750
run 028/500 - accuracy: 0.750
run 029/500 - accuracy: 0.875
run 030/500 - accuracy: 0.625
run 031/500 - accuracy: 0.625
run 032/500 - accuracy: 0.750
run 033/500 - accuracy: 1.000
run 034/50