In [None]:
from medigraph.data.abide import AbideData
import numpy as np
from nilearn import plotting
from medigraph.model.gcn import GCN, SparseGCN
from medigraph.model.baseline import DenseNN
import torch
from tqdm.notebook import tqdm
from medigraph.data.preprocess import sanitize_data, visual_sanity_check_input, whiten
from medigraph.train import training_loop, train, plot_learning_curves, test
from medigraph.data.properties import INPUTS, LABELS, TRAIN_MASK, VAL_MASK, TEST_MASK
import matplotlib.pyplot as plt
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%load_ext autoreload
%autoreload 2
device

In [None]:
dat = AbideData()

### Check connectivy matrix for a single patient
- 111x111 matrices
- We'll retrieve the $6216=\frac{111*(111+1)}{2}$ raw coefficients from the upper triangular matrix

In [None]:
# Get the connectivity matrix for the first subject
idx = 0
mat = dat.get_connectivity_matrix(idx)
plotting.plot_matrix(
    mat,
    figure=(6, 6),
    vmax=1,
    vmin=0,
    title=f"Patient {idx} connectivity matrix {mat.shape}"
)
feature_vector_input = dat.get_connectivity_features(idx)
print(f"input feature vector shape: {feature_vector_input.shape}")

# Train classifier
### Build adjacency, features matrix and classification labels

In [None]:
# % Build adjacency matrix and input feature vectors
inp_np, lab_np, adj_np = dat.get_training_data()
print(f"Adjacency matrix : {adj_np.shape} [VxV]")
print(f"Labels {lab_np.shape} : [V]")
print(f"Input feature vector {inp_np.shape} : [VxF]")

In [None]:
# % Load data to GPU
labels_np = dat.get_labels()
adj = torch.tensor(adj_np, dtype=torch.float32).to(device)
inp_raw = torch.tensor(inp_np, dtype=torch.float32).to(device)  # [V=871,  F6216]
lab = torch.tensor(labels_np, dtype=torch.float32).to(device)  # for binary classification

In [None]:
# % Sanitize and whiten data
clean_inp = sanitize_data(inp_raw)
inp = whiten(clean_inp)
inp.shape, adj.shape, lab.shape

In [None]:
# % Visalization of sanity check
visual_sanity_check_input(inp_raw)
visual_sanity_check_input(clean_inp)
visual_sanity_check_input(inp)

In [None]:
# % sanity check on graph adjacency matrix
model = GCN(inp.shape[1], adj, hdim=64)
plotting.plot_matrix(
    model.adj.detach().cpu().numpy(),
    figure=(6, 6),
    vmax=0.005,
    vmin=0,
    title=f"Graph normalized adjacency matrix {mat.shape}"
)
del model

In [None]:
training_data = {
    INPUTS: inp,
    LABELS: lab
}
metric_dict = {}
for model_name in ["Dense", "GCN"]:
    if model_name == "GCN":
        model = GCN(inp.shape[1], adj, hdim=64)
    else:
        model = DenseNN(inp.shape[1], hdim=64)
    model.to(device)

    model, metrics = training_loop(model, training_data, device, n_epochs=1000)
    metric_dict[model_name] = metrics

In [None]:
def plot_metrics(metric_dict: dict):
    fig, axs = plt.subplots(1, 2, figsize=(10, 6))
    for model_name, metric in metric_dict.items():
        print(metric.keys())
        axs[0].plot(metric["training_losses"], label=model_name)
        axs[1].plot(metric["training_accuracies"], label=f"{model_name} accuracy")
    for ax in axs:
        ax.legend()
        ax.grid()
    axs[0].set_title("Training loss (Binary Cross Entropy)")
    axs[1].set_title("Accuracy")

    plt.show()


plot_metrics(metric_dict)

# Train GCN based on kipf github

In [None]:
inp, lab, adj = dat.get_training_data()
print(f"Adjacency matrix : {adj.shape} [VxV]")
print(f"Labels {lab.shape} : [V]")

In [None]:
model = SparseGCN(inp.shape[1], nhid=16, nclass=2, adjacency=adj)
adj_mat = model.adj.to_dense().cpu().numpy()
plotting.plot_matrix(
    adj_mat,
    figure=(6, 6),
    vmax=0.005,
    vmin=0,
    title=f"Graph normalized adjacency matrix {adj_mat.shape}"
)
del model

### Training model 1

In [None]:

def get_training_dict(data : AbideData, device: torch.device = device,
             nb_train: int = 600, 
             nb_val: int = 100):

    graph_signals, node_labels, adj = data.get_training_data()  

    inp = torch.tensor(graph_signals, dtype=torch.float32).to(device)
    labels = torch.tensor(node_labels, dtype=torch.float32).unsqueeze(1).to(device)

    clean_inp = sanitize_data(inp)
    inp = whiten(clean_inp)

    # get random masks
    shuffle_nodes = np.random.permutation(range(inp.shape[0]))
    train_mask = shuffle_nodes[:nb_train]
    val_mask = shuffle_nodes[nb_train:nb_train+nb_val]
    test_mask = shuffle_nodes[nb_train+nb_val:]

    train_mask = torch.LongTensor(train_mask) #int64 tensor
    val_mask = torch.LongTensor(val_mask)
    test_mask = torch.LongTensor(test_mask)

    return { INPUTS : inp,
            LABELS : labels,
            TRAIN_MASK : train_mask,
            VAL_MASK : val_mask,
            TEST_MASK : test_mask
            }, adj

In [None]:
training_dict, graph_adj = get_training_dict(dat, nb_train=500, nb_val=200)
model_GCN1 = SparseGCN(training_dict[INPUTS].shape[1], 
                  nhid=64, 
                  nclass=1, 
                  adjacency=graph_adj, proba_dropout=0.3)

In [None]:
trained_model_GCN, train_log, val_log = train(model_GCN1, training_dict, 
                                              nEpochs=400, 
                                              optimizer_params={'lr': 0.01, 'weight_decay': 0.05})
torch.save(trained_model_GCN.state_dict(), "__trained_model_GCN")

In [None]:
plot_learning_curves(train_log, val_log, title="Training of model_GCN1")

In [None]:
# Testing 
loss_test, acc_test = test(trained_model_GCN, training_dict)
print(f"Test loss: {loss_test.item():.4f}, Test accuracy: {acc_test.item():.4f}")

In [None]:
# Train bigger model

model_GCN2 = SparseGCN(
    training_dict[INPUTS].shape[1],
    nhid=[2048, 512, 32],
    nclass=1,
    adjacency=adj,
)

trained_model_GCN2, train_log2, val_log2 = train(model_GCN2, training_dict,
                                                 nEpochs=200,
                                                 optimizer_params={'lr': 0.001, 'weight_decay': 0.001})

plot_learning_curves(train_log2, val_log2, title="Training of model_GCN2")

In [None]:
# Testing 
loss_test, acc_test = test(trained_model_GCN2, training_dict)
print(f"Test loss: {loss_test.item():.4f}, Test accuracy: {acc_test.item():.4f}")