In [1]:
import torch
from torch import Tensor
import torch.nn as nn
import torch_geometric.nn as gnn

from icecream import ic

In [33]:
import numpy as np
from pathlib import Path
from torch_geometric.data import Data, Dataset
from torch_geometric.loader import DataLoader
from sklearn.model_selection import train_test_split

class TspDataset(Dataset):
    def __init__(self, split = "train"):
        super().__init__()
        data_root = Path("../data/processed_heuristic-threshold")
        instances = data_root / "instances" / "1"
        instances = sum((list(instance.iterdir()) for instance in instances.iterdir() if instance.stem != "1"), [])

        train_instances, val_instances = train_test_split(instances, test_size = 0.2, random_state = 42)
        self.instances = train_instances if split == "train" else val_instances
    
    def get(self, idx):
        entry = self.instances[idx]
        with np.load(entry / "pairwise.npz") as data:
            distance_matrix = torch.tensor(data['arr_0'], dtype=torch.float)
        distance_matrix = distance_matrix / distance_matrix.max()
        route_mask = torch.from_numpy(np.loadtxt(entry / "sol_mask.txt", dtype=np.float32))[:-1, :-1]
        route_mask = route_mask + route_mask.mT
        
        # route_distance = data['route_distance']
        # distance_matrix = distance_matrix.cuda()
        # route_mask = route_mask.cuda()

        graph = Data(x=distance_matrix[:, :1], edge_attr=distance_matrix, y=route_mask)
        return graph
    
    def len(self):
        return len(self.instances)

In [34]:
from icecream import ic
from einops import rearrange

# class ConvBlock(nn.Module):
#     def __init__(self, input_dim, output_dim):
#         super().__init__()
#         self.activation_fn = nn.ReLU()
#         self.graph_conv = gnn.DenseGATConv(input_dim, output_dim)

#     def forward(self, graph):
#         node_feats = self.graph_conv(graph.x, graph.edge_attr)
#         node_feats = self.activation_fn(node_feats)

#         graph.x = node_feats
#         return graph

class ConvBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.activation_fn = nn.ReLU()

        self.node_conv = gnn.DenseGATConv(input_dim, output_dim)
        
        self.node_edge_conv = gnn.DenseGATConv(input_dim, output_dim)
        self.edge_value_conv = nn.Conv2d(output_dim, 1, 1)
        self.edge_edge_conv = nn.Conv2d(input_dim, output_dim, 1)
    
    def forward(self, graph):    
        node_feats, edge_values, edge_feats = graph
        
        new_node_feats = self.node_conv(node_feats, edge_values)
        
        node_edge_feats = self.node_edge_conv(node_feats, edge_values)
        edge_edge_feats = self.edge_edge_conv(edge_feats)
       
        node_edge_feats = rearrange(node_edge_feats, "b n c -> b c n 1") + rearrange(node_edge_feats, "b n c -> b c 1 n")

        new_edge_feats = node_edge_feats + edge_edge_feats
        
        new_edge_values = torch.sigmoid(self.edge_value_conv(new_edge_feats))
        new_edge_values = rearrange(new_edge_values, "b 1 n1 n2 -> b n1 n2")

        new_node_feats = self.activation_fn(new_node_feats)
        new_edge_feats = self.activation_fn(new_edge_feats)

        return (new_node_feats, new_edge_values, new_edge_feats)

        # node_feats2 = self.graph_conv2(graph.x, graph.edge_attr)
        # node_feats2 = self.activation_fn(node_feats2)
        
        # edge_feats1 = rearrange(node_feats1, "b n c -> b c n 1")
        # edge_feats2 = rearrange(node_feats2, "b n c -> b c 1 n")
        
        # edge_feats = self.edge_conv1(edge_feats1) + self.edge_conv2(edge_feats2)
        # # edge_feats = edge_feats1 @ edge_feats2

        # # gate = self.gate_linear(node_feats.mean(-2))
        # # gate = torch.sigmoid(gate)
        # # edge_feats = gate * edge_feats + (1 - gate) * graph.edge_attr

        # node_feats = node_feats1 + node_feats2# + graph.x.mean(-1, keepdim=True)
        # edge_feats = edge_feats# + graph.edge_attr

        # return Data(x = node_feats, edge_attr = edge_feats)

class Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.model = nn.Sequential(
            ConvBlock(1, 32),
            ConvBlock(32, 32),
            ConvBlock(32, 32),
            ConvBlock(32, 1)
        )

    def forward(self, graph):
        node_feats, edge_feats = graph.x, graph.edge_attr
        node_feats, edge_feats = node_feats.unsqueeze(0), edge_feats.unsqueeze(0)
        graph = (node_feats, edge_feats, edge_feats.unsqueeze(1))
        _, _, edge_feats = self.model(graph)

        out = rearrange(edge_feats, "1 1 n1 n2 -> n1 n2")

        # graph.edge_attr = torch.softmax(graph.edge_attr, -1)
        # out = torch.sigmoid(out)
        return out

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

