# Modular TRM Training for Traveling Salesman Problem (TSP)

This notebook demonstrates a modular approach to training and evaluating a TRM neural network on synthetic TSP instances using PyTorch. The code is organized for easy adaptation to other combinatorial optimization problems.

## 1. Import Libraries and Set Up Environment
Import all required libraries, set random seeds, and configure device (CPU/GPU) for TSP experiments.

In [1]:
import os, math, random
from dataclasses import dataclass
from typing import Tuple, Any, List
import numpy as np
import torch
import torch.nn.functional as F
from torch import nn
from torch.utils.data import Dataset, DataLoader
import sys
sys.path.append(os.path.join("..", "src"))
from exploretinyrm.trm import TRM, TRMConfig

def set_seed(seed: int = 123):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
set_seed(123)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


## 2. AMP and EMA Utilities
Define automatic mixed precision (AMP) and exponential moving average (EMA) utility functions and classes for training stability.

In [2]:
try:
    from torch.amp import autocast as _autocast, GradScaler as _GradScaler
    _USE_TORCH_AMP = True
except ImportError:
    from torch.cuda.amp import autocast as _autocast, GradScaler as _GradScaler
    _USE_TORCH_AMP = False

def make_grad_scaler(is_cuda: bool):
    if _USE_TORCH_AMP:
        try:
            return _GradScaler("cuda", enabled=is_cuda)
        except TypeError:
            return _GradScaler(enabled=is_cuda)
    else:
        return _GradScaler(enabled=is_cuda)

def amp_autocast(is_cuda: bool, use_amp: bool):
    if _USE_TORCH_AMP:
        try:
            return _autocast(device_type="cuda", enabled=(is_cuda and use_amp))
        except TypeError:
            return _autocast(enabled=(is_cuda and use_amp))
    else:
        return _autocast(enabled=(is_cuda and use_amp))

class EMA:
    def __init__(self, model: torch.nn.Module, decay: float = 0.999):
        self.decay = decay
        self.shadow = {
            name: param.detach().clone()
            for name, param in model.named_parameters()
            if param.requires_grad
        }

    def update(self, model: torch.nn.Module) -> None:
        d = self.decay
        with torch.no_grad():
            for name, param in model.named_parameters():
                if not param.requires_grad:
                    continue
                self.shadow[name].mul_(d).add_(param.detach(), alpha=1.0 - d)

    def copy_to(self, model: torch.nn.Module) -> None:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in self.shadow:
                    param.copy_(self.shadow[name])

from contextlib import contextmanager

@contextmanager
def use_ema_weights(model: torch.nn.Module, ema: EMA):
    backup = {
        name: param.detach().clone()
        for name, param in model.named_parameters()
        if param.requires_grad
    }
    ema.copy_to(model)
    try:
        yield
    finally:
        with torch.no_grad():
            for name, param in model.named_parameters():
                if name in backup:
                    param.copy_(backup[name])

## 3. TSP Dataset Preparation
Implement dataset generation for synthetic TSP instances, including random city coordinates and optimal tour labels.

In [3]:
class TSPDataset(Dataset):
    """Synthetic TSP instances: random city coordinates and optimal tour labels."""
    def __init__(self, n_samples: int, n_cities: int = 5, seed: int = 0):
        self.n_cities = n_cities
        self.rng = np.random.default_rng(seed)
        self.samples = [self._generate_sample() for _ in range(n_samples)]
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx): return self.samples[idx]
    def _generate_sample(self):
        coords = self.rng.uniform(0, 1, (self.n_cities, 2))
        # Compute distance matrix
        dist = np.linalg.norm(coords[:, None, :] - coords[None, :, :], axis=-1)
        # Find optimal tour (brute-force for small n_cities)
        from itertools import permutations
        best_tour = None
        best_length = float('inf')
        for perm in permutations(range(self.n_cities)):
            length = sum(dist[perm[i], perm[(i+1)%self.n_cities]] for i in range(self.n_cities))
            if length < best_length:
                best_length = length
                best_tour = perm
        x_tokens = coords.flatten() # [n_cities*2]
        y_tokens = np.array(best_tour, dtype=np.int64) # [n_cities]
        return torch.from_numpy(x_tokens).float(), torch.from_numpy(y_tokens)

def get_tsp_loaders(n_train=512, n_val=128, batch_size=16, n_cities=5, seed=123):
    ds_tr = TSPDataset(n_train, n_cities=n_cities, seed=seed)
    ds_va = TSPDataset(n_val, n_cities=n_cities, seed=seed+1)
    return (
        DataLoader(ds_tr, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True),
        DataLoader(ds_va, batch_size=batch_size, shuffle=False, pin_memory=True)
    )

