# SPI Classification

This notebooks implements a simple deep learning classifier that receives all the pairwise interactions (SPIs) of a subject and predict their label (AVGP or NVGP).

Two models are implemented:
- single-SPI received one SPI and predicts the label
- multi-SPI received all the SPIs at the same time and predicts the label


In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['KERAS_BACKEND'] = 'torch'

import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from src.multimodal.utils import ProgressBarCallback

In [3]:
# load data

data = pd.read_csv('data/Julia2018/spis_dosenbach2010_network_fast.csv')
na_mask = data.groupby('spi').apply(lambda x: x.isna().sum().sum(), include_groups=False) > 0
na_features = na_mask[na_mask].index
data = data.query('spi not in @na_features')  # remove missing values
# TOP feature set using svm
# data = data.query('spi=="phase_multitaper_max_fs-1_fmin-0_fmax-0-5"')
data = data.pivot(index=['subject', 'label'], columns=['spi'])
data = data.reorder_levels([1, 0], axis=1).sort_index(axis=1)

# convert data to list of data for each spi
grouped_data = {}
for name, group in tqdm(data.T.groupby('spi')):
    group = group.T
    group = group.droplevel(0, axis=1)
    grouped_data[name] = group.values

  0%|          | 0/209 [00:00<?, ?it/s]

In [4]:
# Prepare X and y

X = data.values
labels = data.index.get_level_values('label').values

# y_encoder = OneHotEncoder()
# y = y_encoder.fit_transform(labels.reshape(-1,1)).toarray()

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(labels).reshape(-1,1)

In [5]:
# Helper function to split grouped data

def split_grouped_data(grouped_X, y, test_size=0.25, random_state=None):
    """Split SPI tables into train and test splits.

    Args:
        grouped_X (dict): Dictionary of SPI tables, one for each SPI.
        y (np.ndarray): Labels (AVGP or NVGP).
        test_size (float): Fraction of data to reserve for test set.
        random_state (int): Random seed. None for no seed.
    """
    train_idx, test_idx = train_test_split(
        np.arange(y.shape[0]),
        test_size=test_size,
        stratify=y.reshape(-1),
        random_state=random_state)
    X_train = {k: v[train_idx] for k, v in grouped_X.items()}
    X_test = {k: v[test_idx] for k, v in grouped_X.items()}
    y_train = y[train_idx]
    y_test = y[test_idx]
    return X_train, X_test, y_train, y_test

# DEBUG
# X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)

In [6]:
# Helper function to build single-input and multi-inputs models

def build_single_input_model(input_shape): 

    model = keras.Sequential(
        [
            keras.layers.Input(shape=input_shape),
            # keras.layers.Dense(X.shape[1]*2, activation='relu'),
            keras.layers.Dense(8, activation='relu'),
            keras.layers.Dense(2, activation='softmax'),
        ]
    )

    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(),
        optimizer=keras.optimizers.Adam(learning_rate=.01),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
        ],
    )

    return model


def build_multi_input_model(feature_names, feature_shape=(15,)):
    inputs = []
    outputs = []

    for f in feature_names:
        input = keras.layers.Input(shape=feature_shape, name=f)
        output = keras.Sequential([
            input,
            keras.layers.Dense(8, activation='relu'),
            keras.layers.Dense(2, activation='softmax')
        ], name=f'{f}_head')(input)
        inputs.append(input)
        outputs.append(output)

    # concat = keras.layers.concatenate(outputs)
    # merger = keras.Sequential([
    #     keras.layers.Dense(8, activation='relu'),
    #     keras.layers.Dense(2, activation='softmax')
    # ])(concat)

    merger = keras.layers.Average()(outputs)

    model = keras.Model(inputs=inputs, outputs=merger)

    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(),
        optimizer=keras.optimizers.Adam(learning_rate=.01),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
        ],
    )

    return model

# DEBUG
# model = build_single_input_model(X.shape[1:])
# model = build_multi_input_model(list(grouped_data.keys()))

In [7]:

# evaluate model
scores = []
multi_input = True
n_runs = 100
n_epochs = 2000

reusable_pbar = tqdm(
    total=n_epochs, unit='epoch',
    dynamic_ncols=True, leave=False)

for run in range(1, n_runs+1):

    if multi_input:
        X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)
        model = build_multi_input_model(list(grouped_data.keys()))
    else:
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, stratify=y)
        model = build_single_input_model(X.shape[1:])

    history = model.fit(
        X_train, y_train,
        epochs=n_epochs,
        verbose=0, # type: ignore
        shuffle=True,
        callbacks=[
            keras.callbacks.EarlyStopping(monitor='loss', patience=n_epochs//10),
            keras.callbacks.CSVLogger('tmp/keras_logs/spi_v1_train.csv'),
            ProgressBarCallback(
                n_epochs=n_epochs,
                n_runs=n_runs,
                run_index=run, reusable_pbar=reusable_pbar),
        ]
    )
    score = model.evaluate(
        X_test, y_test,
        return_dict=True,
        verbose=0 # type: ignore
    )

    print(f'run {run:03}/{n_runs:03} - accuracy: {score["accuracy"]:.3f}')

    scores.append(score['accuracy'])

print('mean accuracy:', np.mean(scores))

  0%|          | 0/2000 [00:00<?, ?epoch/s]

run 001/100 - accuracy: 0.750
run 002/100 - accuracy: 0.875
run 003/100 - accuracy: 0.625
run 004/100 - accuracy: 0.875
run 005/100 - accuracy: 0.750
run 006/100 - accuracy: 0.625
run 007/100 - accuracy: 0.625
run 008/100 - accuracy: 0.875
run 009/100 - accuracy: 0.750
run 010/100 - accuracy: 0.625
run 011/100 - accuracy: 0.500
run 012/100 - accuracy: 0.750
run 013/100 - accuracy: 0.625
run 014/100 - accuracy: 0.625
run 015/100 - accuracy: 0.750
run 016/100 - accuracy: 0.500
run 017/100 - accuracy: 0.625
run 018/100 - accuracy: 0.500
run 019/100 - accuracy: 0.625
run 020/100 - accuracy: 0.750
run 021/100 - accuracy: 0.625
run 022/100 - accuracy: 0.625
run 023/100 - accuracy: 0.750
run 024/100 - accuracy: 0.875
run 025/100 - accuracy: 0.750
run 026/100 - accuracy: 0.375
run 027/100 - accuracy: 0.625
run 028/100 - accuracy: 0.625
run 029/100 - accuracy: 0.750
run 030/100 - accuracy: 0.750
run 031/100 - accuracy: 0.500
run 032/100 - accuracy: 0.500
run 033/100 - accuracy: 0.750
run 034/10