In [20]:
import numpy as np
import torch
from torch.utils.data import Dataset

# Step 1: Generate a distance matrix
n = 50
distance_matrix = np.random.rand(n, n)
# Make the matrix symmetric and zero diagonal
distance_matrix = (distance_matrix + distance_matrix.T) / 2
np.fill_diagonal(distance_matrix, 0)

# Step 2: Generate a dataset of possible routes
class RoutesDataset(Dataset):
    def __init__(self, distance_matrix, num_routes):
        self.distance_matrix = distance_matrix
        self.num_routes = num_routes

    def __len__(self):
        return self.num_routes

    def __getitem__(self, idx):
        route1 = np.random.permutation(len(self.distance_matrix))
        distance1 = sum(self.distance_matrix[route1[i-1], route1[i]] for i in range(len(route1)))
        
        route2 = np.random.permutation(len(self.distance_matrix))
        distance2 = sum(self.distance_matrix[route2[i-1], route2[i]] for i in range(len(route2)))
        
        return (route1, distance1), (route2, distance2)

routes_dataset = RoutesDataset(distance_matrix, 1000)

In [21]:
import torch.nn as nn

class RouteClassifier(nn.Module):
    def __init__(self):
        super(RouteClassifier, self).__init__()
        self.fc1 = nn.Linear(n, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, route1, route2):
        x1 = torch.relu(self.fc1(route1))
        x1 = torch.relu(self.fc2(x1))
        x1 = self.fc3(x1)

        x2 = torch.relu(self.fc1(route2))
        x2 = torch.relu(self.fc2(x2))
        x2 = self.fc3(x2)

        return torch.sigmoid(x1 - x2) # probability that route1 is shorter than route2
    
# Example:
classifier = RouteClassifier()
# n length of the route
route1 = torch.randn(32, n)
route2 = torch.randn(32, n)

output = classifier(route1, route2)

In [33]:
from torch.utils.data import DataLoader
from torch.optim import Adam

# Prepare the data
data_loader = DataLoader(routes_dataset, batch_size=32, shuffle=True)

# Initialize the classifier and the optimizer
classifier = RouteClassifier()
optimizer = Adam(classifier.parameters(), lr=0.001)

# Train the classifier
for epoch in range(1_000):  # number of epochs
    print(f"Epoch {epoch}")
    loss_sum = 0
    correct_predictions = 0
    total_predictions = 0
    for (route1, distance1), (route2, distance2) in data_loader:
        route1 = route1.float()
        route2 = route2.float()
        label = (distance1 < distance2).float().view(-1, 1)  # reshape the labels

        # Forward pass
        output = classifier(route1, route2)
        predictions = output.round()  # compute the predictions
        correct_predictions += (predictions == label).sum().item()
        total_predictions += label.size(0)

        loss = nn.BCELoss()(output, label)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        loss_sum += loss.item()
    print(f"Loss: {loss_sum / len(data_loader):.2f}", f"Accuracy: {correct_predictions / total_predictions:.2f}")
    # print(f"Accuracy: {correct_predictions / total_predictions}")
    

Epoch 0
Loss: 0.79 Accuracy: 0.50
Epoch 1
Loss: 0.76 Accuracy: 0.47
Epoch 2
Loss: 0.72 Accuracy: 0.50
Epoch 3
Loss: 0.72 Accuracy: 0.49
Epoch 4
Loss: 0.72 Accuracy: 0.50
Epoch 5
Loss: 0.71 Accuracy: 0.49
Epoch 6
Loss: 0.71 Accuracy: 0.51
Epoch 7
Loss: 0.70 Accuracy: 0.51
Epoch 8
Loss: 0.69 Accuracy: 0.54
Epoch 9
Loss: 0.70 Accuracy: 0.49
Epoch 10
Loss: 0.70 Accuracy: 0.51
Epoch 11
Loss: 0.69 Accuracy: 0.51
Epoch 12
Loss: 0.70 Accuracy: 0.52
Epoch 13
Loss: 0.70 Accuracy: 0.53
Epoch 14
Loss: 0.70 Accuracy: 0.49
Epoch 15
Loss: 0.70 Accuracy: 0.49
Epoch 16
Loss: 0.70 Accuracy: 0.50
Epoch 17
Loss: 0.70 Accuracy: 0.53
Epoch 18
Loss: 0.70 Accuracy: 0.49
Epoch 19
Loss: 0.70 Accuracy: 0.50
Epoch 20
Loss: 0.70 Accuracy: 0.52
Epoch 21
Loss: 0.70 Accuracy: 0.51
Epoch 22
Loss: 0.69 Accuracy: 0.52
Epoch 23
Loss: 0.69 Accuracy: 0.53
Epoch 24
Loss: 0.69 Accuracy: 0.51
Epoch 25
Loss: 0.70 Accuracy: 0.49
Epoch 26
Loss: 0.70 Accuracy: 0.49
Epoch 27
Loss: 0.70 Accuracy: 0.49
Epoch 28
Loss: 0.70 Accuracy: 