# SPI Classification using GAN


In [1]:
%reload_ext autoreload
%autoreload 2

In [2]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['KERAS_BACKEND'] = 'torch'

import keras
import numpy as np
import pandas as pd
from sklearn.preprocessing import OneHotEncoder, LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from src.multimodal.utils import ProgressBarCallback

In [3]:
# load data

data = pd.read_csv('data/Julia2018/spis_dosenbach2010_network_fast.csv')
na_mask = data.groupby('spi').apply(lambda x: x.isna().sum().sum(), include_groups=False) > 0
na_features = na_mask[na_mask].index
data = data.query('spi not in @na_features')  # remove missing values
# TOP feature set using svm
data = data.query('spi=="phase_multitaper_max_fs-1_fmin-0_fmax-0-5"')
data = data.pivot(index=['subject', 'label'], columns=['spi'])
data = data.reorder_levels([1, 0], axis=1).sort_index(axis=1)

# convert data to list of data for each spi
grouped_data = {}
for name, group in tqdm(data.T.groupby('spi')):
    group = group.T
    group = group.droplevel(0, axis=1)
    grouped_data[name] = group.values

  0%|          | 0/1 [00:00<?, ?it/s]

In [4]:
# Prepare X and y

X = data.values
labels = data.index.get_level_values('label').values

# y_encoder = OneHotEncoder()
# y = y_encoder.fit_transform(labels.reshape(-1,1)).toarray()

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(labels).reshape(-1,1)

In [5]:
# Helper function to split grouped data

def split_grouped_data(grouped_X, y, test_size=0.25, random_state=None):
    """Split SPI tables into train and test splits.

    Args:
        grouped_X (dict): Dictionary of SPI tables, one for each SPI.
        y (np.ndarray): Labels (AVGP or NVGP).
        test_size (float): Fraction of data to reserve for test set.
        random_state (int): Random seed. None for no seed.
    """
    train_idx, test_idx = train_test_split(
        np.arange(y.shape[0]),
        test_size=test_size,
        stratify=y.reshape(-1),
        random_state=random_state)
    X_train = {k: v[train_idx] for k, v in grouped_X.items()}
    X_test = {k: v[test_idx] for k, v in grouped_X.items()}
    y_train = y[train_idx]
    y_test = y[test_idx]
    return X_train, X_test, y_train, y_test

# DEBUG
# X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)

In [24]:
# Helper function to build models

def build_model(feature_names, input_dim=(15,)):

    from src.multimodal.models.gan import GAN
    gan = GAN(input_dim=input_dim, latent_dim=15)
    gan.compile(
        d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
        g_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
        loss_fn=keras.losses.BinaryCrossentropy(from_logits=True),
    )

    return gan

# DEBUG
# model = build_model(grouped_data.keys())

In [26]:

# evaluate model
scores = []
multi_input = True
n_runs = 1
n_epochs = 10

reusable_pbar = tqdm(
    total=n_epochs, unit='epoch',
    dynamic_ncols=True, leave=False)

for run in range(1, n_runs+1):

    X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)
    model = build_model(grouped_data.keys(), X.shape[1:])

    history = model.fit(
        X_train['phase_multitaper_max_fs-1_fmin-0_fmax-0-5'], y_train,
        epochs=n_epochs,
        verbose=0, # type: ignore
        shuffle=True,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor='d_loss', mode='min', patience=n_epochs//10),
            keras.callbacks.CSVLogger('tmp/keras_logs/spi_gan_v1_train.csv'),
            ProgressBarCallback(
                n_epochs=n_epochs,
                n_runs=n_runs,
                run_index=run, reusable_pbar=reusable_pbar),
        ]
    )

    # score = model.evaluate(
    #     X_test['phase_multitaper_max_fs-1_fmin-0_fmax-0-5'], y_test,
    #     return_dict=True,
    #     verbose=0 # type: ignore
    # )

    # print(f'run {run:03}/{n_runs:03} - accuracy: {score["accuracy"]:.3f}')

    # scores.append(score['accuracy'])

# print('mean accuracy:', np.mean(scores))

  0%|          | 0/10 [00:00<?, ?epoch/s]