train_dataset = TspDataset()
val_dataset = TspDataset(split = "val")

train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, pin_memory=True)
val_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=False, pin_memory=True)
# loss_fn = nn.BCELoss(reduction = 'none')
# loss_fn = nn.BCEWithLogitsLoss(reduction = 'none')
loss_fn = nn.CrossEntropyLoss()

def dice_loss(y_pred, y_true):
    y_pred = torch.sigmoid(y_pred)
    # Flatten the predictions and ground truth
    y_true_flat = y_true.flatten()
    y_pred_flat = y_pred.flatten()

    # Compute the intersection and union
    intersection = (y_true_flat * y_pred_flat).sum() + 1
    union = (y_true_flat).sum() + (y_pred_flat).sum() + 1

    # Compute the Dice loss
    dice_loss = 1 - 2 * intersection / union

    return dice_loss

# loss_fn = dice_loss
# loss_fn = nn.MSELoss()

from tqdm import tqdm

while True:
    lowest_val_loss = float("inf")
    patience = 0
    loss_arr = []
    std_arr = []
    model = Model().to(device)
    optimizer = torch.optim.Adam(model.parameters())
    for epoch in range(200):
        if patience >= 3:
            continue
        cu_loss = None
        cu_std = None
        pbar = tqdm(train_dataloader, delay=1)
        for i, batch in enumerate(pbar):
            graph = batch.to(device)
            out = model(graph)
        
            loss = loss_fn(out, graph.y)
    
            # loss = loss * graph.y
            loss = loss.mean()
        
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    
            loss = loss.detach().cpu().item()
            loss_arr.append(loss)
            cu_loss = cu_loss + loss if cu_loss is not None else loss

            std = torch.std(out)
            std = std.detach().cpu().item()
            std_arr.append(std)
            cu_std = cu_std + std if cu_std is not None else std
    
            pbar.set_postfix(loss = cu_loss / (i + 1), std = cu_std / (i + 1))

            if i > 10 and (cu_std / (i + 1) < 1e-4):
                break
        else:
            with torch.no_grad():
                cu_loss = None
                for batch in val_dataloader:
                    graph = batch.to(device)
                    out = model(graph)
        
                    loss = loss_fn(out, graph.y)
                    loss = loss.mean()
                    cu_loss = cu_loss + loss if cu_loss is not None else loss
            val_loss = cu_loss / len(val_dataloader)
            print(f"epoch {epoch}: {val_loss}")
            if val_loss < lowest_val_loss:
                patience = 0
                lowest_val_loss = val_loss
            else:
                patience += 1
            continue
        break
    else:
        break

100%|██████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.39it/s, loss=4.86, std=0.6]


epoch 0: 4.587519645690918


100%|████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.47it/s, loss=4.43, std=0.844]


epoch 1: 4.195655345916748


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.28it/s, loss=4.09, std=1.06]


epoch 2: 3.953949213027954


100%|██████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.24it/s, loss=3.93, std=1.2]


epoch 3: 3.8702948093414307


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.66it/s, loss=3.86, std=1.28]


epoch 4: 3.8247079849243164


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 51.76it/s, loss=3.81, std=1.35]


epoch 5: 3.853071689605713


100%|██████████████████████████████████████████████████████████| 1080/1080 [00:21<00:00, 51.33it/s, loss=3.78, std=1.4]


epoch 6: 3.7317054271698


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 51.49it/s, loss=3.75, std=1.43]


epoch 7: 3.726921319961548


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:21<00:00, 49.85it/s, loss=3.74, std=1.45]


epoch 8: 3.751734495162964


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.42it/s, loss=3.72, std=1.48]


epoch 9: 3.6599910259246826


100%|██████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 51.43it/s, loss=3.7, std=1.51]


epoch 10: 3.7452828884124756


100%|██████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 51.66it/s, loss=3.7, std=1.55]


epoch 11: 3.6534581184387207


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:21<00:00, 50.92it/s, loss=3.68, std=1.59]


epoch 12: 3.693665027618408


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:20<00:00, 52.02it/s, loss=3.68, std=1.58]


epoch 13: 3.6370913982391357


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:21<00:00, 50.72it/s, loss=3.67, std=1.64]


epoch 14: 3.636951446533203


100%|█████████████████████████████████████████████████████████| 1080/1080 [00:19<00:00, 54.08it/s, loss=3.67, std=1.64]


epoch 15: 3.7290799617767334


 80%|██████████████████████████████████████████████▎           | 863/1080 [00:17<00:04, 54.07it/s, loss=3.64, std=1.66]

In [None]:
graph = tsp_dataset[11]

import matplotlib.pyplot as plt

with torch.no_grad():
    # plt.imshow(graph.edge_attr.squeeze())
    # plt.show()
    out = model(graph)
    plt.imshow(out)
    plt.colorbar()
    plt.show()
    plt.imshow(graph.y)
    plt.show()

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_arr)
plt.show()
plt.plot(std_arr)
plt.show()