train_loader, val_loader = get_tsp_loaders(
    n_train=2048,
    n_val=512,
    batch_size=16,
    n_cities=5,
    seed=123
)

In [4]:
# Show some examples of the dataset
for i in range(2):
    x, y = train_loader.dataset[i]
    print(f"Example {i}:")
    print("City coordinates:", x.numpy().reshape(-1, 2))
    print("Optimal tour:", y.numpy())

Example 0:
City coordinates: [[0.6823519  0.05382102]
 [0.22035988 0.18437181]
 [0.1759059  0.8120945 ]
 [0.92334497 0.2765744 ]
 [0.81975454 0.8898927 ]]
Optimal tour: [0 3 4 2 1]
Example 1:
City coordinates: [[0.51297045 0.2449646 ]
 [0.8242416  0.21376297]
 [0.74146706 0.6299402 ]
 [0.92740726 0.23190819]
 [0.79912513 0.51816505]]
Optimal tour: [0 1 3 4 2]


## 4. Model Configuration and Initialization
Configure TRM model parameters for TSP, instantiate the model, optimizer, scaler, and EMA.

In [5]:
N_CITIES = 5
INPUT_TOKENS = 1000  # Discretized coordinate values (for embedding)
OUTPUT_TOKENS = N_CITIES  # City indices: 0..N_CITIES-1
SEQ_LEN = N_CITIES * 2  # Each city has (x, y)

D_MODEL = 128
N_SUP   = 16
N       = 6
T       = 3
USE_ATT = False

cfg = TRMConfig(
    input_vocab_size=INPUT_TOKENS,
    output_vocab_size=OUTPUT_TOKENS,
    seq_len=SEQ_LEN,
    d_model=D_MODEL,
    n_layers=2,
    use_attention=USE_ATT,
    n_heads=8,
    dropout=0.0,
    mlp_ratio=4.0,
    token_mlp_ratio=2.0,
    n=N,
    T=T,
    k_last_ops=None,
    stabilize_input_sums=True
)

model = TRM(cfg).to(device)
print("Params (M):", sum(p.numel() for p in model.parameters())/1e6)

optimizer = torch.optim.AdamW(
    model.parameters(), lr=3e-4, weight_decay=0.0, betas=(0.9, 0.95)
)

scaler = make_grad_scaler(device.type == "cuda")
ema = EMA(model, decay=0.999)

Params (M): 0.523552


NVIDIA GeForce RTX 5060 Ti with CUDA capability sm_120 is not compatible with the current PyTorch installation.
The current PyTorch install supports CUDA capabilities sm_50 sm_60 sm_61 sm_70 sm_75 sm_80 sm_86 sm_90 sm_37 compute_37.
If you want to use the NVIDIA GeForce RTX 5060 Ti GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/



## 5. Sanity Checks on Data
Run assertions to verify input shapes, value ranges, and tour validity for the TSP dataset.

In [6]:
# Sanity checks for TSP data
for i in range(len(train_loader.dataset)):
    x, y = train_loader.dataset[i]
    coords = x.numpy().reshape(N_CITIES, 2)
    assert coords.shape == (N_CITIES, 2)
    assert np.all((coords >= 0) & (coords <= 1))
    assert y.shape[0] == N_CITIES
    assert set(y.numpy()) == set(range(N_CITIES)), "Tour must visit each city exactly once"
print("Sanity OK: coordinates in [0,1], valid tour.")

Sanity OK: coordinates in [0,1], valid tour.


## 6. Training and Evaluation Functions
Define training and evaluation functions, including loss calculation, metric reporting, and EMA evaluation for TSP tours.

In [7]:
def tour_ce_loss(logits: torch.Tensor, y_true: torch.Tensor) -> torch.Tensor:
    # logits: [B, N, C], y_true: [B, N]
    B, N, C = logits.shape
    return F.cross_entropy(logits.reshape(B*N, C), y_true.reshape(B*N))

def tour_exact_match(preds, y_true):
    # Returns fraction of cities in correct position
    preds = preds.cpu().numpy() if isinstance(preds, torch.Tensor) else preds
    y_true = y_true.cpu().numpy() if isinstance(y_true, torch.Tensor) else y_true
    return np.mean(preds == y_true)

