In [6]:
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import pandas as pd
import torch
import copy
from tqdm import tqdm
from metrics import evaluation_metrics

from slim import SLIMDataModule
import torch.nn as nn

# Instantiate the DataModule

In [None]:
import matplotlib.pyplot as plt

data_module = SLIMDataModule(data_dir="./data")
train_dataloader = data_module.train_dataloader()
# Get first batch
batch = next(iter(train_dataloader))
# Visualise the images

In [147]:
batch[0][0]

tensor([[0.0000, 0.6371, 0.3994,  ..., 0.0000, 0.0147, 0.3531],
        [0.6371, 0.0000, 0.5750,  ..., 0.0000, 0.0818, 0.2694],
        [0.3994, 0.5750, 0.0000,  ..., 0.0000, 0.0000, 0.1344],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.5492, 0.0000],
        [0.0147, 0.0818, 0.0000,  ..., 0.5492, 0.0000, 0.1841],
        [0.3531, 0.2694, 0.1344,  ..., 0.0000, 0.1841, 0.0000]])

In [161]:
def symmetric_normalize(A_tilde):
    """
    Performs symmetric normalization of A_tilde (Adj. matrix with self loops):
      A_norm = D^{-1/2} * A_tilde * D^{-1/2}
    Where D_{ii} = sum of row i in A_tilde.

    A_tilde (N, N): Adj. matrix with self loops
    Returns:
      A_norm : (N, N)
    """

    eps = 1e-5
    d = A_tilde.sum(dim=1) + eps
    D_inv = torch.diag(torch.pow(d, -0.5))
    return D_inv @ A_tilde @ D_inv


def batch_normalize(batch):
    batch_n = torch.zeros_like(batch)
    for i, A in enumerate(batch):
        batch_n[i] = symmetric_normalize(A + torch.eye(n=A.shape[0]))
    return batch_n

In [164]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group["lr"]


def train_model(
    model,
    train_dataloader,
    val_dataloader,
    num_epochs=100,
    lr=0.01,
    validate_every=1,
    patience=10,
    criterion=None,
    **kwargs,
):
    """
    Train the model, validate every 'validate_every' epochs, and pick the
    checkpoint with best validation accuracy.

    Parameters:
    -----------
    model : torch.nn.Module
        The PyTorch model to train.
    train_dataloader : torch.utils.data.DataLoader
        DataLoader for the training set.
    val_dataloader : torch.utils.data.DataLoader
        DataLoader for the validation set.
    num_epochs : int
        Number of training epochs.
    lr : float
        Learning rate for the optimizer.
    validate_every : int
        Validate (and possibly checkpoint) every 'validate_every' epochs.
    patience : int
        Patience for learning rate scheduler.
    criterion : torch.nn.Module
        Loss function.

    Returns:
    --------
    best_loss_history : list
        The training loss history across epochs.
    best_model_state_dict : dict
        The state dictionary of the model achieving the best validation accuracy.
    """

    model.train()
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", patience=patience
    )
    train_loss_history = []
    val_loss_history = []

    best_val_loss = torch.inf
    best_model_state_dict = None
    val_loss = 0.0

    progress_bar = tqdm(range(num_epochs))
    for epoch in progress_bar:
        progress_bar.set_description(f"Epoch {epoch}|{num_epochs}")
        model.train()
        epoch_loss = 0.0

        for batch in train_dataloader:
            inputs, targets = batch
            inputs = batch_normalize(inputs)
            optimizer.zero_grad()

            # Forward pass on training data
            outputs = model.forward(inputs, **kwargs)
            loss = criterion(outputs, targets.to(model.device))
            loss.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # Record training loss
            epoch_loss += loss.item()

        avg_loss = epoch_loss / len(train_dataloader)
        train_loss_history.append(avg_loss)

        # Validation step
        if (epoch + 1) % validate_every == 0 or (epoch + 1) == num_epochs:
            model.eval()
            val_loss = 0.0
            for batch in val_dataloader:
                inputs, targets = batch
                outputs = model(inputs)

                val_loss += criterion(outputs, targets.to(model.device)).item()

            val_loss /= len(val_dataloader)
            val_loss_history.append(val_loss)
            scheduler.step(val_loss)

            lr = get_lr(optimizer)

            # Check if this is the best f1 score so far
            if val_loss > best_val_loss:
                best_val_loss = val_loss
                best_model_state_dict = copy.deepcopy(model.state_dict())

            if lr < 1e-5:
                break

        progress_bar.set_postfix({"train_loss": avg_loss, "val_loss": val_loss})

    # If we have a best model, load it
    if best_model_state_dict is not None:
        model.load_state_dict(best_model_state_dict)

    return train_loss_history, val_loss_history, best_model_state_dict


