In [1]:
from pathlib import Path
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch_geometric.nn import GINEConv
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import GCNConv

In [2]:
def load_train_graphs(root: Path):
    graphs = []
    for pt_file in sorted(root.glob("*.pt"))[278:428]:
        data = torch.load(pt_file, weights_only=False)

        # Eliminar node_id perquè PyG el tracta com a node feature
        if hasattr(data, "node_id"):
            del data.node_id

        graphs.append(data)
    return graphs

root = Path("Datasets/train_pyg")
graphs = load_train_graphs(root)
print(f"Total graphs loaded: {len(graphs)}")

Total graphs loaded: 150


In [3]:
example = graphs[1]
print(example)
print(f"x shape: {example.x.shape}")
print(f"x dtype: {example.x.dtype}")
print(f"edge_index shape: {example.edge_index.shape}")
print(f"edge_index dtype: {example.edge_index.dtype}")
print(f"edge_attr shape: {example.edge_attr.shape}")
print(f"edge_attr dtype: {example.edge_attr.dtype}")
print(f"y shape: {example.y.shape}")
print(f"y dtype: {example.y.dtype}")
#print(f"node_id shape: {example.node_id.shape}")
#print(f"node_id dtype: {example.node_id.dtype}")

Data(x=[48, 2], edge_index=[2, 2256], edge_attr=[2256, 1], y=37)
x shape: torch.Size([48, 2])
x dtype: torch.float32
edge_index shape: torch.Size([2, 2256])
edge_index dtype: torch.int64
edge_attr shape: torch.Size([2256, 1])
edge_attr dtype: torch.float32
y shape: torch.Size([])
y dtype: torch.int64


In [4]:
random.shuffle(graphs)
val_ratio = 0.2
val_size = int(len(graphs) * val_ratio)
train_size = len(graphs) - val_size
graphs_train = graphs[:train_size]
graphs_val = graphs[train_size:]
print(f"Train set: {len(graphs_train)} graphs")
print(f"Val set: {len(graphs_val)} graphs")

train_loader = DataLoader(graphs_train, batch_size=32, shuffle=True)
val_loader = DataLoader(graphs_val, batch_size=32, shuffle=False)

batch_train = next(iter(train_loader))
print(batch_train)
print("Batch x shape:", batch_train.x.shape)
print("Batch edge_index shape:", batch_train.edge_index.shape)
print("Batch edge_attr shape:", batch_train.edge_attr.shape)
print("Batch y shape:", batch_train.y.shape)
print("Batch batch vector shape:", batch_train.batch.shape)

batch_val = next(iter(val_loader))
print(batch_val)
print("Batch x shape:", batch_val.x.shape)
print("Batch edge_index shape:", batch_val.edge_index.shape)
print("Batch edge_attr shape:", batch_val.edge_attr.shape)
print("Batch y shape:", batch_val.y.shape)
print("Batch batch vector shape:", batch_val.batch.shape)

Train set: 120 graphs
Val set: 30 graphs
DataBatch(x=[893, 2], edge_index=[2, 31190], edge_attr=[31190, 1], y=[32], batch=[893], ptr=[33])
Batch x shape: torch.Size([893, 2])
Batch edge_index shape: torch.Size([2, 31190])
Batch edge_attr shape: torch.Size([31190, 1])
Batch y shape: torch.Size([32])
Batch batch vector shape: torch.Size([893])
DataBatch(x=[612, 2], edge_index=[2, 16318], edge_attr=[16318, 1], y=[30], batch=[612], ptr=[31])
Batch x shape: torch.Size([612, 2])
Batch edge_index shape: torch.Size([2, 16318])
Batch edge_attr shape: torch.Size([16318, 1])
Batch y shape: torch.Size([30])
Batch batch vector shape: torch.Size([612])