def train_one_epoch(
    model: TRM,
    loader: DataLoader,
    optimizer,
    scaler,
    epoch: int,
    use_amp: bool = True,
    ema: "EMA | None" = None
):
    model.train()
    total_ce, total_em, total_steps = 0.0, 0.0, 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device, non_blocking=True)
        y_true   = y_true.to(device,   non_blocking=True)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            optimizer.zero_grad(set_to_none=True)
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
            logits_tour = logits.float().view(logits.size(0), N_CITIES, logits.size(2))
            loss_ce = tour_ce_loss(logits_tour, y_true)
            with torch.no_grad():
                em = torch.tensor([tour_exact_match(logits_tour[i].argmax(dim=-1), y_true[i]) for i in range(x_tokens.size(0))], device=logits.device)
            loss = loss_ce
            if use_amp:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                if ema is not None:
                    ema.update(model)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                if ema is not None:
                    ema.update(model)
            total_ce   += loss_ce.detach().item()
            total_em   += em.mean().item()
            total_steps += 1
    print(f"Epoch {epoch:02d} | CE {total_ce/max(1,total_steps):.4f} | Exact match {total_em/max(1,total_steps):.3f}")

@torch.no_grad()
def evaluate(model: TRM, loader: DataLoader, n_sup_eval: int = N_SUP, show_examples: int = 3):
    model.eval()
    acc_list = []
    example_count = 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device)
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(n_sup_eval):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
        logits_tour = logits.float().view(logits.size(0), N_CITIES, logits.size(2))
        em = torch.tensor([tour_exact_match(logits_tour[i].argmax(dim=-1), y_true[i]) for i in range(x_tokens.size(0))], device=logits.device)
        acc = em.mean()
        acc_list.append(acc)
        # Print a few examples from the first batch
        if example_count < show_examples:
            preds = logits_tour.argmax(dim=-1)
            xs = x_tokens.cpu().numpy()
            ys = y_true.cpu().numpy()
            preds_np = preds.cpu().numpy()
            for i in range(min(show_examples - example_count, xs.shape[0])):
                print(f"\nExample {example_count + 1}:")
                print("City coordinates:")
                print(xs[i].reshape(N_CITIES, 2))
                print("Predicted tour:", preds_np[i])
                print("True tour:", ys[i])
                print("Exact match score:", tour_exact_match(preds_np[i], ys[i]))
                example_count += 1
    acc = torch.stack(acc_list).mean().item()
    print(f"Validation | Exact match {acc:.3f}")
    return acc

@torch.no_grad()
def evaluate_with_ema(model: TRM, ema: EMA, loader: DataLoader, n_sup_eval: int = N_SUP):
    with use_ema_weights(model, ema):
        return evaluate(model, loader, n_sup_eval=n_sup_eval)

## 7. Training Loop
Run the main training loop for several epochs, reporting metrics for both raw and EMA weights.

In [8]:
EPOCHS = 5
for epoch in range(1, EPOCHS+1):
    train_one_epoch(model, train_loader, optimizer, scaler, epoch, use_amp=False, ema=ema)
    acc_raw = evaluate(model, val_loader)
    acc_ema = evaluate_with_ema(model, ema, val_loader)
    print(f"Validation (raw) | Tour accuracy {acc_raw:.3f}")
    print(f"Validation (EMA) | Tour accuracy {acc_ema:.3f}")

RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

## 8. Model Inference and Visualization
Show predicted tours for a batch of TSP instances and compare to true optimal tours, including graphical visualization.

In [None]:
import matplotlib.pyplot as plt
@torch.no_grad()
def show_predictions(model: TRM, loader: DataLoader, n_batches: int = 1):
    model.eval()
    shown = 0
    for x_tokens, y_true in loader:
        x_tokens = x_tokens.to(device)
        y_true   = y_true.to(device)
        y_state, z_state = model.init_state(batch_size=x_tokens.size(0), device=device)
        for _ in range(N_SUP):
            y_state, z_state, logits, halt_logit = model.forward_step(
                x_tokens, y=y_state, z=z_state, n=N, T=T, k_last_ops=None
            )
        logits_tour = logits.view(logits.size(0), N_CITIES, logits.size(2))
        preds = logits_tour.argmax(dim=-1).cpu().numpy()
        xs = x_tokens.cpu().numpy()
        ys = y_true.cpu().numpy()
        for i in range(min(4, xs.shape[0])):
            coords = xs[i].reshape(N_CITIES, 2)
            pred_tour = preds[i]
            true_tour = ys[i]
            plt.figure(figsize=(6,3))
            plt.subplot(1,2,1)
            plt.title("Predicted Tour")
            plt.plot(coords[pred_tour,0], coords[pred_tour,1], marker='o')
            for j, (x, y) in enumerate(coords):
                plt.text(x, y, str(j), fontsize=12)
            plt.subplot(1,2,2)
            plt.title("True Tour")
            plt.plot(coords[true_tour,0], coords[true_tour,1], marker='o')
            for j, (x, y) in enumerate(coords):
                plt.text(x, y, str(j), fontsize=12)
            plt.suptitle(f"TSP Example {shown+i}")
            plt.show()
        shown += 1
        if shown >= n_batches:
            break

show_predictions(model, val_loader, n_batches=1)