In [None]:
# helper functions

def load(path):
    with open(path, "r") as f:
        data = f.readlines()
    ds = []
    for line in data:
        x, y = line.split("output")
        x, y = x.strip(), y.strip()
        x = [float(num) for num in x.split()]
        y = [int(num) for num in y.split()]
        points = [(x[i], x[i+1]) for i in range(0, len(x), 2)]
        ds.append((points, y))
    return ds

def tsp_length(points, path):
    # Get points in path order
    path_points = points[path]

    # wrap path around    
    next_points = torch.roll(path_points, -1, dims=0)
    
    # Euclidean distance
    distances = torch.sqrt(torch.sum((path_points - next_points)**2, dim = 1))
    
    return torch.sum(distances)

def mean_tsp_length(x, y):
    bs = x.shape[0]
    distances = []
    for i in range(bs):
        points, path = x[i, :], y[i]
        distances.append(tsp_length(points, path))
    return sum(distances)/len(distances)

In [None]:
# Dataset and collate_fn
import torch
from torch.utils.data import Dataset, DataLoader

class TspDataset(Dataset):
    def __init__(self, data):
        super().__init__()
        self.data = data

    def __getitem__(self, index):
        return self.data[index]
    
    def __len__(self):
        return len(self.data)
    
def collate_fn(batch):
    inputs, targets = zip(*batch)
    inputs = torch.tensor(inputs)
    targets = [y[:-1] for y in targets] # answer loops back to start
    targets = torch.tensor(targets) - 1 # indexing points from 1

    return inputs, targets


In [None]:
# simple trainer
import torch.optim as optim
import matplotlib.pyplot as plt

class SimpleTrainer:
    def __init__(self, model, device='cuda'):
        self.model = model.to(device)
        self.device = device
        self.losses = []  # Store (step, loss) tuples
        self.pred_lengths = []  # Store (step, pred_length) tuples
        self.gt_lengths = []  # Store (step, gt_length) tuples
    
    def train(self, train_loader, num_steps=5000, lr=1e-3, log_interval=100):
        """Simple training loop with basic loss logging"""
        optimizer = optim.AdamW(self.model.parameters(), lr=lr)
        self.model.train()
        
        train_iter = iter(train_loader)
        
        print(f"Starting training for {num_steps} steps...")
        
        for step in range(num_steps):
            # Get batch
            try:
                X, y = next(train_iter)
            except StopIteration:
                train_iter = iter(train_loader)
                X, y = next(train_iter)
            
            X, y = X.to(self.device), y.to(self.device)
            
            # Training step
            optimizer.zero_grad()
            indices, all_logits, loss = self.model(X, y)
            loss.backward()

            # Gradient clipping
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)

            optimizer.step()
            
            # Calculate TSP lengths
            pred_length = mean_tsp_length(X, indices)
            gt_length = mean_tsp_length(X, y)
            
            # Normalize by problem size (number of cities)
            problem_size = X.shape[1]  # Assuming X has shape [batch_size, num_cities, 2]
            pred_length_normalized = pred_length / problem_size
            gt_length_normalized = gt_length / problem_size
            
            # Store metrics
            self.losses.append((step, loss.item()))

            # Print progress
            if step % log_interval == 0:
                print(f"Step {step:4d} | Loss: {loss.item():.4f}")
                print(f"Step {step:4d} | TSP mean pred len/city: {pred_length_normalized:.4f}")
                print(f"Step {step:4d} | TSP mean gt   len/city: {gt_length_normalized:.4f}")

        
        print("Training completed!")
    
    def plot_loss(self):
        """Plot training loss"""
        if not self.losses:
            print("No losses to plot!")
            return
            
        steps, losses = zip(*self.losses)
        plt.figure(figsize=(10, 6))
        plt.plot(steps, losses)
        plt.title('Training Loss')
        plt.xlabel('Step')
        plt.ylabel('Loss')
        plt.grid(True)
        plt.show()

In [None]:
train_path = "tsp_data/tsp_5_train/tsp5.txt"
test_path = "tsp_data/tsp_5_train/tsp5_test.txt"

train, test = load(train_path), load(test_path)

