In [None]:
#default_exp topo_solvers

In [None]:
#exporti
import os
import torch

from dl4to.topo_solvers import TopoSolver
from dl4to.solution import Solution

In [None]:
#hide
from nbdev.showdoc import show_doc

# Trivial solver

In [None]:
#export
class TrivialSolver(TopoSolver):
    """
    A topo solver that returns the trivial solution for a problem object.

    Parameters
    ----------
    θ_default : float
        The factor with which the density of the trivial solutions should be multiplied before being returned in the call.
    device : str
        The device of the topo solver. Possible options are "cpu" and "cuda".
    """
    def __init__(self, θ_default=1., device='cpu'):
        super().__init__(device=device)
        self.θ_default = θ_default


    def _get_name(self):
        return "TrivialSolver"


    def _get_new_solution(self, solution):
        if hasattr(solution, 'problem'):
            problem = solution.problem
        else:
            problem = solution

        θ = self.θ_default * torch.ones(1, *problem.shape, device=self.device, dtype=solution.dtype)
        return Solution(problem, θ)


    def _get_new_solutions(self, solutions, eval_mode):
        solutions = [self._get_new_solution(solution) for solution in solutions]
        return solutions

In [None]:
#hide
import shutil
from dl4to.criteria import Binariness, WeightedBCE, Fail
from dl4to.datasets import TopoDataset, BasicDataset
from dl4to.utils import get_dataloader

In [None]:
%%time
#hide

def test_attributes():
    trivial_solver = TrivialSolver(θ_default=1.)
    assert not trivial_solver.trainable
    trivial_solver.to('cpu')


test_attributes()

CPU times: user 52 µs, sys: 33 µs, total: 85 µs
Wall time: 92.7 µs


In [None]:
%%time
#hide

def test_that_it_is_callable():
    trivial_solver = TrivialSolver(θ_default=1.)
    problem = BasicDataset().ledge()
    solution = trivial_solver(problem)


test_that_it_is_callable()

CPU times: user 5.28 ms, sys: 8.83 ms, total: 14.1 ms
Wall time: 27.6 ms


In [None]:
#hide
from dl4to.pde import FDM
def get_dataloader_ledge(batch_size):
    problem = BasicDataset(resolution=30).ledge()
    problem.pde_solver = FDM(padding_depth=0)
    solution_tiv = problem.trivial_solution

    dataset = TopoDataset([(problem, solution_tiv), (problem, solution_tiv)])
    return get_dataloader(dataset, batch_size=batch_size)

In [None]:
%%time
#slow
#hide

def test_that_we_can_run_evaluate_over_dataset(batch_size):
    trivial_solver = TrivialSolver(θ_default=1.)
    dataloader = get_dataloader_ledge(batch_size)
    

    crits = [
        Binariness(),
        WeightedBCE(),
        Fail(),
    ]

    trivial_solver.eval(
        root="tmp_test_folder",
        dataloader=dataloader,
        criteria=crits,
    )

    shutil.rmtree("tmp_test_folder")


test_that_we_can_run_evaluate_over_dataset(batch_size=1)
test_that_we_can_run_evaluate_over_dataset(batch_size=2)
test_that_we_can_run_evaluate_over_dataset(batch_size=4)

CPU times: user 3.12 s, sys: 41.9 ms, total: 3.16 s
Wall time: 994 ms
