# Importacions

In [38]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from pathlib import Path

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, GATv2Conv, LayerNorm
from torch_geometric.utils import to_dense_batch

# Càrrega de fitxers

In [39]:
test_path = Path("Datasets/test_pt")
graphs_test = [torch.load(f) for f in sorted(test_path.glob("*.pt"))]
print(f"Nombre de grafs de train: {len(graphs_test)}")

Nombre de grafs de train: 5


In [40]:
# Exemples
data = graphs_test[4]
print(data)
print(data.x)
print(data.id)

Data(x=[29, 2], edge_index=[2, 406], edge_attr=[406, 1], id=[29], name='bayg29', coord=[29])
tensor([[1., 1.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.],
        [0., 0.]])
tensor([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17, 18,
        19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29])


# Càrrega del model

In [41]:
class TSPGNN(nn.Module):
    def __init__(self, in_channels=2, hidden_channels=64, heads=4, num_layers=4):
        super().__init__()

        self.input_proj = nn.Linear(in_channels, hidden_channels * heads)  # projecte input a hidden

        self.layers = nn.ModuleList()
        self.norms = nn.ModuleList()

        # First layer ja té dimensió hidden, residual OK
        self.layers.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, edge_dim=1))
        self.norms.append(LayerNorm(hidden_channels * heads))

        # Hidden layers
        for _ in range(num_layers - 1):
            self.layers.append(GATv2Conv(hidden_channels * heads, hidden_channels, heads=heads, edge_dim=1))
            self.norms.append(LayerNorm(hidden_channels * heads))

        # Output per node (score per node)
        self.out = nn.Linear(hidden_channels * heads, 1)

    def forward(self, data, return_probs=False):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        edge_attr = edge_attr.view(-1, 1)  # assegura la forma correcta

        # Map input a hidden
        x = self.input_proj(x)

        # Residual connections
        for conv, norm in zip(self.layers, self.norms):
            h = conv(x, edge_index, edge_attr)
            h = norm(h)
            h = F.relu(h)
            x = x + h  # residual safe, dimensions coincideixen

        logits = self.out(x).squeeze(-1)  # [num_nodes_total]

        if return_probs:
            # Probabilitats per node dins cada graf
            x_dense, mask = to_dense_batch(logits.unsqueeze(-1), batch=data.batch)
            probs = torch.softmax(x_dense, dim=1)  # softmax per nodes dins el graf
            return probs, mask

        return logits

In [42]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = TSPGNN(
    in_channels=2,
    hidden_channels=32,
    heads=4,
    num_layers=2
).to(device)

model.load_state_dict(torch.load("model_gnn2.pt", map_location=device))
model.eval()

TSPGNN(
  (input_proj): Linear(in_features=2, out_features=128, bias=True)
  (layers): ModuleList(
    (0-1): 2 x GATv2Conv(128, 32, heads=4)
  )
  (norms): ModuleList(
    (0-1): 2 x LayerNorm(128, affine=True, mode=graph)
  )
  (out): Linear(in_features=128, out_features=1, bias=True)
)

In [43]:
data = graphs_test[4].to(device)

with torch.no_grad():
    logits = model(data)
    next_node = logits.argmax().item()
    node_id = data.id[next_node].item()

print(next_node, node_id)

27 28


# Processament Test

In [None]:
# FUNCIÓ PER ELIMINAR NODE D'UN GRAF
def remove_node(data, idx):

    # 1) Eliminar features de node
    x = torch.cat([data.x[:idx], data.x[idx+1:]], dim=0)

    # 2) Eliminar id del node
    node_id = torch.cat([data.id[:idx], data.id[idx+1:]], dim=0)

    # 3) Eliminar coordenades si existeixen
    if hasattr(data, "coord") and data.coord is not None:
        # coord pot ser llista de python → convertim a tensor si cal
        if isinstance(data.coord, list):
            coord = data.coord[:idx] + data.coord[idx+1:]
        else:
            coord = torch.cat([data.coord[:idx], data.coord[idx+1:]], dim=0)
    else:
        coord = None

    # 4) Filtrar arestes que no toquen el node
    keep_edges = (data.edge_index[0] != idx) & (data.edge_index[1] != idx)
    edge_index = data.edge_index[:, keep_edges]
    edge_attr = data.edge_attr[keep_edges]

    # 5) Reindexar nodes:
    #    si node > idx → node-1
    edge_index = edge_index.clone()
    edge_index[edge_index > idx] -= 1

    # 6) Crear graf final
    new_data = Data(
        x = x,
        id = node_id,
        edge_index = edge_index,
        edge_attr = edge_attr
    )

    # Mantenim el nom
    if hasattr(data, "name"):
        new_data.name = data.name

    # Mantenim coord si existeix
    if coord is not None:
        new_data.coord = coord

    return new_data

