In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from pathlib import Path
import numpy as np

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch_geometric.utils import to_dense_batch

In [3]:
def load_train_graphs(root: Path):
    graphs = []
    for pt_file in sorted(root.glob("*.pt")):
        data = torch.load(pt_file, weights_only=False)
        graphs.append(data)
        print(pt_file)
    return graphs

In [4]:
root = Path("Datasets/train_pyg")
graphs = load_train_graphs(root)

print(f"Total graphs loaded: {len(graphs)}")
print(graphs[1513])

Datasets/train_pyg/00000.pt
Datasets/train_pyg/00001.pt
Datasets/train_pyg/00002.pt
Datasets/train_pyg/00003.pt
Datasets/train_pyg/00004.pt
Datasets/train_pyg/00005.pt
Datasets/train_pyg/00006.pt
Datasets/train_pyg/00007.pt
Datasets/train_pyg/00008.pt
Datasets/train_pyg/00009.pt
Datasets/train_pyg/00010.pt
Datasets/train_pyg/00011.pt
Datasets/train_pyg/00012.pt
Datasets/train_pyg/00013.pt
Datasets/train_pyg/00014.pt
Datasets/train_pyg/00015.pt
Datasets/train_pyg/00016.pt
Datasets/train_pyg/00017.pt
Datasets/train_pyg/00018.pt
Datasets/train_pyg/00019.pt
Datasets/train_pyg/00020.pt
Datasets/train_pyg/00021.pt
Datasets/train_pyg/00022.pt
Datasets/train_pyg/00023.pt
Datasets/train_pyg/00024.pt
Datasets/train_pyg/00025.pt
Datasets/train_pyg/00026.pt
Datasets/train_pyg/00027.pt
Datasets/train_pyg/00028.pt
Datasets/train_pyg/00029.pt
Datasets/train_pyg/00030.pt
Datasets/train_pyg/00031.pt
Datasets/train_pyg/00032.pt
Datasets/train_pyg/00033.pt
Datasets/train_pyg/00034.pt
Datasets/train_pyg/0

In [5]:
val_ratio = 0.2
val_size = int(len(graphs) * val_ratio)
train_size = len(graphs) - val_size

graphs_val = graphs[train_size:]
graphs_train = graphs[:train_size]

print(f"Train set: {len(graphs_train)} graphs")
print(f"Val set: {len(graphs_val)} graphs")

Train set: 1212 graphs
Val set: 302 graphs


In [6]:
class TSPDataset(Dataset):
    def __init__(self, graphs):
        self.graphs = graphs

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]

train_dataset = TSPDataset(graphs_train)
val_dataset = TSPDataset(graphs_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(next(iter(train_loader)))
print(next(iter(val_loader)))

DataBatch(x=[2325, 2], edge_index=[2, 283728], edge_attr=[283728, 1], y=[32], node_id=[2325], batch=[2325], ptr=[33])
DataBatch(x=[688, 2], edge_index=[2, 16832], edge_attr=[16832, 1], y=[32], node_id=[688], batch=[688], ptr=[33])


In [7]:
class TSPGNN(nn.Module):
    def __init__(self, in_channels=2, hidden_channels=64, heads=4, num_layers=4, dropout=0.3):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.heads = heads
        self.num_layers = num_layers
        self.dropout_rate = dropout

        out_channels = hidden_channels // heads if hidden_channels % heads == 0 else hidden_channels
        
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        self.input_bn = nn.BatchNorm1d(hidden_channels)

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()

        for i in range(num_layers):
            self.layers.append(GATv2Conv(
                hidden_channels, 
                out_channels, 
                heads=heads, 
                edge_dim=1,
                concat=True
            ))
            self.norms.append(LayerNorm(hidden_channels))
            self.dropouts.append(nn.Dropout(dropout))

        self.out_proj = nn.Linear(hidden_channels, hidden_channels // 2)
        self.out = nn.Linear(hidden_channels // 2, 1)

    def forward(self, data, return_probs=False):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        edge_attr = edge_attr.view(-1, 1)

        x = self.input_proj(x)
        x = self.input_bn(x)
        x = F.relu(x)

        for i, (conv, norm, dropout) in enumerate(zip(self.layers, self.norms, self.dropouts)):
            h = conv(x, edge_index, edge_attr)
            h = norm(h)
            h = F.relu(h)
            h = dropout(h)
            
            if x.shape == h.shape:
                x = x + h
            else:
                x = h

        x = self.out_proj(x)
        x = F.relu(x)
        logits = self.out(x).squeeze(-1)

        if return_probs:
            x_dense, mask = to_dense_batch(logits.unsqueeze(-1), batch=data.batch)
            probs = torch.softmax(x_dense, dim=1)
            return probs, mask

        return logits

In [8]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = TSPGNN(
    in_channels=2,
    hidden_channels=64,
    heads=4,
    num_layers=4,
    dropout=0.25
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=1e-5
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,
    T_mult=2,
    eta_min=1e-6
)

criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0
patience = 20
patience_counter = 0

Device: cuda
Model parameters: 36,993
