In [None]:
from medigraph.data.abide import DEFAULT_ABIDE_LOCATION, AbideData
from medigraph.data.io import Dump
import numpy as np
from nilearn import plotting
from medigraph.model.gcn import GCN
from medigraph.model.baseline import DenseNN
import torch
from tqdm.notebook import tqdm
from medigraph.data.preprocess import sanitize_data, visual_sanity_check_input, whiten
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
%load_ext autoreload
%autoreload 2
device

In [None]:
dat = AbideData()

### Check connectivy matrix for a single patient
- 111x111 matrices
- We'll retrieve the $6216=\frac{111*(111+1)}{2}$ raw coefficients from the upper triangular matrix

In [None]:
# Get the connectivity matrix for the first subject
idx = 0
mat = dat.get_connectivity_matrix(idx)
plotting.plot_matrix(
    mat,
    figure=(6, 6),
    vmax=1,
    vmin=0,
    title=f"Patient {idx} connectivity matrix {mat.shape}"
)
feature_vector_input = dat.get_connectivity_features(idx)
print(f"input feature vector shape: {feature_vector_input.shape}")


# Train classifier
### Build adjacency, features matrix and classification labels

In [None]:
# % Build adjacency matrix and input feature vectors
inp_np, lab_np, adj_np = dat.get_training_data()
print(f"Adjacency matrix : {adj_np.shape} [VxV]")
print(f"Labels {lab_np.shape} : [V]")
print(f"Input feature vector {inp_np.shape} : [VxF]")


In [None]:
# % Load data to GPU
labels_np = dat.get_labels()
adj = torch.tensor(adj_np, dtype=torch.float32).to(device)
inp_raw = torch.tensor(inp_np, dtype=torch.float32).to(device) #[V=871,  F6216]
lab = torch.tensor(labels_np, dtype=torch.float32).to(device) # for binary classification


In [None]:
# % Sanitize and whiten data
clean_inp = sanitize_data(inp_raw)
inp = whiten(clean_inp)
inp.shape, adj.shape, lab.shape

In [None]:
visual_sanity_check_input(inp_raw)
visual_sanity_check_input(clean_inp)
visual_sanity_check_input(inp)

In [None]:
# % Sanitize and whiten data
clean_inp = sanitize_data(inp)
visual_sanity_check_input(clean_inp)
whitened_inp = whiten(clean_inp)
visual_sanity_check_input(whitened_inp)

In [None]:
# % Train 
N_EPOCHS = 1000
selected_inp = clean_inp

training_losses_dict = {}
training_accuracy_dict = {}
for model_name in ["Dense", "GCN"]:
    criterion = torch.nn.BCEWithLogitsLoss()
# for model_name in ["GCN",]:
    if model_name == "GCN":
        model = GCN(selected_inp.shape[1], adj, hdim=64)
    else:
        model = DenseNN(selected_inp.shape[1], hdim=64)
    model.to(device)
    optim = torch.optim.Adam(model.parameters(), lr=1.E-4, weight_decay=0.1) #
    training_losses = []
    training_accuracies = []
    for ep in tqdm(range(N_EPOCHS)):
        model.train()
        optim.zero_grad()
        logit = model(selected_inp)
        loss = criterion(logit, lab)
        loss.backward()
        optim.step()
        with torch.no_grad(): 
            predicted_prob = torch.sigmoid(logit).squeeze()  # Apply sigmoid and remove extra dimensions if any
            predicted = (predicted_prob >= 0.5).long()  # Convert probabilities to 0 or 1
            correct = (predicted == lab).sum().item()
            total = lab.shape[0]
            accuracy = correct / total
            training_accuracies.append(accuracy)
        if ep % 100 == 0:
            print(f"Epoch {ep} loss: {loss.item():10f} - accuracy: {accuracy:.2%}") 
        training_losses.append(loss.detach().cpu())
    training_losses_dict[model_name] = training_losses
    training_accuracy_dict[model_name] = training_accuracies

In [None]:
import matplotlib.pyplot as plt
fig, axs = plt.subplots(1, 2, figsize=(10, 6))
for model_name, training_losses in training_losses_dict.items():
    axs[0].plot(training_losses, label=model_name)
    axs[1].plot(training_accuracy_dict[model_name], label=f"{model_name} accuracy")
for ax in axs:
    ax.legend()
    ax.grid()
axs[0].set_title("Training loss (Binary Cross Entropy)")
axs[1].set_title("Accuracy")

plt.show()

In [None]:
plotting.plot_matrix(
    model.adj.detach().cpu().numpy(),
    figure=(6, 6),
    vmax=0.005,
    vmin=0,
    title=f"Graph normalized adjacency matrix {mat.shape}"
)

In [None]:
adj.sum(axis=1)
model.adj.sum(axis=1)

In [None]:
model.adj@inp