# SPI Classification using Conditional GAN (CGAN)


In [1]:
%reload_ext autoreload
%autoreload 2

In [3]:
import os
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'  # disable GPU for testing
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from keras import ops
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import train_test_split
from tqdm.auto import tqdm
from src.multimodal.utils import ProgressBarCallback

In [4]:
# step1: load and prep data

data = pd.read_csv('data/Julia2018/spis_dosenbach2010_network_nilearn.csv')
na_mask = data.groupby('spi').apply(lambda x: x.isna().sum().sum(), include_groups=False) > 0
na_features = na_mask[na_mask].index
data = data.query('spi not in @na_features')  # remove missing values
# TOP feature set using svm
data = data.query('spi=="partial_correlation"')

data = data.melt(id_vars=['subject', 'label', 'spi'], var_name='process', value_name='value')
data = data.pivot(index=['subject', 'label'], columns=['spi', 'process'])
data.columns = data.columns.droplevel(0)
data.sort_index(axis=1, level=['spi', 'process'])

# convert data to list of grouped data
grouped_data = {}
for name, group in tqdm(data.T.groupby('spi')):
    group = group.T
    group = group.droplevel(0, axis=1)
    grouped_data[name] = group.values


  0%|          | 0/1 [00:00<?, ?it/s]

In [8]:
# Prepare labels (y)

labels = data.index.get_level_values('label').values

y_encoder = LabelEncoder()
y = y_encoder.fit_transform(labels)
y_encoder.classes_

array(['AVGP', 'NVGP'], dtype=object)

In [9]:
# Helper function to split grouped data

def split_grouped_data(grouped_X, y, test_size=0.25, random_state=None):
    """Split SPI tables into train and test splits.

    Args:
        grouped_X (dict): Dictionary of SPI tables, one for each SPI.
        y (np.ndarray): Labels (AVGP or NVGP).
        test_size (float): Fraction of data to reserve for test set.
        random_state (int): Random seed. None for no seed.
    """
    train_idx, test_idx = train_test_split(
        np.arange(y.shape[0]),
        test_size=test_size,
        stratify=y,
        random_state=random_state)
    X_train = {k: v[train_idx] for k, v in grouped_X.items()}
    X_test = {k: v[test_idx] for k, v in grouped_X.items()}
    y_train = y[train_idx]
    y_test = y[test_idx]
    
    return X_train, X_test, y_train, y_test

# DEBUG
# X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)

In [10]:
# Helper function to build models

def build_model(feature_name, feature_dim=15, n_classes=2):

    from src.multimodal.models.conditional_gan import ConditionalGAN
    model = ConditionalGAN(input_dim=feature_dim, latent_dim=15,
                 n_classes=n_classes,
                 name=f'cgan_{feature_name}')
    model.compile(
        loss=keras.losses.CategoricalCrossentropy(from_logits=True),
        d_optimizer=keras.optimizers.Adam(learning_rate=0.0003),
        g_optimizer=keras.optimizers.Adam(learning_rate=0.0003)
    )

    return model

# # DEBUG
# model = build_model(list(grouped_data.keys())[0])
# model.summary()

In [20]:
# train and evaluate model

scores = []

# params
normalize = False
n_runs = 1
n_epochs = 1000
patience = 100
version = 'v1'

reusable_pbar = tqdm(
    total=n_epochs, unit='epoch',
    dynamic_ncols=True, leave=False)

feature_names = list(grouped_data.keys())
feature_name = feature_names[0]  # FIXME run over all features
feature_dim = grouped_data[list(grouped_data.keys())[0]].shape[1]

print({
    'normalize': normalize,
    'n_inputs': len(grouped_data.keys()),
    'feature_dim': feature_dim,
    'version': version,
    'n_classes': y.max() + 1,
})

for run in range(1, n_runs+1):

    X_train, X_test, y_train, y_test = split_grouped_data(grouped_data, y)
    model = build_model(feature_name, feature_dim, n_classes=y.max() + 1)

    history = model.fit(
        X_train[feature_name], y_train,
        epochs=n_epochs,
        verbose=0, # type: ignore
        shuffle=True,
        callbacks=[
            keras.callbacks.EarlyStopping(
                monitor='g_loss', mode='min', patience=patience),
            keras.callbacks.CSVLogger(f'tmp/keras_logs/{model.name}_{version}.csv'),
            ProgressBarCallback(
                n_epochs=n_epochs,
                n_runs=n_runs,
                run_index=run, reusable_pbar=reusable_pbar),
        ]
    )

    metrics = model.evaluate(
        X_test[feature_name], y_test,
        return_dict=True,
        verbose=0 # type: ignore
    )
    print(metrics)

    X_real = X_test[feature_name]
    X_fake = np.random.randn(X_real.shape[0], 15)
    
    y_real = np.ones((X_real.shape[0], 1))
    y_fake = np.zeros((X_fake.shape[0], 1))

    X_eval = ops.concatenate([X_real, X_fake], axis=0)
    y_eval = ops.concatenate([y_real, y_fake], axis=0)

    y_pred = model.predict(X_eval, verbose=0)
    accuracy = ((y_pred > 0.5).astype(int) == y_eval.cpu().numpy()).astype(int).mean()

    print(f'run {run:03}/{n_runs:03} - accuracy: {accuracy:.3f}')

    scores.append(metrics['accuracy'])

print('mean accuracy:', np.mean(scores))

  0%|          | 0/1000 [00:00<?, ?epoch/s]

{'normalize': False, 'n_inputs': 1, 'feature_dim': 15, 'version': 'v1', 'n_classes': 2}
{'accuracy': 0.0, 'd_loss': 0.0, 'g_loss': 0.0}
run 001/001 - accuracy: 0.500
mean accuracy: 0.0
