In [None]:
import torch
from torch import Tensor
import torch.nn as nn
from torch_geometric.nn import DenseGCNConv

In [None]:
import numpy as np
from pathlib import Path
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader

class TspDataset(Dataset):
    def __init__(self):
        super().__init__("")
        data_root = Path("processed_heuristic")
        instances = data_root / "instances"
        self.instances = list(instances.iterdir())
    
    def get(self, idx):
        entry = self.instances[idx]
        with np.load(entry / "pairwise.npz") as data:
            distance_matrix = torch.tensor(data['arr_0'], dtype=torch.float)
        with np.load(entry /  "solution.npz") as data:
            route_mask = torch.tensor(data['route_mask'], dtype=torch.float)
            route_distance = data['route_distance']

        graph = Data(x=torch.zeros(distance_matrix.shape[0], 1, dtype=torch.float), edge_attr=distance_matrix, y=route_mask)
        return graph
    
    def len(self):
        return len(self.instances)    

In [None]:
from icecream import ic
from einops import rearrange

class Model(nn.Module):
    def __init__(self):
        super().__init__()
        hidden_dim = 32
        self.activation_fn = nn.ReLU()

        self.conv1 = DenseGCNConv(1, hidden_dim)
        self.edge_conv1 = nn.Conv2d(hidden_dim, 1, 1)
        # self.conv2 = DenseGCNConv(hidden_dim, hidden_dim)
        # self.edge_conv2 = nn.Conv2d(hidden_dim, 1, 1)
    
    def forward(self, graph):
        node_feats = self.conv1(graph.x, graph.edge_attr)
        node_feats = self.activation_fn(node_feats)
        edge_feats = rearrange(node_feats, "b n c -> b c n 1")
        edge_feats = self.edge_conv1(edge_feats @ edge_feats.mT)
        edge_feats = rearrange(edge_feats, "b 1 n1 n2 -> b n1 n2")
        # node_feats = self.conv2(node_feats, edge_feats)
        # edge_feats = self.edge_conv2(n)

        return node_feats, edge_feats

In [None]:
tsp_dataset = TspDataset()
tsp_dataloader = DataLoader(tsp_dataset, batch_size=1)
model = Model()
loss_fn = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters())

from tqdm import tqdm

for batch in tqdm(tsp_dataloader):
    graph = batch
    _, out = model(graph)
    out = out.squeeze()
    loss = loss_fn(graph.y, out)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    

100%|██████████| 1856/1856 [18:47<00:00,  1.65it/s]


<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=cf2edc64-aead-4b81-a3fd-24d0376819e4' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>