In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2" 


import torch_geometric
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool, GraphUNet
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_adj
from tqdm import tqdm


from utils.data import GraphDataModule, save_prediction
from utils.training import train_model
from utils.metrics import evaluate_model
from utils.evaluation import evaluate_metrics

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
k_folds = 3

In [None]:
data_module = GraphDataModule("./data", num_workers=1, k_folds=k_folds, p_val=0.33)
train_loaders = data_module.train_dataloaders()
val_loaders = data_module.val_dataloaders()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.utils import to_dense_adj
from torch_geometric.data import Batch
import torch.nn.utils.spectral_norm as spectral_norm

class SuperResMLP(nn.Module):
    """
    A Multi-Layer Perceptron (MLP) for brain graph super-resolution.
    This model maps a vectorized low-resolution brain connectivity matrix to a high-resolution version.
    The vectorization extracts the off-diagonal upper triangular elements from the symmetric adjacency matrices.
    """
    def __init__(self, num_nodes_input: int, num_nodes_output: int, num_hidden_nodes: int, n_layers: int, dropout: float = 0.1):
        super().__init__()
        self.num_nodes_input = num_nodes_input
        self.num_nodes_output = num_nodes_output

        # Calculate sizes based on off-diagonal upper triangular elements.
        # Using (n * (n - 1)) // 2 instead of (n * (n + 1)) // 2 ensures the diagonal is excluded,
        # matching the typical vectorization process for symmetric matrices.
        input_size = (num_nodes_input * (num_nodes_input - 1)) // 2
        output_size = (num_nodes_output * (num_nodes_output - 1)) // 2
        hidden_size = num_hidden_nodes

        # Precompute masks for vectorization and anti-vectorization.
        # These masks are registered as buffers so they move with the model’s device.
        self.register_buffer("input_mask", torch.triu(torch.ones(num_nodes_input, num_nodes_input), diagonal=1).bool())
        self.register_buffer("output_mask", torch.triu(torch.ones(num_nodes_output, num_nodes_output), diagonal=1).bool())

        # Input Layer with Spectral Normalization
        self.input_layer = nn.Sequential(
            nn.Flatten(start_dim=1),
            spectral_norm(nn.Linear(input_size, hidden_size)),
            nn.BatchNorm1d(hidden_size),
            nn.Dropout(dropout),
            nn.LeakyReLU(negative_slope=0.01),
        )

        # Residual Blocks with Spectral Normalization
        self.residual_blocks = nn.ModuleList([
            nn.Sequential(
                spectral_norm(nn.Linear(hidden_size, hidden_size)),
                nn.BatchNorm1d(hidden_size),
                nn.Dropout(dropout),
                nn.LeakyReLU(negative_slope=0.01)
            ) for _ in range(n_layers)
        ])

        # Output Layer with Spectral Normalization
        self.output_layer = nn.Sequential(
            spectral_norm(nn.Linear(hidden_size, output_size)),
            nn.Sigmoid(),  # Ensures outputs are in [0, 1]
        )

        # Apply Xavier Initialization to Linear layers
        self._init_weights()

    def _init_weights(self):
        """Applies Xavier initialization to all Linear layers."""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, samples: Batch) -> torch.Tensor:
        """
        Forward pass.

        Args:
            samples (Batch): A torch_geometric Batch object containing:
                             - edge_index: Graph connectivity in COO format.
                             - edge_attr: Optional edge attributes.
                             - batch: Batch vector mapping nodes to their respective graphs.

        Returns:
            torch.Tensor: Reconstructed high-resolution symmetric adjacency matrices.
        """
        # Convert the graph to dense adjacency matrices for each sample in the batch.
        x = to_dense_adj(samples.edge_index, edge_attr=samples.edge_attr, batch=samples.batch)
        batch_size = x.size(0)

        # Vectorize the dense matrices by extracting only the off-diagonal upper triangular elements.
        x = x[:, self.input_mask]

        # Process through the input layer.
        x = self.input_layer(x)

        # Apply a series of residual blocks with skip connections.
        for block in self.residual_blocks:
            residual = x
            x = block(x)
            x = x + residual
            x = F.leaky_relu(x, negative_slope=0.01)

        # Process through the output layer.
        x = self.output_layer(x)

        # Reconstruct the symmetric high-resolution adjacency matrix from the vectorized output.
        matrix = torch.zeros((batch_size, self.num_nodes_output, self.num_nodes_output), device=x.device)
        matrix[:, self.output_mask] = x
        # Mirror the upper-triangular part to the lower-triangular part.
        matrix = matrix + matrix.transpose(1, 2)

        return matrix

In [None]:
batch,target_batch = next(iter(train_loaders[0]))
input_dim = batch[0].x.shape[0]
output_dim = target_batch[0].x.shape[0]


In [None]:
criterion = nn.MSELoss()

for k in range(k_folds):
    model = SuperResMLP(input_dim, output_dim, num_hidden_nodes=(input_dim+output_dim)//2, n_layers=0).to(device)
    train_loader = train_loaders[k]
    val_loader = val_loaders[k]
    train_loss_history, val_loss_history, lr_history, best_model_state_dict = train_model(
        model=model, 
        train_dataloader=train_loader, 
        val_dataloader=val_loader,
        criterion=criterion,
        num_epochs=100,
    )
    evaluate_metrics(model, val_loader)

In [None]:

loss = evaluate_model(model, val_loader)
print(loss)

torch.save(model, 'model.pth')

In [None]:
model = torch.load("model.pth", weights_only=False)

In [None]:
test_dataloader = data_module.test_dataloader()

In [None]:
submission_file = "outputs/test/submission.csv"
save_prediction(model, test_dataloader, submission_file)

In [None]:
df = pd.read_csv(submission_file)

In [None]:
!kaggle competitions submit -c dgl-2025-brain-graph-super-resolution-challenge -f outputs/test/submission.csv -m "test"