@torch.no_grad()
def evaluate_model(model, dataloader, criterion):
    """
    Runs forward pass, calculates binary predictions (threshold=0.5),
    and returns the accuracy score.
    """
    model.eval()
    val_loss = 0.0
    eval_metrics = {
        "mae": 0,
        "pcc": 0,
        "js_dis": 0,
        "avg_mae_bc": 0,
        "avg_mae_ec": 0,
        "avg_mae_pc": 0,
    }

    for batch in dataloader:
        inputs, targets = batch
        inputs.to(model.device)
        outputs.to(model.device)
        outputs = model(inputs)

        val_loss += criterion(targets, outputs).item()
        batch_metrics = evaluation_metrics(
            outputs.detach().numpy(), targets.detach().numpy()
        )

        for k, v in batch_metrics.items():
            eval_metrics[k] += v

    val_loss /= len(dataloader)
    for v in eval_metrics.values():
        v /= len(dataloader)
    return val_loss, eval_metrics

In [115]:
X = torch.randn(32, 15, 112)

new_x = torch.nn.functional.interpolate(X, size=(189), mode="linear").squeeze(0)
new_x.shape

torch.Size([32, 15, 189])

In [178]:
from torch_geometric.nn.conv import SAGEConv, GCNConv
import torch.nn as nn

num_nodes = 1000  # Adjust based on your dataset
embedding_dim = 128  # Dimension of node embeddings

# Trainable node embeddings
node_embeddings = nn.Embedding(num_nodes, embedding_dim)


class GCNLayer(nn.Module):
    """
    A single layer of a Graph Convolutional Network (GCN).
    """

    def __init__(self, input_dim, output_dim, use_nonlinearity=True):
        super(GCNLayer, self).__init__()
        self.use_nonlinearity = use_nonlinearity
        self.Omega = nn.Parameter(
            torch.randn(input_dim, output_dim)
            * torch.sqrt(torch.tensor(2.0) / (input_dim + output_dim))
        )
        self.beta = nn.Parameter(torch.zeros(output_dim))

    def forward(self, A_normalized, H_k):
        agg = torch.matmul(A_normalized, H_k)  # local agg
        H_k_next = torch.matmul(agg, self.Omega) + self.beta
        return nn.functional.relu(H_k_next) if self.use_nonlinearity else H_k_next


# GraphSAGE model
class GraphSAGE(nn.Module):
    def __init__(self, hidden_channels, out_size, n_layers: int = 2):
        super().__init__()
        self.in_channels = hidden_channels
        self.out_size = out_size
        self.conv = nn.ModuleList()
        for _ in range(n_layers - 1):
            self.conv.append(
                GCNLayer(input_dim=hidden_channels, output_dim=hidden_channels)
            )
        self.conv.append(
            GCNLayer(
                input_dim=hidden_channels,
                output_dim=hidden_channels,
                use_nonlinearity=False,
            )
        )

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, A):
        A = A.to(self.device)
        X = torch.ones(
            A.shape[0],
            A.shape[1],
            self.in_channels,
            dtype=torch.float32,
            device=self.device,
        )
        for layer in self.conv:
            X = layer(A, X)
        X = X.permute(0, 2, 1)
        X = torch.nn.functional.interpolate(
            X, size=(self.out_size,), mode="linear"
        ).squeeze(0)
        X = X.permute(0, 2, 1)
        A_pred = torch.zeros(
            A.shape[0],
            self.out_size,
            self.out_size,
            dtype=torch.float32,
            device=self.device,
        )
        for i, x in enumerate(X):
            A_pred[i] = torch.sigmoid(x @ x.T)
        A_pred = A_pred * (
            A_pred > 0.2
        )  # Thresholding to preserve sparse brain connectivity
        return A_pred

In [181]:
# Define the model, loss function, and optimizer
in_dim = batch[0].shape[1]
out_dim = batch[1].shape[1]
dim = 100
n_layers = 5

model = GraphSAGE(out_size=out_dim, hidden_channels=dim, n_layers=n_layers)
model.to(torch.device("mps"))
criterion = nn.MSELoss()

In [None]:
train_losses, val_losses, _ = train_model(
    model=model,
    train_dataloader=data_module.train_dataloader(),
    val_dataloader=data_module.val_dataloader(),
    num_epochs=100,
    lr=0.01,
    validate_every=1,
    patience=50,
    criterion=criterion,
)

Epoch 99|100: 100%|██████████| 100/100 [00:13<00:00,  7.47it/s, train_loss=0.104, val_loss=0.549]


In [None]:
def plot_loss(train_losses, val_losses, title="Losses"):
    plt.figure()
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.title(title)
    plt.legend()
    plt.show()

In [None]:
plot_loss(train_losses=train_losses, val_losses=val_losses)

In [None]:
## Evaluation metrics

_, eval_metrics = evaluate_model(
    model, data_module.val_dataloader(), criterion=criterion
)

print(eval_metrics)