# GIN < FC

In [None]:
import numpy as np
import pickle
import pandas as pd 
import matplotlib.pyplot as plt
from importlib import reload

import torch
import torch.nn.functional as F
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import TensorDataset
from torch_geometric.data import Data, DataLoader

from torchinfo import summary
from sklearn.model_selection import StratifiedKFold

import sys
sys.path.append("..")
from brain_connectivity import model as bc
from brain_connectivity.dense import ConnectivityDenseNet, ConnectivityMode
from brain_connectivity.gin import GIN

## Load data

In [None]:
DATA_FOLDER = '../data'
PICKLE_FOLDER = '../pickles'

In [None]:
df_metadata = pd.read_csv(f'{DATA_FOLDER}/patients-cleaned.csv', index_col=0)
df_metadata.head(2)

### Select connectivity dataset

In [None]:
THRESHOLD = 0.05                                          # 0.01, 0.05, 0.1, 0.15
N = 20                                                   # 3, 5, 7, 10, 15, 20, 40
CORR_TYPE = 'pearson'                                    # 'pearson', 'spearman', 'partial-pearson'
THRESHOLD_METHOD = 'abs-group-avg-diff'                  # 'abs-sample-diff', 'abs-group-avg-diff'
THRESHOLD_TYPE = 'min'                                   # 'min', 'max' or for kNN 'small', 'large'
KNN = False                                              # Whether all or only top N neigbors are taken

In [None]:
fc_folder = f'{PICKLE_FOLDER}/fc-{CORR_TYPE}{"-knn" if KNN else ""}-{THRESHOLD_METHOD}'

# Try Gini or SGD.
# fc_folder = f'{PICKLE_FOLDER}/fc-{CORR_TYPE}-gini'
# fc_folder = f'{PICKLE_FOLDER}/fc-{CORR_TYPE}-sgd'

fc_file_binary = f'{fc_folder}/{THRESHOLD_TYPE}-{f"knn-{N}" if KNN else f"th-{THRESHOLD}"}-binary.pickle'
fc_file_real = f'{fc_folder}/{THRESHOLD_TYPE}-{f"knn-{N}" if KNN else f"th-{THRESHOLD}"}-real.pickle'

# fc_file_binary = f'{fc_folder}/binary.pickle'
# fc_file_real = f'{fc_folder}/real.pickle'

In [None]:
with open(f'{PICKLE_FOLDER}/fc-pearson.pickle', 'rb') as f:
    raw_matrix = pickle.load(f)

with open(fc_file_binary, 'rb') as f:
    edge_index_matrix = pickle.load(f)

with open(fc_file_real, 'rb') as f:
    fc_matrix = pickle.load(f)


num_samples, num_parcels, _ = edge_index_matrix.shape
edge_index_matrix.shape

## Split data

In [None]:
with open(f'{PICKLE_FOLDER}/test-indices.pickle', 'rb') as f:
    test_indices = pickle.load(f)
    
train_indices = list(set(range(num_samples)) - set(test_indices))
train_targets = df_metadata.iloc[train_indices]["target"].reset_index(drop=True)

print(f'Train set size: {len(train_indices)}')
print(f'Test set size: {len(test_indices)}')

## Prepare data

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

### `Data` object fields

- `data.x`: Node feature matrix with shape `[num_nodes, num_node_features]`

- `data.edge_index`: Graph connectivity in COO format with shape `[2, num_edges]` and type `torch.long`

- `data.edge_attr`: Edge feature matrix with shape `[num_edges, num_edge_features]`

- `data.y`: Target to train against (may have arbitrary shape), e.g., node-level targets of shape `[num_nodes, *]` or graph-level targets of shape `[1, *]`

- `data.pos`: Node position matrix with shape `[num_nodes, num_dimensions]`

### Select node features

- onehot
- correlations

In [None]:
# Each nodes contains its row from FC matrix.
def correlations_in_nodes(i):
    return torch.from_numpy(fc_matrix[i]).to(torch.float32)


# Each brain region is onehot encoded. See GIN for phenotype paper.
def onehot_in_nodes(i):
    return torch.diag(torch.ones(num_parcels))

In [None]:
features_in_nodes = correlations_in_nodes
num_node_features = num_parcels
num_node_features

## Create datasets

### Graph dataset

