In [9]:
from pathlib import Path
import numpy as np
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch_geometric.nn import GINEConv
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch_geometric.utils import to_dense_batch
from torch_geometric.nn import GCNConv

In [15]:
from pathlib import Path
import torch

def load_raw_graphs(root: Path, limit=None):
    graphs = []
    for i, pt_file in enumerate(sorted(root.glob("*.pt"))):
        if limit is not None and i >= limit:
            break
        data = torch.load(pt_file, weights_only=False)
        if hasattr(data, "node_id"):
            del data.node_id
        graphs.append(data)
    return graphs

root = Path("Datasets/train_pyg")
graphs = load_raw_graphs(root, limit=5000)
print(f"Total graphs loaded: {len(graphs)}")

for idx, g in enumerate(graphs):
    num_nodes = g.x.size(0)
    y = int(g.y)

    if num_nodes == 0:
        print(f"❌ Graf {idx} té 0 nodes")

    if y < 0 or y >= num_nodes:
        print(f"❌ Graf {idx} target fora de rang: y={y}, num_nodes={num_nodes}")
else:
    print("✔️ Tots els grafs tenen y dins [0, num_nodes)")


Total graphs loaded: 4165
❌ Graf 600 target fora de rang: y=-1, num_nodes=9
❌ Graf 1107 target fora de rang: y=-1, num_nodes=24
❌ Graf 1450 target fora de rang: y=-1, num_nodes=21
❌ Graf 1512 target fora de rang: y=-1, num_nodes=5
✔️ Tots els grafs tenen y dins [0, num_nodes)


In [12]:
def load_train_graphs(root: Path):
    graphs = []
    for pt_file in sorted(root.glob("*.pt")):
        data = torch.load(pt_file, weights_only=False)

        # Eliminar node_id perquè PyG el tracta com a node feature
        if hasattr(data, "node_id"):
            del data.node_id

        graphs.append(data)
    return graphs

root = Path("Datasets/train_pyg")
graphs = load_train_graphs(root)
print(f"Total graphs loaded: {len(graphs)}")

Total graphs loaded: 4165


In [3]:
example = graphs[1]
print(example)
print(f"x shape: {example.x.shape}")
print(f"x dtype: {example.x.dtype}")
print(f"edge_index shape: {example.edge_index.shape}")
print(f"edge_index dtype: {example.edge_index.dtype}")
print(f"edge_attr shape: {example.edge_attr.shape}")
print(f"edge_attr dtype: {example.edge_attr.dtype}")
print(f"y shape: {example.y.shape}")
print(f"y dtype: {example.y.dtype}")
#print(f"node_id shape: {example.node_id.shape}")
#print(f"node_id dtype: {example.node_id.dtype}")

Data(x=[280, 2], edge_index=[2, 78120], edge_attr=[78120, 1], y=241)
x shape: torch.Size([280, 2])
x dtype: torch.float32
edge_index shape: torch.Size([2, 78120])
edge_index dtype: torch.int64
edge_attr shape: torch.Size([78120, 1])
edge_attr dtype: torch.float32
y shape: torch.Size([])
y dtype: torch.int64


In [13]:
random.shuffle(graphs)
val_ratio = 0.2
val_size = int(len(graphs) * val_ratio)
train_size = len(graphs) - val_size
graphs_train = graphs[:train_size]
graphs_val = graphs[train_size:]
print(f"Train set: {len(graphs_train)} graphs")
print(f"Val set: {len(graphs_val)} graphs")

train_loader = DataLoader(graphs_train, batch_size=32, shuffle=True)
val_loader = DataLoader(graphs_val, batch_size=32, shuffle=False)

batch_train = next(iter(train_loader))
print(batch_train)
print("Batch x shape:", batch_train.x.shape)
print("Batch edge_index shape:", batch_train.edge_index.shape)
print("Batch edge_attr shape:", batch_train.edge_attr.shape)
print("Batch y shape:", batch_train.y.shape)
print("Batch batch vector shape:", batch_train.batch.shape)

batch_val = next(iter(val_loader))
print(batch_val)
print("Batch x shape:", batch_val.x.shape)
print("Batch edge_index shape:", batch_val.edge_index.shape)
print("Batch edge_attr shape:", batch_val.edge_attr.shape)
print("Batch y shape:", batch_val.y.shape)
print("Batch batch vector shape:", batch_val.batch.shape)

Train set: 3332 graphs
Val set: 833 graphs
DataBatch(x=[4072, 2], edge_index=[2, 1207992], edge_attr=[1207992, 1], y=[32], batch=[4072], ptr=[33])
Batch x shape: torch.Size([4072, 2])
Batch edge_index shape: torch.Size([2, 1207992])
Batch edge_attr shape: torch.Size([1207992, 1])
Batch y shape: torch.Size([32])
Batch batch vector shape: torch.Size([4072])
DataBatch(x=[4961, 2], edge_index=[2, 1556172], edge_attr=[1556172, 1], y=[32], batch=[4961], ptr=[33])
Batch x shape: torch.Size([4961, 2])
Batch edge_index shape: torch.Size([2, 1556172])
Batch edge_attr shape: torch.Size([1556172, 1])
Batch y shape: torch.Size([32])
Batch batch vector shape: torch.Size([4961])


In [14]:
from torch_geometric.loader import DataLoader

