In [None]:
#default_exp topo_solvers

In [None]:
#exporti
import copy
import time
import torch
import warnings
import numpy as np
import matplotlib.pyplot as plt
from collections import defaultdict

from dl4to.plotting import plot_curve
from dl4to.utils import create_dir, save_dict_as_txt, cast_to_solutions

In [None]:
#hide
from nbdev.showdoc import show_doc

# Training module

In [None]:
#exporti
class TrainModuleVerboseUtils:
    @staticmethod
    def print_current_losses(topo_solver_logs):
        logs = topo_solver_logs
        train_epoch    = logs["train_epochs"][-1]
        train_loss     = logs['train_losses'][-1]
        train_loss_std = logs['train_losses_std'][-1]
        print(f"Train epoch: {train_epoch}. Train loss: {train_loss:.2}±{train_loss_std:.2}.")
        if "val_losses" in logs:
            val_epoch    = logs["val_epochs"][-1]
            val_loss     = logs['val_losses'][-1]
            val_loss_std = logs['val_losses_std'][-1]
            print(f'Valid epoch: {val_epoch}. Valid loss: {val_loss:.2}±{val_loss_std:.2}.')


    @staticmethod
    def plot_train_and_val_losses(topo_solver_logs):
        logs = topo_solver_logs
        assert len(logs['train_epochs']) == len(logs['train_losses']) == len(topo_solver_logs['train_losses_std']), "Training loss log inconsistent."
        assert len(logs['val_epochs'])   == len(logs['val_losses'])   == len(topo_solver_logs['val_losses_std']), "Validation loss log inconsistent."
        fig, axes = plt.subplots(figsize=(7, 3), dpi=200, sharex=False)
        plot_curve(
            x=logs['train_epochs'],
            y=logs['train_losses'],
            y_std=logs['train_losses_std'],
            label="Training loss",
            axis=axes,
            show_all_xticks=False
        )
        plot_curve(
            x=logs['val_epochs'],
            y=logs['val_losses'],
            y_std=logs['val_losses_std'],
            label="Validation loss",
            axis=axes,
            show_all_xticks=False
        )
        plt.legend()
        plt.show()

In [None]:
#exporti
class EpochLossGetter:
    def __init__(self, topo_solver):
        self.solver = topo_solver
        self.model = self.solver.model if hasattr(self.solver, 'model') else None
        self.criterion = self.solver.criterion  if hasattr(self.solver, 'criterion') else None
        self.optimizer = self.solver.optimizer if hasattr(self.solver, 'optimizer') else None
        self._check_attr()


    def _push_to_device(self, solutions):
        for solution in solutions:
            solution.device = self.solver.device


    def _check_attr(self):
        if self.model is None:
            raise AttributeError("TrainModule cannot train w/o a model!")
        if not self.solver.trainable:
            raise AttributeError(f"Topo solver `{self.solver.name}` is not trainable.")
        if self.criterion is None:
            raise AttributeError("TrainModule cannot find a criterion!")


    def _run_batch(self, problems_or_solutions, gt_solutions, train):
        self._push_to_device(gt_solutions)
        solutions = self.solver(problems_or_solutions, eval_mode=not train)
        losses = self.solver.criterion(solutions, gt_solutions, binary=False)
        loss = losses.mean()
        if train:
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()
        losses = list(losses.detach().cpu().numpy().astype(np.float64))
        return losses


    def __call__(self, dataloader, train):
        if train:
            self.model.train()
        else:
            self.model.eval()
        losses = []
        for problems_or_solutions, gt_solutions in dataloader:
            losses += self._run_batch(problems_or_solutions=problems_or_solutions, gt_solutions=gt_solutions, train=train)
        self.model.eval()
        return np.mean(losses), np.std(losses)

