# SPI Classification

This notebooks implements a simple deep learning classifier that receives all the pairwise interactions (SPIs) of a subject and predict their label (AVGP or NVGP).

Two models are implemented:
- single-SPI received one SPI and predicts the label
- multi-SPI received all the SPIs at the same time and predicts the label


In [44]:
%reload_ext autoreload
%autoreload 2

In [45]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['KERAS_BACKEND'] = 'torch'

import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from src.multimodal.utils import ProgressBarCallback

In [46]:
# step1: load and prep data

data = pd.read_csv('data/Julia2018/spis_dosenbach2010_network_nilearn.csv')
na_mask = data.groupby('spi').apply(lambda x: x.isna().sum().sum(), include_groups=False) > 0
na_features = na_mask[na_mask].index
data = data.query('spi not in @na_features')  # remove missing values
# TOP feature set using svm
data = data.query('spi=="partial_correlation"')

data = data.melt(id_vars=['subject', 'label', 'spi'], var_name='process', value_name='value')
data = data.pivot(index=['subject', 'label'], columns=['spi', 'process'])
data.columns = data.columns.droplevel(0)
data.sort_index(axis=1, level=['spi', 'process'])

# convert data to list of grouped data
grouped_data = {}
for name, group in tqdm(data.T.groupby('spi')):
    group = group.T
    group = group.droplevel(0, axis=1)
    grouped_data[name] = group.values


  0%|          | 0/1 [00:00<?, ?it/s]

In [47]:
# Prepare labels (y)

labels = data.index.get_level_values('label').values

# y_encoder = OneHotEncoder()
# y = y_encoder.fit_transform(labels.reshape(-1,1)).toarray()

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(labels).reshape(-1,1)

In [48]:
# Helper function to split grouped data

def split_grouped_data(grouped_X, y, test_size=0.25, random_state=None):
    """Split SPI tables into train and test splits.

    Args:
        grouped_X (dict): Dictionary of SPI tables, one for each SPI.
        y (np.ndarray): Labels (AVGP or NVGP).
        test_size (float): Fraction of data to reserve for test set.
        random_state (int): Random seed. None for no seed.
    """
    train_idx, test_idx = train_test_split(
        np.arange(y.shape[0]),
        test_size=test_size,
        stratify=y.reshape(-1),
        random_state=random_state)
    X_train = {k: v[train_idx] for k, v in grouped_X.items()}
    X_test = {k: v[test_idx] for k, v in grouped_X.items()}
    y_train = y[train_idx]
    y_test = y[test_idx]
    
    return X_train, X_test, y_train, y_test

# DEBUG
# X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)

In [99]:
# Helper function to build single-input and multi-inputs models
def build_multi_input_model(feature_names, feature_shape):
    inputs = []
    heads = []
    for f in feature_names:
        input_layer = keras.layers.Input(shape=feature_shape, name=f)
        head = keras.Sequential([
            # input_layer,
            keras.layers.Dense(32, activation='relu'),
            keras.layers.Dense(16, activation='relu'),
            keras.layers.Dense(8, activation='relu'),
            keras.layers.Dense(2, activation='softmax')
        ], name=f'{f}_head')(input_layer)
        inputs.append(input_layer)
        heads.append(head)
    # concat = keras.layers.concatenate(outputs)
    # merger = keras.Sequential([
    #     keras.layers.Dense(8, activation='relu'),
    #     keras.layers.Dense(2, activation='softmax')
    # ])(concat)

    merger = keras.layers.average(heads)

    model = keras.Model(
        inputs=inputs,
        outputs=merger
    )

    model.compile(
        loss=keras.losses.SparseCategoricalCrossentropy(),
        # optimizer=keras.optimizers.Adam(learning_rate=.01),
        optimizer='sgd',
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name='accuracy'),
        ],
    )

    return model

# DEBUG
# model = build_multi_input_model(list(grouped_data.keys()), (15,))
# model.summary()

In [100]:
# evaluate

scores = []
normalize = True
n_runs = 10
n_epochs = 10000

reusable_pbar = tqdm(
    total=n_epochs, unit='epoch',
    dynamic_ncols=True, leave=False)

feature_shape = grouped_data[list(grouped_data.keys())[0]].shape[1:]

print({
    'normalize': normalize,
    'n_inputs': len(grouped_data.keys()),
    'feature_shape': feature_shape,
})

for run in range(1, n_runs+1):

    X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)
    model = build_multi_input_model(list(grouped_data.keys()), feature_shape)

    # normalize data

    if normalize:
        scaler = MinMaxScaler(feature_range=(-1, 1))
        for k in X_train.keys():
            X_train[k] = scaler.fit_transform(X_train[k].T).T
            X_test[k] = scaler.fit_transform(X_test[k].T).T

    history = model.fit(
        X_train, y_train,
        epochs=n_epochs,
        verbose=0, # type: ignore
        shuffle=True,
        callbacks=[
            keras.callbacks.EarlyStopping(monitor='loss', patience=100),
            keras.callbacks.CSVLogger('tmp/keras_logs/spi_v1_train.csv'),
            ProgressBarCallback(
                n_epochs=n_epochs,
                n_runs=n_runs,
                run_index=run, reusable_pbar=reusable_pbar),
        ]
    )
    metrics = model.evaluate(
        X_test, y_test,
        return_dict=True,
        verbose=0 # type: ignore
    )

    print(f'run {run:03}/{n_runs:03} - accuracy: {metrics["accuracy"]:.3f}')

    scores.append(metrics['accuracy'])

print('mean accuracy: {:.3f}±{:.3f}'.format(np.mean(scores), np.std(scores)))

  0%|          | 0/10000 [00:00<?, ?epoch/s]

{'normalize': True, 'n_inputs': 1, 'feature_shape': (15,)}
run 001/010 - accuracy: 0.625
run 002/010 - accuracy: 0.625
run 003/010 - accuracy: 0.625
run 004/010 - accuracy: 0.375
run 005/010 - accuracy: 0.500
run 006/010 - accuracy: 0.375
run 007/010 - accuracy: 0.250
run 008/010 - accuracy: 0.625
run 009/010 - accuracy: 0.375
run 010/010 - accuracy: 0.625
mean accuracy: 0.500±0.137