In [None]:
graph_dataset = [Data(
    x=features_in_nodes(i),  
    edge_index=torch.from_numpy(np.asarray(np.nonzero(edge_index_matrix[i]))).to(torch.int64),
    # y=torch.tensor([[1, 0]]  if target == 0 else [[0, 1]], dtype=torch.int64)
    y=torch.tensor([target], dtype=torch.int64)
).to(device) for target, i in zip(train_targets, train_indices)]

In [None]:
print(f'True train data: {len(graph_dataset)}')

print('Data object')
print(f'Edge index: {graph_dataset[0].edge_index.shape}')
print(f'Node features: {graph_dataset[0].x.shape}')
print(f'Target: {graph_dataset[0].y.shape}')

### Dense dataset

In [None]:
dense_dataset = TensorDataset(
    torch.from_numpy(raw_matrix[train_indices]), 
    torch.from_numpy(train_targets.values)
)

## Define GIN & FC architectures

In [None]:
# Architecture FC.
summary(ConnectivityDenseNet(
    num_parcels, 
    ConnectivityMode.SINGLE, 
    num_node_features,
    32,
    num_sublayers=3
))

In [None]:
# Architecture GIN.
summary(GIN(
    size_in=num_node_features,
    num_hidden_features=32,
    num_sublayers=3
))

## Train model

In [None]:
NUM_FOLDS = 2
skf = StratifiedKFold(n_splits=NUM_FOLDS, random_state=42, shuffle=True)

In [None]:
# Training parameters settings.
training_params = {
    # Training regime.
    'epochs': 100,
    'validation_frequency': 1,

    # Optimizer.
    'optimizer': torch.optim.Adam,
    'momentum': 0.9,
    'learning_rate': 0.001,
    'weight_decay': 0.0001,

    # Scheduler.
    'step_size': 50,
    'gamma': 0.5,

    # Loss.
    'criterion': torch.nn.CrossEntropyLoss()
}

In [None]:
# Model parameters settings.
model_params = {
    'size_in': num_node_features,
    'num_hidden_features': 64,
    'num_sublayers': 5
}

gin_params = {
    'eps': 0.2
}

dense_params = {
    'mode': ConnectivityMode.SINGLE,
    'num_nodes': num_parcels
}

In [None]:
# Experiment folder.
EXP_FOLDER = 'runs/fc-vs-gin'

# Experiment.
EXP_ID = 1

BATCH_SIZE = 2

MODEL_TYPE = bc.ModelType.GRAPH

In [None]:
reload(bc)

In [None]:
for fold_id, (train_index, val_index) in enumerate(skf.split(np.zeros(len(train_targets)), train_targets)):

    # Init TB writer.
    experiment_str = f'id={EXP_ID:03d},fold={fold_id}'
    writer = SummaryWriter(f"../{EXP_FOLDER}/{MODEL_TYPE}/{experiment_str}")

    # Init model.
    if MODEL_TYPE == bc.ModelType.GRAPH:
        net = GIN(**model_params, **gin_params).to(device)
    elif MODEL_TYPE == bc.ModelType.DENSE:
        net = ConnectivityDenseNet(**model_params, **dense_params).to(device)
    else:
        raise Exception('Unsupported model type.')

    # Prepare data.
    dataset = graph_dataset if MODEL_TYPE == bc.ModelType.GRAPH else dense_dataset
    X_train = [graph_dataset[i] for i in train_index]
    X_val = [dense_dataset[i] for i in val_index]

    # NOTE: There is no problem in using `Geometric` `DataLoader` as standard one.
    trainloader = DataLoader(X_train, batch_size=BATCH_SIZE, shuffle=True)
    valloader = DataLoader(X_val, batch_size=BATCH_SIZE, shuffle=False)

    # Save architecture.
    with open(f"../{EXP_FOLDER}/{MODEL_TYPE}/{experiment_str}/architecture", 'w', encoding="utf-8") as f:
        if MODEL_TYPE == bc.ModelType.GRAPH:
            f.write(fc_folder + '\n' + fc_file_binary + '\n' + fc_file_real + '\n')
        f.write(features_in_nodes.__str__() + '\n')
        f.write(training_params.__str__() + '\n')
        f.write(net.__str__() + '\n\n')
        f.write(str(summary(net)))

    # Init model wrapper.
    model = bc.Model(
        model=net, 
        trainloader=trainloader, 
        valloader=valloader,
        writer=writer,
        **training_params
    )
    
    # Run training.
    model.train()


    # Single fold during exploration.
    #break

print('Finished training')