train_ds = TspDataset(train)
train_dl = DataLoader(train_ds, batch_size=32, collate_fn=collate_fn)

test_ds = TspDataset(test)
test_dl = DataLoader(test_ds, batch_size=32, collate_fn=collate_fn)

In [None]:
from ptr_net import PtrNet

model = PtrNet(input_size = 2, hidden_size = 128, start_token_value=(-1.0, -1.0))
trainer = SimpleTrainer(model)
trainer.train(train_dl, num_steps=100_000)

In [5]:
import torch

@torch.no_grad()
def test(model, test_loader):
    model.eval()
    device = next(model.parameters()).device
    
    gt_distances = []
    pred_distances = []
    for batch in test_loader:
        X, y = batch
        X, y = X.to(device), y.to(device)
        indices, _ = model(X)
        pred_distances.append(mean_tsp_length(X, indices))
        gt_distances.append(mean_tsp_length(X, y))
    
    print(f"Mean gt distance {sum(gt_distances)/len(gt_distances)}")
    print(f"Mean pred distance {sum(pred_distances)/len(pred_distances)}")


In [None]:
test(model, test_dl)

Mean gt distance 2.1227197647094727
Mean pred distance 2.1257359981536865


In [None]:
trainer.plot_loss()

In [None]:
import os
import random
from torch.utils.data import IterableDataset

class TspMultiDataset(IterableDataset):
    def __init__(self, batch_size):
        super().__init__()
        data_dir = "tsp_data/tsp_5-20_train/"
        paths = [f"{data_dir}/{file}" for file in os.listdir(data_dir)]
        self.datasets = [load(p) for p in paths]
        self.batch_size = batch_size

    def __iter__(self):
        while True:
            idx_ds = random.randint(0, len(self.datasets)-1)
            idx = random.randint(0, len(self.datasets[idx_ds])-1-self.batch_size)
            yield self.datasets[idx_ds][idx:idx+self.batch_size]

def collate_batch_fn(batch):
    inputs, targets = zip(*batch[0])
    inputs = torch.tensor(inputs)
    targets = [y[:-1] for y in targets] # answer loops back to start
    targets = torch.tensor(targets) - 1 # indexing points from 1

    return inputs, targets

In [16]:
import ptr_net
from importlib import reload
reload(ptr_net)
from ptr_net import PtrNet


model = PtrNet(input_size = 2, hidden_size = 256, start_token_value=(-1.0, -1.0))

multi_ds = TspMultiDataset(batch_size=128)

In [None]:
from torch.utils.data import DataLoader

multi_train_dl = DataLoader(multi_ds, batch_size=1, collate_fn=collate_batch_fn)
trainer = SimpleTrainer(model)
trainer.train(multi_train_dl, num_steps=1_000_000, log_interval=1_000)

In [None]:
trainer.plot_loss()

In [11]:
torch.save(model.state_dict(), "627k_ptrnet_5-20_tsp.pth")

In [7]:
# test
import torch
from ptr_net import PtrNet


model = PtrNet(input_size = 2, hidden_size = 256, start_token_value=(-1.0, -1.0))
state_dict = torch.load("627k_ptrnet_5-20_tsp.pth")
model.load_state_dict(state_dict)

<All keys matched successfully>

In [10]:
test_20 = TspDataset(load("tsp_data/tsp_20_test.txt"))
test_20_dl = DataLoader(test_20, batch_size=32, collate_fn=collate_fn)
test(model, test_20_dl)

Mean gt distance 4.242640495300293
Mean pred distance 4.282567977905273


In [10]:
test_40 = TspDataset(load("tsp_data/tsp_40_test.txt"))
test_40_dl = DataLoader(test_40, batch_size=32, collate_fn=collate_fn)
test(model, test_40_dl)

Mean gt distance 5.821553707122803
Mean pred distance 8.100526809692383


In [11]:
test_50 = TspDataset(load("tsp_data/tsp50_test.txt"))
test_50_dl = DataLoader(test_50, batch_size=32, collate_fn=collate_fn)
test(model, test_50_dl)

Mean gt distance 6.434800624847412
Mean pred distance 10.240120887756348
