In [1]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "2" 


import torch_geometric
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import pandas as pd
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, TopKPooling, global_mean_pool, GraphUNet
from torch_geometric.data import Batch
from torch_geometric.utils import to_dense_adj
from tqdm import tqdm


from utils.data import GraphDataModule, save_prediction
from utils.training import train_model
from utils.metrics import evaluate_model

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cpu


In [3]:
data_module = GraphDataModule("./data", num_workers=1, batch_size=1)
train_loader = data_module.train_dataloader()
val_loader = data_module.val_dataloader()

Time taken to load ./data\hr_train.csv: 2.889310121536255 seconds
Time taken to load ./data\lr_train.csv: 0.4781632423400879 seconds
Time taken to load ./data\lr_test.csv: 0.48460960388183594 seconds


Converting vectors to graphs: 100%|██████████| 133/133 [00:00<00:00, 339.26it/s]
Converting vectors to graphs: 100%|██████████| 34/34 [00:00<00:00, 297.15it/s]
Converting vectors to graphs: 100%|██████████| 112/112 [00:00<00:00, 747.67it/s]


In [43]:
class UpscalerGNN(nn.Module):
    def __init__(self, input_features, hidden_nodes, hidden_channels, out_channels, attention_dim):
        super().__init__()
        self.input_nodes = 160
        self.output_nodes = 268
        
        self.layer1 = nn.Linear(input_features, hidden_nodes)
        self.gcn = GCNConv(hidden_nodes, hidden_nodes)
        self.graph_unet = GraphUNet(hidden_nodes, hidden_channels, out_channels, depth=3, pool_ratios=0.5)
        
        self.layer2 = nn.Linear(self.output_nodes-self.input_nodes, out_channels)
                
        self.key = nn.Linear(out_channels, attention_dim)
        self.query = nn.Linear(out_channels, attention_dim)

    @property
    def device(self):
        return next(self.parameters()).device

    def forward(self, samples: Batch):
        X1 = self.layer1(samples.x)
        X1 = self.gcn(X1, samples.edge_index, samples.edge_attr)
        X1 = self.graph_unet(X1, samples.edge_index)
                
        X2 = torch.eye(self.output_nodes-self.input_nodes).to(self.device)
        X2 = self.layer2(X2)
        
        X = torch.concat((X1, X2), dim=0)
        K = self.key(X)
        Q = self.query(X)
        A = K @ Q.transpose(-1, -2)
        A = F.sigmoid(A)
        A = A.unsqueeze(0)
        
        return A


In [44]:
batch,target_batch = next(iter(train_loader))
input_dim = batch[0].x.shape[0]
output_dim = target_batch[0].x.shape[0]
input_features = batch[0].x.shape[1]
print(input_features,input_dim,output_dim)

4 160 268


In [45]:
hidden_nodes = 128
hidden_channels = 64
out_channels = 32
attention_dim = 16

model = UpscalerGNN(
    input_features=input_features,
    hidden_nodes=hidden_nodes, 
    hidden_channels=hidden_channels, 
    out_channels=out_channels, 
    attention_dim=attention_dim
)

In [46]:
criterion = nn.L1Loss()

train_loss_history, val_loss_history, lr_history, best_model_state_dict = train_model(
    model=model, 
    train_dataloader=train_loader, 
    val_dataloader=val_loader,
    criterion=criterion,
    num_epochs=100,
)

  4%|▍         | 4/100 [00:21<08:34,  5.35s/it, train_loss=0.253, val_loss=0.244, lr=0.01]


KeyboardInterrupt: 

In [15]:
model.load_state_dict(best_model_state_dict)
loss = evaluate_model(model, val_loader)
print(loss)

torch.save(model, 'model.pth')

0.24420016


In [None]:
model = torch.load("model.pth", weights_only=False)

In [None]:
test_dataloader = data_module.test_dataloader()

In [None]:
submission_file = "outputs/test/submission.csv"
save_prediction(model, test_dataloader, submission_file)

In [None]:
df = pd.read_csv(submission_file)

In [None]:
!kaggle competitions submit -c dgl-2025-brain-graph-super-resolution-challenge -f outputs/test/submission.csv -m "test"
