In [None]:
#default_exp topo_solvers

In [None]:
#exporti
import torch

In [None]:
#hide
from nbdev.showdoc import show_doc

# Eval module

In [None]:
#exporti
class EvalModule:
    """
    A class that contains methods for the evaluation of topo solvers.
    """

    @staticmethod
    def _push_to_device(solutions, device):
        for solution in solutions:
            solution.device = device


    @staticmethod
    @torch.no_grad()
    def _run_epoch(topo_solver, dataloader, criteria):
        criteria_dict = defaultdict(list)

        for problems_or_solutions, gt_solutions in dataloader:
            EvalModule._push_to_device(gt_solutions, device=topo_solver.device)
            solutions = topo_solver(problems_or_solutions, eval_mode=True)

            assert solutions[0].θ.device == gt_solutions[0].θ.device, f"EvalModule: {solutions[0].θ.device=}, but {gt_solutions[0].θ.device=}."

            for criterion in criteria:
                criterion_values = criterion(solutions, gt_solutions, binary=True)
                criterion_values = list(criterion_values.detach().cpu().numpy())
                criteria_dict[criterion.name] += criterion_values

            for criterion in criteria:
                criteria_dict[criterion.name] = list(np.float_(criteria_dict[criterion.name]))

        return criteria_dict


    @staticmethod
    @torch.no_grad()
    def get_first_solutions(topo_solver, n_solutions, dataloader):
        """
        Returns the first `n_solutions` solutions obtained via the topo_solver from problems from the dataloader.

        Returns
        -------
        list
        """
        return_solutions = []

        for problems_or_solutions, gt_solutions in dataloader:
            EvalModule._push_to_device(gt_solutions, device=topo_solver.device)
            solutions = topo_solver(problems_or_solutions, eval_mode=True)

            for solution in solutions:
                return_solutions.append(solution)
                if len(return_solutions) >= n_solutions:
                    return return_solutions

        return return_solutions


    @staticmethod
    def __call__(topo_solver, criteria, dataloader):
        """
        Evalate criteria with outputs from the topo solver.

        Returns
        -------
        collections.defaultdict
        """
        criteria_dict = EvalModule._run_epoch(
            topo_solver=topo_solver,
            dataloader=dataloader,
            criteria=criteria,
        )
        return criteria_dict