In [None]:
#default_exp topo_solvers

In [None]:
#exporti
import os
import torch
from collections import defaultdict

from dl4to.topo_solvers import TopoSolver
from dl4to.utils import get_dataloader, cast_to_problems

In [None]:
#hide
from nbdev.showdoc import show_doc

# Oracle Solver

In [None]:
#export
class OracleSolver(TopoSolver):
    """
    A topo solver that gets a topo dataset of problems and solutions and returns the ground truth solution for any given problem object from the dataset.
    """
    def __init__(self, 
                 dataset:"dl4to.dataset.TopoDataset", # The dataset which is used to look for the given problem and return the assoziated ground truth solution.
                 device:str='cpu' # The device of the topo solver. Possible options are "cpu" and "cuda".
                ):
        super().__init__(device=device, name="OracleSolver")
        self.logs = defaultdict(list)
        self.dataset = dataset


    def _get_new_solutions(self, solutions, eval_mode):
        problems = cast_to_problems(solutions)

        problems_in_dataset = self.dataset.get_problems()
        problem_indices = [problems_in_dataset.index(problem) for problem in problems]

        gt_solutions = [self.dataset[i][1] for i in problem_indices]
        return gt_solutions

In [None]:
#hide
from dl4to.datasets import BasicDataset, TopoDataset

In [None]:
%%time
#hide

def test_with_ledge_and_trivial_solution():
    problem = BasicDataset().ledge()
    gt_solution = problem.trivial_solution
    dataset = TopoDataset()
    dataset.dataset = [(problem, gt_solution)]
    oracle_solver = OracleSolver(dataset=dataset)

    solution = oracle_solver(problem)
    assert torch.allclose(solution.θ, gt_solution.θ)


test_with_ledge_and_trivial_solution()

Found 0 files.
importing dataset...


0it [00:00, ?it/s]

done!
CPU times: user 27 ms, sys: 6 ms, total: 33 ms
Wall time: 36.9 ms
