In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from importlib import reload
import os
import torch
import operator
import copy
from collections import defaultdict
import sys

sys.path.append("..")
from brain_connectivity import (
    dataset,
    gin,
    dense,
    enums,
    training,
    evaluation,
    general_utils,
    data_utils,
)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

## Train model

In [None]:
common_hyperparameters = {
    # Dataset.
    "upsample_ts": [None],
    "upsample_ts_method": ["iaaft"],
    "correlation_type": [enums.CorrelationType.PEARSON],
    "batch_size": [2, 4, 8],
    # Model.
    'num_hidden_features': [8, 16, 32, 64, 90],
    'num_sublayers': [1, 2],
    'dropout': [0.0, 0.5],
    # Training.
    'optimizer_kwargs': {
        'lr': [0.01, 0.001, 0.0001],
        'weight_decay': [0.0001]
    },
    'epochs': [20],
}

graph_hyperparameters_fixed = {
    "node_features": [enums.NodeFeatures.FC_MATRIX_ROW],
    # How to create FC matrix.
    "geometric_kwargs": {
        "thresholding_function": [enums.ThresholdingFunction.GROUP_AVERAGE, enums.ThresholdingFunction.SUBJECT_VALUES],
        "threshold_type": [enums.DataThresholdingType.FIXED_THRESHOLD],
        "threshold": [0.01, 0.05, 0.1],
        # FIXME: Cannot name file with `str` of `operator` function due to "<>".
        # "thresholding_operator": [operator.ge],
        "threshold_by_absolute_value": [True, False],
        "return_absolute_value": [False],
    },
    "eps": [0.0]
}
graph_hyperparameters_fixed.update(common_hyperparameters)

graph_hyperparameters_knn = copy.deepcopy(graph_hyperparameters_fixed)
graph_hyperparameters_knn["geometric_kwargs"]["threshold_type"] = [enums.DataThresholdingType.KNN]
graph_hyperparameters_knn["geometric_kwargs"]["threshold"] = [5, 10, 20]
graph_hyperparameters = [graph_hyperparameters_fixed, graph_hyperparameters_knn]

dense_hyperparameters = {
    "node_features": [enums.NodeFeatures.FC_MATRIX_ROW],
    "mode": [enums.ConnectivityMode.SINGLE],
    "num_nodes": [90],
    "readout": ["add", "mean", "max"],
    "emb_dropout": [0.0],
    "emb_residual": [None, "add", "mean", "max"],
    "emb_init_weights": ["constant", "normal"],
    "emb_val": [0.0],
    "emb_std": [0.01],
}
dense_hyperparameters.update(common_hyperparameters)

In [None]:
# Always fixed parameters.
model_params = {
    "size_in": 90
}
dataset_params = {
    # Raw data.
    "data_folder": os.path.normpath('../data'),
    "device": device,
}
training_params = {
    # Training regime.
    'validation_frequency': 1,
    
    # Optimizer.
    'optimizer': torch.optim.Adam,
    # Loss.
    'criterion': torch.nn.CrossEntropyLoss(),
    # Scheduler.
    # 'scheduler': torch.optim.lr_scheduler.StepLR,
    # 'scheduler_kwargs': {
    #     'step_size': 50,
    #     'gamma': 0.5
    # },

    # Plotting.
    'fc_matrix_plot_frequency': None,
    'fc_matrix_plot_sublayer': 0
}

In [None]:
pg = data_utils.DoubleLevelParameterGrid(graph_hyperparameters)
len(pg)

In [None]:
dataframe_with_subjects = "patients-cleaned.csv"
target_column = "target"

df = pd.read_csv(
    os.path.join(os.path.normpath("../data"), dataframe_with_subjects),
    index_col=0,
)
targets = df[target_column].values


In [None]:
reload(dataset)
reload(gin)
reload(dense)
reload(evaluation)
reload(training)
reload(data_utils)

In [None]:
i = 3

In [None]:
experiment_folder = os.path.join(os.path.normpath("../runs"), f"test_gin_{i}")
os.makedirs(experiment_folder, exist_ok=False)

i += 1
general_utils.close_all_loggers()

In [None]:
# Init cross-validation.
num_assess_folds = 2
num_select_folds = 3
cv = data_utils.StratifiedCrossValidation(
    targets=targets,
    num_assess_folds=num_assess_folds,
    num_select_folds=num_select_folds,
)

