In [10]:
import sys
import os

module_path = os.path.join(os.path.abspath(os.getcwd()))
if module_path not in sys.path:
    sys.path.append(module_path)

In [20]:
# create dataset
!cd src && python3 dataset_creator.py --city --n_samples 5 --n_tasks 4 --out_file ../data/test.pkl

Traceback (most recent call last):
  File "/Users/hxi/University/PMLR/stochastic-vehicle-routing/src/dataset_creator.py", line 5, in <module>
    from src.city import City
ModuleNotFoundError: No module named 'src'


256

In [24]:
from functools import partial
from pathlib import Path

import numpy as np
import torch
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

from src.VSPSolver import solve_vsp
from src.utils import get_model, get_criterion, get_optimizer, get_dataloaders
from src.city import City, SimpleDirectedGraph
from src.visualization_utils import visualize_paths


# add simple tensorboard to the trainer to visualize results 
class InteractiveTrainer:
    def __init__(self, config, log=False, **kwargs):
        self.config = config
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.model = get_model(config, device=self.device)
        self.train_loader, self.val_loader, self.test_loader = get_dataloaders(config)

        optimizer_class, optimizer_kwargs = get_optimizer(config)
        self.optimizer = optimizer_class(self.model.parameters(), **optimizer_kwargs)

        criterion_class, criterion_kwargs = get_criterion(config)
        self.criterion = lambda func: criterion_class(func, **criterion_kwargs)

        self.n_epochs = config["train"]["n_epochs"]
        self.eval_every = config["train"]["eval_every_n_epochs"]
        self.save_every = config["train"]["save_every_n_epochs"]
        self.save_dir = Path(config["train"]["save_dir"])
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.with_city = config["data"]["city"]
        self.writer = SummaryWriter(kwargs["logdir"]) if log else None

    # compare all paths to chosen path and SVSP path
    def update_log(self, instance, iter: int, loss: float, solution: list[int], labels: list[int], **kwargs):
        if isinstance(instance, City):
            figure = visualize_paths(instance, solution)
            self.writer.add_figure(f"City/city instance at iter {iter}", figure)
        self.writer.add_scalar(f"Loss", loss, iter)

    def compute_metrics(self, i):
        if i % self.eval_every != 0:
            return
        losses = []
        self.model.eval()
        with torch.no_grad():
            for inputs, labels, instance in self.val_loader:
                graph = instance.graph if self.with_city else instance
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                theta = self.model(inputs)
                func = partial(solve_vsp, graph=graph)
                criterion = self.criterion(func)
                loss = criterion(theta, labels).mean()
                losses.append(loss.item())

        print(f"Validation loss: {np.mean(losses):.3f}")

    def save_model(self, i):
        if i % self.save_every == 0 and i > 0:
            torch.save(self.model.state_dict(), self.save_dir / f"epoch{i}.pt")

    def train_epoch(self, i):
        self.model.train()
        losses = []
        for iter, (inputs, labels, instance) in tqdm(enumerate(self.train_loader), desc=f"Epoch {i}"):
            graph = instance.graph if self.with_city else instance
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            theta = self.model(inputs)

            func = partial(solve_vsp, graph=graph)
            criterion = self.criterion(func)
            loss = criterion(theta, labels)
            losses.append(loss.item())

            # workaround to get single solution from the solver
            solution = func(theta.detach().clone().unsqueeze(0))[0]

            loss.backward()
            self.optimizer.step()

            # visual logging here
            self.global_steps += 1
            self.update_log(instance, self.global_steps, loss, solution, labels)

        print(f"Train loss: {np.mean(losses):.3f}")

    def train(self):
        self.global_steps = 0
        for i in range(self.n_epochs):
            self.train_epoch(i)
            self.compute_metrics(i)
            self.save_model(i)

    def test(self):
        self.model.eval()
        losses = []
        with torch.no_grad():
            for inputs, labels, instance in tqdm(self.test_loader):
                graph = instance.graph if self.with_city else instance
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                theta = self.model(inputs)
                func = partial(solve_vsp, graph=graph)
                criterion = self.criterion(func)
                loss = criterion(theta, labels)
                losses.append(loss.item())

        print(f"Test loss: {np.mean(losses):.3f}")


In [26]:
import yaml

with open("configs/test.yaml", "r") as file:
    config = yaml.safe_load(file)
    trainer = InteractiveTrainer(config, log=True, logdir="runs")
    trainer.train()
    trainer.test()

True


Epoch 0: 0it [00:00, ?it/s]

Train loss: 3.193
Validation loss: 4.004


Epoch 1: 0it [00:00, ?it/s]

Train loss: 3.612
Validation loss: 4.012


Epoch 2: 0it [00:00, ?it/s]

Train loss: 3.701
Validation loss: 4.003


Epoch 3: 0it [00:00, ?it/s]

Train loss: 3.863
Validation loss: 4.000


Epoch 4: 0it [00:00, ?it/s]

Train loss: 3.868
Validation loss: 4.000


Epoch 5: 0it [00:00, ?it/s]

Train loss: 3.887
Validation loss: 4.002


Epoch 6: 0it [00:00, ?it/s]

Train loss: 3.887
Validation loss: 4.006


Epoch 7: 0it [00:00, ?it/s]

Train loss: 3.886
Validation loss: 4.006


Epoch 8: 0it [00:00, ?it/s]

Train loss: 3.893
Validation loss: 4.004


Epoch 9: 0it [00:00, ?it/s]

Train loss: 3.881
Validation loss: 4.001


Epoch 10: 0it [00:00, ?it/s]

Train loss: 3.881
Validation loss: 4.002


Epoch 11: 0it [00:00, ?it/s]

Train loss: 3.880
Validation loss: 4.003


Epoch 12: 0it [00:00, ?it/s]

Train loss: 3.867
Validation loss: 4.003


Epoch 13: 0it [00:00, ?it/s]

Train loss: 3.891
Validation loss: 4.004


Epoch 14: 0it [00:00, ?it/s]

Train loss: 3.903
Validation loss: 4.004


Epoch 15: 0it [00:00, ?it/s]

Train loss: 3.891
Validation loss: 4.008


Epoch 16: 0it [00:00, ?it/s]

Train loss: 3.870
Validation loss: 4.003


Epoch 17: 0it [00:00, ?it/s]

Train loss: 3.903
Validation loss: 4.003


Epoch 18: 0it [00:00, ?it/s]

Train loss: 3.882


KeyboardInterrupt: 