In [None]:
#export
class TrainModule:
    """
    A class that contains methods for the training of topo solvers.

    Parameters
    ----------
    topo_solver : dl4to.TopoSolver
        The trainable topo solver that should be trained.
    """
    def __init__(self, topo_solver):
        self.solver = topo_solver
        self.epoch_module = EpochLossGetter(self.solver)


    def _train_epoch(self, dataloader, epoch):
        train_loss, train_loss_std = self.epoch_module(dataloader=dataloader, train=True)
        self.solver.logs['train_epochs'].append(epoch)
        self.solver.logs['train_losses'].append(train_loss)
        self.solver.logs['train_losses_std'].append(train_loss_std)


    def _eval_epoch(self, dataloader, epoch, verbose):
        val_loss, val_loss_std = self.epoch_module(dataloader=dataloader, train=False)
        self.solver.logs['val_epochs'].append(epoch)
        self.solver.logs['val_losses'].append(val_loss)
        self.solver.logs['val_losses_std'].append(val_loss_std)

        if verbose:
            TrainModuleVerboseUtils.print_current_losses(topo_solver_logs=self.solver.logs)


    def _write_solver_to_disc(self, dir_path, prefix, tick):
        self.solver.logs['duration'] = time.time() - tick
        torch.save(self.solver, f'{dir_path}/{prefix}_solver.pt')
        save_dict_as_txt(self.solver.logs, dir_path=dir_path, file_name=f'{prefix}_logs')


    def _write_solver_to_disc_if_best_yet(self, dir_path, tick):
        val_losses = self.solver.logs['val_losses']
        val_loss = val_losses[-1]
        best_val_loss = min(val_losses)
        if val_loss > best_val_loss:
            return best_val_loss
        self._write_solver_to_disc(dir_path=dir_path, prefix="best", tick=tick)
        return val_loss


    def _create_folder_and_save_topo_solver_args(self, root, epochs, patience):
        dir_path = create_dir(name=f"train_results", path=root, prepend_date=False)
        my_dict = {
            'root': root,
            'epochs': epochs,
            'patience': patience,
        }
        args_dict = self.solver.get_args_as_dict()
        args_dict = {**args_dict, **my_dict}
        save_dict_as_txt(my_dict=args_dict, dir_path=dir_path, file_name="solver_description")

        return dir_path


    def _get_best_epoch(self):
        val_losses = self.solver.logs['val_losses']
        val_epochs = self.solver.logs['val_epochs']
        assert len(val_losses) == len(val_epochs), "TrainModule: len(val_losses) != len(val_epochs)"
        best_val_loss_idx = np.argmin(val_losses)
        return val_epochs[best_val_loss_idx]


    def _input_check(self, dataloader_val, validation_interval, patience):
        if dataloader_val is not None:
            return
        if validation_interval is not None:
            raise ValueError("You can not validate without a dataloader_val. Set validation_interval=None or add dataloader_val.")
        if patience != "inf":
            raise ValueError("patience != 'inf' requires dataloader_val != None.")


    def __call__(
        self, root, dataloader_train,
        dataloader_val=None, epochs=100, 
        validation_interval=10, verbose=True,
        patience=None):
        """
        Run the training for the topo solver.

        Parameters
        -------
        root : str
            The directory where the training results should be saved.
        dataloader_train : torch.utils.data.Dataloader
            The dataloader that contains the training data.
        dataloader_val : torch.utils.data.Dataloader
            The dataloader that contains the validation data.
        epochs : int
            The maximal number of training epochs.
        validation_interval : int
            The number of epochs after which a validation step is performed and printed.
        verbose : bool
            Whether to print information on the current training status, like the current loss and epoch.
        patience : int
            If the validation score does not improve for `patience` epochs in a row, then the training is stopped and the best model is used.
        """
        if patience is None:
            patience = "inf"
        self._input_check(dataloader_val, validation_interval, patience)
        dir_path = self._create_folder_and_save_topo_solver_args(root, epochs, patience)
        tick = time.time()

        for epoch in range(epochs):
            self._train_epoch(dataloader_train, epoch)
            validation_free_epoch = (validation_interval is None) or ((epoch % validation_interval != 0) and (epoch != epochs - 1))
            if validation_free_epoch:
                continue
            self._eval_epoch(dataloader_val, epoch, verbose)
            self._write_solver_to_disc_if_best_yet(dir_path, tick)
            if patience != "inf":
                we_have_lost_hope = epoch > self._get_best_epoch() + patience
                if we_have_lost_hope:
                    break
        self._write_solver_to_disc(dir_path=dir_path, prefix="last", tick=tick)
        if verbose:
            TrainModuleVerboseUtils.plot_train_and_val_losses(topo_solver_logs=self.solver.logs)
        print(f'Finished training after {epoch + 1} epochs.\n')

In [None]:
#hide
from dl4to.criteria import WeightedBCE

In [None]:
#hide
class MockTrainableTopoSolver:
    def __init__(self):
        self.model = torch.nn.Sequential()
        self.logs = None
        self.optimizer = None
        self.criterion = WeightedBCE()
        self.metrics = None
        self.trainable = True

In [None]:
%%time
#hide

def test_that_we_can_instanciate():
    topo_solver = MockTrainableTopoSolver()
    train_module = TrainModule(topo_solver)


test_that_we_can_instanciate()

CPU times: user 124 µs, sys: 23 µs, total: 147 µs
Wall time: 154 µs