for outer_id in cv.outer_cross_validation():
    os.makedirs(os.path.join(experiment_folder, f"{outer_id:03d}"), exist_ok=False)
    logger = general_utils.get_logger(
            "cv", os.path.join(experiment_folder, f"{outer_id:03d}", "cv.txt")
        )
    # Model selection.
    # Keep best hyperparameters.
    best_hyperparameters = None
    best_mean_accuracy = 0
    best_std_accuracy = 0

    hyperparameter_grid = data_utils.DoubleLevelParameterGrid(
        graph_hyperparameters
    )
    for hyper_id, hyperparameters in enumerate(hyperparameter_grid):
        logger.info(f"Hyperparameters: {hyperparameters}")
        if hyper_id == 2:
            break
        log_folder = os.path.join(
            experiment_folder,
            f"{outer_id:03d}",
            f"{hyper_id:04d}_{training.stringify(hyperparameters)}",
        )
        os.makedirs(log_folder, exist_ok=False)

        model, data, trainer = training.init_geometric_traning(
            log_folder,
            device,
            hyperparameters,
            targets,
            model_params=model_params,
            dataset_params=dataset_params,
            training_params=training_params,
        )

        # Run training.
        train_dataset = "train"
        eval_dataset = "val"
        for inner_id in cv.inner_cross_validation():
            logger.debug(f"Inner fold {inner_id+1} / {num_select_folds}")
            trainer.train(
                model=model,
                named_trainloader=(
                    train_dataset,
                    data.geometric_loader(
                        dataset=train_dataset, indices=cv.train_indices
                    ),
                ),
                named_evalloader=(
                    eval_dataset,
                    data.geometric_loader(
                        dataset=eval_dataset, indices=cv.val_indices
                    ),
                ),
                fold=inner_id,
            )

        # Results.
        train_results, eval_results = trainer.get_results(
            train_dataset=train_dataset, eval_dataset=eval_dataset
        )
        logger.debug(f"Train: {train_results}")
        logger.debug(f"Val: {eval_results}")

        # Update best setting based on eval accuracy
        max_index = np.argmax(
            eval_results["accuracy"][0] - eval_results["accuracy"][1]
        )
        max_mean_accuracy = eval_results["accuracy"][0][max_index]
        max_std_accuracy = eval_results["accuracy"][1][max_index]

        if (max_mean_accuracy - max_std_accuracy) > (
            best_mean_accuracy - best_std_accuracy
        ):
            hyperparameters["epochs"] = max_index + 1
            best_hyperparameters = hyperparameters
            best_mean_accuracy = max_mean_accuracy
            best_std_accuracy = max_std_accuracy

    # Model assessment.
    logger.info(f"Best hyperparameters: {best_hyperparameters}")
    logger.info(f"Best mean accuracy: {best_mean_accuracy}")
    logger.info(f"Best std accuracy: {best_std_accuracy}") 

    # Average over 3 runs to offset random initialization.
    test_results = defaultdict(list)
    dev_results = defaultdict(list)
    for test_id in range(3):
        log_folder = os.path.join(
            experiment_folder, f"{outer_id:03d}", f"test_{test_id}"
        )
        os.makedirs(log_folder, exist_ok=False)

        model, data, trainer = training.init_geometric_traning(
            log_folder,
            device,
            best_hyperparameters,
            targets,
            model_params=model_params,
            dataset_params=dataset_params,
            training_params=training_params,
        )
        # Run training.
        train_dataset = "dev"
        eval_dataset = "test"
        trainer.train(
            model=model,
            named_trainloader=(
                train_dataset,
                data.geometric_loader(
                    dataset=train_dataset, indices=cv.dev_indices
                ),
            ),
            named_evalloader=(
                eval_dataset,
                data.geometric_loader(
                    dataset=eval_dataset, indices=cv.test_indices
                ),
            ),
            fold=f"test_{test_id}",
        )
        # Results.
        train_results, eval_results = trainer.get_results(
            train_dataset=train_dataset, eval_dataset=eval_dataset
        )
        dev_results["accuracy"].append(
            (train_results["accuracy"][0][-1], train_results["accuracy"][1][-1])
        )
        dev_results["recall"].append(
            (train_results["recall"][0][-1], train_results["recall"][1][-1])
        )
        dev_results["precision"].append(
            (
                train_results["precision"][0][-1],
                train_results["precision"][1][-1],
            )
        )

        test_results["accuracy"].append(
            (eval_results["accuracy"][0][-1], eval_results["accuracy"][1][-1])
        )
        test_results["recall"].append(
            (eval_results["recall"][0][-1], eval_results["recall"][1][-1])
        )
        test_results["precision"].append(
            (eval_results["precision"][0][-1], eval_results["precision"][1][-1])
        )

    logger.info(f"Dev: {dev_results}")
    logger.info(f"Test: {test_results}")
    general_utils.close_logger("cv")
    break

general_utils.close_all_loggers()
print("Finished training")


In [None]:
test_results

In [None]:
best_std_accuracy

In [None]:
hyperparameters

In [None]:
general_utils.close_all_loggers