### Environment Setup and Dependency Installation
Don't run this if you don't have problems with dataset creation (sectino below)

!source .venv/bin/activate
%pip install -r requirements.txt --break-system-packages

### Dataset Creation
First create the dataset using the command below. If the command fails, try running it directly in the terminal instead.
After the dataset is created, adapt the config file in `configs`. The default is `configs/test.yaml`.


In [None]:
# create dataset
!python3 src/dataset_creator.py --city --n_samples 5 --n_tasks 4 --out_file data/test.pkl

### The Interactive Trainer
This is an adaption of `src/trainer.py`. It can be used to log the city instances and the VSP solutions.

In [5]:
from functools import partial
from pathlib import Path

import numpy as np
import torch
from tqdm.auto import tqdm
from torch.utils.tensorboard import SummaryWriter

from src.VSPSolver import solve_vsp
from src.utils import get_model, get_criterion, get_optimizer, get_dataloaders
from src.city import City
from src.visualization_utils import visualize_paths, visualize_tasks


# add simple tensorboard to the trainer to visualize results 
class InteractiveTrainer:
    def __init__(self, config, log=False, **kwargs):
        self.config = config
        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

        self.model = get_model(config, device=self.device)
        self.train_loader, self.val_loader, self.test_loader = get_dataloaders(config)

        optimizer_class, optimizer_kwargs = get_optimizer(config)
        self.optimizer = optimizer_class(self.model.parameters(), **optimizer_kwargs)

        criterion_class, criterion_kwargs = get_criterion(config)
        self.criterion = lambda func: criterion_class(func, **criterion_kwargs)

        self.n_epochs = config["train"]["n_epochs"]
        self.eval_every = config["train"]["eval_every_n_epochs"]
        self.save_every = config["train"]["save_every_n_epochs"]
        self.save_dir = Path(config["train"]["save_dir"])
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.with_city = config["data"]["city"]
        self.log = log
        self.writer = SummaryWriter(kwargs["logdir"]) if self.log else None

    # compare all paths to chosen path and SVSP path
    def update_log(self, instance, iter: int, loss: float, solution: list[int], labels: list[int], **kwargs):
        if isinstance(instance, City):
            # figure = visualize_paths(instance, solution)
            solution_figure = visualize_tasks(instance, vsp_path=solution)
            gt_figure = visualize_tasks(instance, vsp_path=labels)
            self.writer.add_figure(f"City/iter {iter}: solution", solution_figure)
            self.writer.add_figure(f"City/iter {iter}: ground truth", gt_figure)
        self.writer.add_scalar(f"Loss", loss, iter)

    def compute_metrics(self, i):
        if i % self.eval_every != 0:
            return
        losses = []
        self.model.eval()
        with torch.no_grad():
            for inputs, labels, instance in self.val_loader:
                graph = instance.graph if self.with_city else instance
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)

                theta = self.model(inputs)
                func = partial(solve_vsp, graph=graph)
                criterion = self.criterion(func)
                loss = criterion(theta, labels).mean()
                losses.append(loss.item())

        print(f"Validation loss: {np.mean(losses):.3f}")

    def save_model(self, i):
        if i % self.save_every == 0 and i > 0:
            torch.save(self.model.state_dict(), self.save_dir / f"epoch{i}.pt")

    def train_epoch(self, i):
        self.model.train()
        losses = []
        for iter, (inputs, labels, instance) in tqdm(enumerate(self.train_loader), desc=f"Epoch {i}"):
            graph = instance.graph if self.with_city else instance
            inputs = inputs.to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            theta = self.model(inputs)

            func = partial(solve_vsp, graph=graph)
            criterion = self.criterion(func)
            loss = criterion(theta, labels)
            losses.append(loss.item())

            # workaround to get single solution from the solver
            solution = func(theta.detach().clone().unsqueeze(0))[0]

            loss.backward()
            self.optimizer.step()

            # visual logging here
            if self.log:
                self.global_steps += 1
                self.update_log(instance, self.global_steps, loss, solution, labels)

        print(f"Train loss: {np.mean(losses):.3f}")

    def train(self):
        self.global_steps = 0
        for i in range(self.n_epochs):
            self.train_epoch(i)
            self.compute_metrics(i)
            self.save_model(i)

    def test(self):
        self.model.eval()
        losses = []
        with torch.no_grad():
            for inputs, labels, instance in tqdm(self.test_loader):
                graph = instance.graph if self.with_city else instance
                inputs = inputs.to(self.device)
                labels = labels.to(self.device)
                theta = self.model(inputs)
                func = partial(solve_vsp, graph=graph)
                criterion = self.criterion(func)
                loss = criterion(theta, labels)
                losses.append(loss.item())

        print(f"Test loss: {np.mean(losses):.3f}")


### Train on the dataset
Run the following to start training our model. The figures will be stored in `runs` by default. Use `tensorboard --logdir=runs` to view them.

In [None]:
import yaml

with open("configs/test.yaml", "r") as file:
    config = yaml.safe_load(file)
    trainer = InteractiveTrainer(config, log=False, logdir="runs")
    trainer.train()
    trainer.test()

### Inspecting the Dataset 
Check for city instances where the task is easy.

In [10]:
import pickle
from src.city import SimpleDirectedGraph, Vertex, Edge

def get_outgoing_edges(vertex: Vertex, graph: SimpleDirectedGraph, sol: list[int]) -> list[Edge]:
    result = []

    for i, e in enumerate(graph.get_edges()):
        if e.from_vertex == vertex and sol[i] == 1:
            result.append(e)

    return result

def check_easy_problems(data_path: str, with_city: bool):
    with open(data_path, "rb") as file:
        data = pickle.load(file)
        X = data["X"]
        Y = data["Y"]
        graphs: SimpleDirectedGraph = list(map(lambda city: city.graph, data["cities"])) if with_city else data["graphs"]

    counter = 0

    for i, (g, s) in enumerate(zip(graphs, Y)):
        if len(get_outgoing_edges(g.get_source(), g, s)) == len(g.get_vertices())-2 or \
           len(get_outgoing_edges(g.get_source(), g, s)) == 1:
            counter += 1

    print(f"percentage easy problems: {counter/len(graphs)*100}")

In [11]:
check_easy_problems("data/test.pkl", with_city=True)

percentage easy problems: 70.0
