Versió optimitzada de GNN2 amb millores en normalització d'arestes i arquitectura de la xarxa.

Millores principals:
- Normalització GLOBAL dels pesos d'arestes usant estadístiques de TRAIN únicament
- Arquitectura GNN mejorada amb dropout i residual connections
- Scheduler learning rate sofisticat (CosineAnnealingWarmRestarts)
- Early stopping per evitar overfitting
- Validació en train/val split (test set no té labels d'entrenament)

# Importacions

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from pathlib import Path
import numpy as np

from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GATv2Conv, LayerNorm
from torch_geometric.utils import to_dense_batch

# Càrrega de fitxers

In [None]:
train_path = Path("Datasets/train_pt")
graphs_train = [torch.load(f, weights_only=False) for f in sorted(train_path.glob("*.pt"))]

print(f"Nombre de grafs de train: {len(graphs_train)}")

In [None]:
data = graphs_train[300]
print(data)
print(f"Node features shape: {data.x.shape}")
print(f"Edge index shape: {data.edge_index.shape}")
print(f"Edge attr shape: {data.edge_attr.shape}")
print(f"Target: {data.y}")
print(f"\nEdge weights statistics:")
print(f"Min: {data.edge_attr.min():.4f}, Max: {data.edge_attr.max():.4f}")
print(f"Mean: {data.edge_attr.mean():.4f}, Std: {data.edge_attr.std():.4f}")

# Normalització d'arestes

**Important**: La normalització es fa amb estadístiques de TRAIN únicament.
Això assegura que el model aprèn a generalitzar a arestes de diferents escales de manera uniforme.

In [None]:
all_edge_weights = torch.cat([g.edge_attr for g in graphs_train], dim=0)
edge_mean = all_edge_weights.mean()
edge_std = all_edge_weights.std()

print(f"\nEdge weight statistics (before normalization):")
print(f"Min: {all_edge_weights.min():.4f}, Max: {all_edge_weights.max():.4f}")
print(f"Mean: {edge_mean:.4f}, Std: {edge_std:.4f}")

for g in graphs_train:
    g.edge_attr = (g.edge_attr - edge_mean) / (edge_std + 1e-8)

all_normalized = torch.cat([g.edge_attr for g in graphs_train], dim=0)
print(f"\nEdge weight statistics (after normalization):")
print(f"Min: {all_normalized.min():.4f}, Max: {all_normalized.max():.4f}")
print(f"Mean: {all_normalized.mean():.4f}, Std: {all_normalized.std():.4f}")
print(f"\nExample edge weights after normalization: {graphs_train[300].edge_attr[:5].squeeze()}")

# Train/Val Split

In [None]:
val_ratio = 0.15
val_size = int(len(graphs_train) * val_ratio)
train_size = len(graphs_train) - val_size

graphs_val = graphs_train[train_size:]
graphs_train = graphs_train[:train_size]

print(f"Train set: {len(graphs_train)} graphs")
print(f"Val set: {len(graphs_val)} graphs")

In [None]:
class TSPGraphDataset(Dataset):
    def __init__(self, graphs):
        self.graphs = graphs

    def __len__(self):
        return len(self.graphs)

    def __getitem__(self, idx):
        return self.graphs[idx]

train_dataset = TSPGraphDataset(graphs_train)
val_dataset = TSPGraphDataset(graphs_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

# Disseny GNN (Versió Millorada)

In [None]:
class ImprovedTSPGNN(nn.Module):
    def __init__(self, in_channels=2, hidden_channels=64, heads=4, num_layers=4, dropout=0.3):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.heads = heads
        self.num_layers = num_layers
        self.dropout_rate = dropout

        out_channels = hidden_channels // heads if hidden_channels % heads == 0 else hidden_channels
        
        self.input_proj = nn.Linear(in_channels, hidden_channels)
        self.input_bn = nn.BatchNorm1d(hidden_channels)

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()
        self.dropouts = nn.ModuleList()

        for i in range(num_layers):
            self.layers.append(GATv2Conv(
                hidden_channels, 
                out_channels, 
                heads=heads, 
                edge_dim=1,
                concat=True
            ))
            self.norms.append(LayerNorm(hidden_channels))
            self.dropouts.append(nn.Dropout(dropout))

        self.out_proj = nn.Linear(hidden_channels, hidden_channels // 2)
        self.out = nn.Linear(hidden_channels // 2, 1)

    def forward(self, data, return_probs=False):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        edge_attr = edge_attr.view(-1, 1)

        x = self.input_proj(x)
        x = self.input_bn(x)
        x = F.relu(x)

        for i, (conv, norm, dropout) in enumerate(zip(self.layers, self.norms, self.dropouts)):
            h = conv(x, edge_index, edge_attr)
            h = norm(h)
            h = F.relu(h)
            h = dropout(h)
            
            if x.shape == h.shape:
                x = x + h
            else:
                x = h

        x = self.out_proj(x)
        x = F.relu(x)
        logits = self.out(x).squeeze(-1)

        if return_probs:
            x_dense, mask = to_dense_batch(logits.unsqueeze(-1), batch=data.batch)
            probs = torch.softmax(x_dense, dim=1)
            return probs, mask

        return logits

# Training

## Configuració

In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")

model = ImprovedTSPGNN(
    in_channels=2,
    hidden_channels=64,
    heads=4,
    num_layers=4,
    dropout=0.25
).to(device)

num_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {num_params:,}")

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=0.001,
    weight_decay=1e-5
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer,
    T_0=10,
    T_mult=2,
    eta_min=1e-6
)

criterion = nn.CrossEntropyLoss()

best_val_acc = 0.0
patience = 20
patience_counter = 0

## Funcions

In [None]:
def train_epoch(model, loader, optimizer, criterion, device, scheduler=None):
    model.train()
    total_loss = 0.0

    for data in loader:
        data = data.to(device)
        optimizer.zero_grad()

        logits = model(data)

        logits_dense, mask = to_dense_batch(logits.unsqueeze(-1), batch=data.batch)
        batch_loss = 0.0
        batch_size = logits_dense.size(0)

        for i in range(batch_size):
            num_nodes_i = mask[i].sum()
            logits_i = logits_dense[i, :num_nodes_i, 0]
            target_i = data.y[i]
            batch_loss += criterion(logits_i.unsqueeze(0), target_i.unsqueeze(0))

        batch_loss /= batch_size
        batch_loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        if scheduler is not None:
            scheduler.step()

        total_loss += batch_loss.item()

    return total_loss / len(loader)

In [None]:
@torch.no_grad()
def evaluate(model, loader, device):
    model.eval()
    total_correct_top1 = 0
    total_graphs = 0
    normalized_ranks = []

    for data in loader:
        data = data.to(device)
        probs_dense, mask = model(data, return_probs=True)
        batch_size, max_nodes, _ = probs_dense.size()

        for i in range(batch_size):
            num_nodes_i = mask[i].sum()
            probs_i = probs_dense[i, :num_nodes_i, 0]
            target_i = data.y[i]

            pred_idx = probs_i.argmax()
            if pred_idx == target_i:
                total_correct_top1 += 1
            total_graphs += 1

            sorted_indices = torch.argsort(probs_i, descending=True)
            rank = (sorted_indices == target_i).nonzero(as_tuple=True)[0].item() + 1
            normalized_rank = (rank - 1) / (num_nodes_i - 1) if num_nodes_i > 1 else 0.0
            normalized_ranks.append(normalized_rank)

    top1_acc = total_correct_top1 / total_graphs if total_graphs > 0 else 0.0
    mean_normalized_rank = sum(normalized_ranks) / len(normalized_ranks) if normalized_ranks else 0.0

    return top1_acc, mean_normalized_rank

## Entrenament

In [None]:
num_epochs = 100
best_val_acc = 0.0
patience_counter = 0

print(f"\nEntrenament durant {num_epochs} epochs (amb early stopping si patience={patience}):")
print("-" * 90)

for epoch in range(1, num_epochs + 1):
    train_loss = train_epoch(model, train_loader, optimizer, criterion, device, scheduler=scheduler)
    train_acc, train_rank = evaluate(model, train_loader, device)
    val_acc, val_rank = evaluate(model, val_loader, device)
    
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), "model_gnn2_improved_best.pt")
    else:
        patience_counter += 1
    
    if epoch % 5 == 0 or epoch == 1:
        print(f"Epoch {epoch:3d} | Loss: {train_loss:.4f} | "
              f"Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | "
              f"Train Rank: {train_rank:.4f} | Val Rank: {val_rank:.4f}")
    
    if patience_counter >= patience:
        print(f"\nEarly stopping at epoch {epoch} (no improvement for {patience} epochs)")
        break

print("-" * 90)
print(f"Best validation accuracy: {best_val_acc:.4f}")

## Guardar model

In [None]:
save_path = "model_gnn2_improved_final.pt"
torch.save(model.state_dict(), save_path)
print(f"Model guardat a {save_path}")
print(f"Best model guardat a model_gnn2_improved_best.pt")