In [None]:
# FUNCIÓ PRINCIPAL PER EXECUTAR TSP
def run_tsp(model, data, max_stuck=2):
    """
    Executa TSP per un sol graf.
    - Opció B: si next == current, no eliminem cap node.
    - Protecció anti-bucle: si el model repeteix 'next=current' massa vegades,
      parem aquest graf però deixem continuar la resta.
    """
    device = next(model.parameters()).device

    # Fem una còpia segura en CPU per treballar
    data = data.clone().cpu()

    path = []

    # --- Trobar indices ---
    initial_idx = torch.argmax(data.x[:, 0]).item()
    current_idx = torch.argmax(data.x[:, 1]).item()

    initial_id = data.id[initial_idx].item()
    path.append(initial_id)

    print(f"\n=== START GRAPH {data.name} ===")
    print(f"Initial node ID = {initial_id}")
    print("----------------------------------------")

    step = 0
    stuck_count = 0   # comptador de casos next=current repetits

    # Bucle principal
    while data.x.size(0) > 3:

        step += 1
        current_id = data.id[current_idx].item()

        # Predicció
        logits = model(data.to(device)).cpu().detach()
        next_idx = logits.argmax().item()
        next_id  = data.id[next_idx].item()

        print(f"\nStep {step}:")
        print("Nodes disponibles (IDs):", data.id.tolist())
        print(f"Current ID:   {current_id}")
        print(f"Predicted ID: {next_id}")

        # Afegim al path
        path.append(next_id)

        # OPCIÓ B — CONTROL DE BUCLE
        if next_idx == current_idx:
            stuck_count += 1
            print(f"⚠️  Model stuck: next=current ({stuck_count}/{max_stuck})")

            if stuck_count >= max_stuck:
                print(f"❌ GRAF {data.name}: ATURAT PER BUCLE INFINIT")
                print(f"Path parcial retornat: {path}")
                print("========================================\n")
                return path  # Retornem path parcial i parem només aquest graf

            # Mantenim els flags i seguim sense eliminar nodes
            continue
        else:
            stuck_count = 0

        # Eliminar CURRENT (si no és l'inicial)
        if current_idx != initial_idx:

            remove_idx = current_idx
            data = remove_node(data, remove_idx)

            # Corregir desplaçaments
            if remove_idx < next_idx:
                next_idx -= 1
            if remove_idx < initial_idx:
                initial_idx -= 1

        # Actualitzar flags
        new_x = torch.zeros((data.x.size(0), 2))
        new_x[initial_idx, 0] = 1
        new_x[next_idx,    1] = 1
        data.x = new_x

        current_idx = next_idx

    # Finalització amb 3 nodes
    print("\n=== FINALITZACIÓ (últims 3 nodes) ===")
    print("Nodes restants (IDs):", data.id.tolist())

    remaining = list(range(data.x.size(0)))
    remaining.remove(initial_idx)
    remaining.remove(current_idx)

    other_idx = remaining[0]
    other_id = data.id[other_idx].item()

    print(f"Current ID: {data.id[current_idx].item()}")
    print(f"Other ID:   {other_id}")
    print(f"Initial ID: {initial_id}")

    # seqüència final
    path.append(other_id)

    print(f"Path final generat: {path}")
    print("========================================\n")

    return path

In [None]:
# EXECUCIÓ SOBRE TEST
all_paths = []

for i, data in enumerate(graphs_test):
    data = data.to(device)

    path = run_tsp(model, data)

    print(f"Graph {i} ({data.name}):")
    print(path)
    print()

    all_paths.append(path)


=== START GRAPH a280 ===
Initial node ID = 1
----------------------------------------

Step 1:
Nodes disponibles (IDs): [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197, 1

In [None]:
# VALIDAR PATHS
valid_paths = []

for i, path in enumerate(all_paths):
    data = graphs_test[i]
    num_nodes = data.num_nodes

    expected_nodes = set(data.id.tolist())
    path_set = set(path)

    valid_nodes = (path_set == expected_nodes)
    valid_len   = (len(path) == num_nodes)

    initial_node = data.id[torch.argmax(data.x[:, 0]).item()].item()
    valid_initial = (path[0] == initial_node)

    if valid_nodes and valid_len and valid_initial:
        print(f"Graph {i} ({data.name}): ✅ Path vàlid")

        valid_paths.append((graphs_test[i].name, path))

    else:
        print(f"Graph {i} ({data.name}): ❌ Path invàlid")

        if not valid_nodes:
            print(f"  - Falta algun node: {expected_nodes - path_set}")
            print(f"  - Nodes sobrants: {path_set - expected_nodes}")

        if not valid_len:
            print(f"  - Longitud incorrecta: {len(path)} (esperat {num_nodes})")

        if not valid_initial:
            print(f"  - L'inicial incorrecte: esperat {initial_node}, rebut {path[0]}")

Graph 0 (a280): ✅ Path vàlid
Graph 1 (ali535): ❌ Path invàlid
  - Falta algun node: {220, 21}
  - Nodes sobrants: set()
Graph 2 (att48): ❌ Path invàlid
  - Falta algun node: {3, 6, 7, 8, 11, 16, 17, 19, 30}
  - Nodes sobrants: set()
  - Longitud incorrecta: 41 (esperat 48)
Graph 3 (att532): ✅ Path vàlid
Graph 4 (bayg29): ✅ Path vàlid


In [50]:
# GUARDAR PATHS
output_dir = Path("Datasets")
output_dir.mkdir(exist_ok=True)

output_file = output_dir / "test_paths.txt"

with open(output_file, "w") as f:
    for name, path in valid_paths:
        f.write(f"{name} : {path}\n")