train_loader = DataLoader(graphs, batch_size=32, shuffle=False)

for batch_idx, batch in enumerate(train_loader):
    ptr = batch.ptr
    y = batch.y

    for i in range(len(y)):
        start = ptr[i].item()
        end = ptr[i+1].item()
        num_nodes = end - start
        target = y[i].item()

        if num_nodes == 0 or target < 0 or target >= num_nodes:
            print(f"❌ Batch {batch_idx}, graf {i}: target={target}, num_nodes={num_nodes}")
            print("ptr:", ptr)
            print("y:", y)
            raise SystemExit

print("✔️ Tots els batches tenen targets coherents amb num_nodes")


❌ Batch 55, graf 23: target=-1, num_nodes=24
ptr: tensor([   0,  109,  558,  830, 1110, 1324, 1393, 1451, 1508, 2086, 2161, 2498,
        2569, 2653, 2875, 3148, 3443, 3505, 4117, 4130, 4461, 4524, 4951, 5054,
        5078, 5084, 5174, 5290, 5564, 5622, 5799, 5843, 5854])
y: tensor([ 41,   3,  34, 241, 132,  20,  11,  28, 175,  38,  21,  39,  35, 185,
        255, 148,   7, 174,  12, 328,  24,  53,  33,  -1,   1,  31,  71,  42,
         53,   8,  21,   9])


SystemExit: 

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [5]:
class TSPGNN(nn.Module):
    def __init__(self, hidden_dim=64, num_layers=3):
        super().__init__()

        self.hidden_dim = hidden_dim
        self.num_layers = num_layers

        # MLP per a la primera capa GINEConv
        self.mlp_in = nn.Sequential(
            nn.Linear(2, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # Capa GINE inicial
        self.convs = nn.ModuleList()
        self.convs.append(
            GINEConv(
                nn.Sequential(
                    nn.Linear(hidden_dim, hidden_dim),
                    nn.ReLU(),
                    nn.Linear(hidden_dim, hidden_dim)
                ),
                edge_dim=1
            )
        )

        # Capes GINE intermèdies
        for _ in range(num_layers - 1):
            self.convs.append(
                GINEConv(
                    nn.Sequential(
                        nn.Linear(hidden_dim, hidden_dim),
                        nn.ReLU(),
                        nn.Linear(hidden_dim, hidden_dim)
                    ),
                    edge_dim=1
                )
            )

        # Capa final: score per node
        self.node_classifier = nn.Linear(hidden_dim, 1)

    def forward(self, x, edge_index, edge_attr, batch):
        h = self.mlp_in(x)

        for conv in self.convs:
            h = conv(h, edge_index, edge_attr)
            h = F.relu(h)

        logits = self.node_classifier(h).squeeze(-1)
        return logits

In [6]:
torch.cuda.empty_cache()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = TSPGNN(
    hidden_dim=16,   # 32 o 64
    num_layers=2
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

optimizer = torch.optim.Adam(
    model.parameters(),
    lr=1e-3
)

criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0
patience = 20
patience_counter = 0

Device: cuda
Model parameters: 1,489


In [7]:
def train_one_epoch(model, loader, optimizer, device):
    model.train()
    total_loss = 0.0
    total_correct = 0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        optimizer.zero_grad()

        logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        losses = []
        correct = 0

        for i in range(len(batch.y)):
            start = batch.ptr[i].item()
            end = batch.ptr[i+1].item()

            logits_i = logits[start:end]          # [num_nodes_graph_i]
            target_i = batch.y[i].item()          # scalar

            # CrossEntropyLoss expects [1, num_classes] and [1]
            loss_i = F.cross_entropy(
                logits_i.view(1, -1),
                torch.tensor([target_i], device=device)
            )
            losses.append(loss_i)

            pred_i = logits_i.argmax().item()
            if pred_i == target_i:
                correct += 1

        loss = torch.stack(losses).mean()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        total_correct += correct
        total_graphs += len(batch.y)

    avg_loss = total_loss / len(loader)
    accuracy = total_correct / total_graphs

    return avg_loss, accuracy


@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    total_graphs = 0

    for batch in loader:
        batch = batch.to(device)
        logits = model(batch.x, batch.edge_index, batch.edge_attr, batch.batch)

        losses = []
        correct = 0

        for i in range(len(batch.y)):
            start = batch.ptr[i].item()
            end = batch.ptr[i+1].item()

            logits_i = logits[start:end]
            target_i = batch.y[i].item()

            loss_i = F.cross_entropy(
                logits_i.view(1, -1),
                torch.tensor([target_i], device=device)
            )
            losses.append(loss_i)

            pred_i = logits_i.argmax().item()
            if pred_i == target_i:
                correct += 1

        loss = torch.stack(losses).mean()

        total_loss += loss.item()
        total_correct += correct
        total_graphs += len(batch.y)

    avg_loss = total_loss / len(loader)
    accuracy = total_correct / total_graphs

    return avg_loss, accuracy

In [8]:
num_epochs = 200

for epoch in range(1, num_epochs + 1):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, device)
    val_loss, val_acc = evaluate(model, val_loader, device)

    print(
        f"Epoch {epoch:03d} | "
        f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
        f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}"
    )

    # Early stopping
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "best_model.pt")
        print("  → New best model saved")
    else:
        patience_counter += 1
        if patience_counter >= patience:
            print("Early stopping triggered")
            break

RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