In [5]:
class TSPGNN(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=3):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # MLP per a la primera capa GINEConv
        self.mlp_in = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Capa GINE inicial
        self.convs = nn.ModuleList()
        self.convs.append(
            GINEConv(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim)
                ),
                edge_dim=1
            )
        )

        # Capes GINE intermèdies
        for _ in range(num_layers - 1):
            self.convs.append(
                GINEConv(
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ReLU(),
                        nn.Linear(hidden_dim, hidden_dim)
                    ),
                    edge_dim=1
                )
            )

        # Capa final: score per node
        self.node_classifier = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, edge_attr, batch):
        h = self.mlp_in(x)

        for conv in self.convs:
            h = conv(h, edge_index, edge_attr)
            h = F.relu(h)

        logits = self.node_classifier(h).squeeze(-1)
        return logits

In [6]:
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = TSPGNN(
    hidden_dim=16,   # 32 o 64
    num_layers=2
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3
)

criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0
patience = 20
patience_counter = 0

Device: cuda
Model parameters: 1,489


In [7]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        losses = []
        correct = 0

        for i in range(len(batch.y)):
            start = batch.ptr[i].item()
            end = batch.ptr[i+1].item()

            logits_i = logits[start:end]          # [num_nodes_graph_i]
            target_i = batch.y[i].item()          # scalar

            # CrossEntropyLoss expects [1, num_classes] and [1]
            loss_i = F.cross_entropy(
                logits_i.view(1, -1),
                torch.tensor([target_i], device=device)
            )
            losses.append(loss_i)

            pred_i = logits_i.argmax().item()
            if pred_i == target_i:
                correct += 1

        loss = torch.stack(losses).mean()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += correct
        total_graphs += len(batch.y)

    avg_loss = total_loss / len(loader)
    accuracy = total_correct / total_graphs

    return avg_loss, accuracy


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        losses = []
        correct = 0

        for i in range(len(batch.y)):
            start = batch.ptr[i].item()
            end = batch.ptr[i+1].item()

            logits_i = logits[start:end]
            target_i = batch.y[i].item()

            loss_i = F.cross_entropy(
                logits_i.view(1, -1),
                torch.tensor([target_i], device=device)
            )
            losses.append(loss_i)

            pred_i = logits_i.argmax().item()
            if pred_i == target_i:
                correct += 1

        loss = torch.stack(losses).mean()

        total_loss += loss.item()
        total_correct += correct
        total_graphs += len(batch.y)

    avg_loss = total_loss / len(loader)
    accuracy = total_correct / total_graphs

    return avg_loss, accuracy

In [8]:
num_epochs = 200

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, device)

    print(
        f"Epoch {epoch:03d} | "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pt")
        print("  → New best model saved")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break

Epoch 001 | Train Loss: 46.9010 | Train Acc: 0.0583 | Val Loss: 22.5111 | Val Acc: 0.1333
  → New best model saved
Epoch 002 | Train Loss: 29.6182 | Train Acc: 0.0750 | Val Loss: 7.3348 | Val Acc: 0.1000
Epoch 003 | Train Loss: 10.5382 | Train Acc: 0.0500 | Val Loss: 5.8149 | Val Acc: 0.0333
Epoch 004 | Train Loss: 8.4723 | Train Acc: 0.0667 | Val Loss: 5.5957 | Val Acc: 0.0333
Epoch 005 | Train Loss: 5.7392 | Train Acc: 0.0333 | Val Loss: 5.8338 | Val Acc: 0.0667
Epoch 006 | Train Loss: 6.1820 | Train Acc: 0.0917 | Val Loss: 7.2175 | Val Acc: 0.0667
Epoch 007 | Train Loss: 6.4800 | Train Acc: 0.0833 | Val Loss: 5.4449 | Val Acc: 0.1333
Epoch 008 | Train Loss: 5.9909 | Train Acc: 0.0917 | Val Loss: 5.3249 | Val Acc: 0.1000
Epoch 009 | Train Loss: 6.3992 | Train Acc: 0.0750 | Val Loss: 4.4641 | Val Acc: 0.1333
Epoch 010 | Train Loss: 7.1101 | Train Acc: 0.0750 | Val Loss: 4.2008 | Val Acc: 0.1333
Epoch 011 | Train Loss: 4.0139 | Train Acc: 0.1083 | Val Loss: 3.8710 | Val Acc: 0.1000
Epo