In [None]:
#default_exp topo_solvers

In [None]:
#exporti
import copy
import torch
from typing import Union

from dl4to.solution import Solution
from dl4to.utils import get_dataloader, cast_to_solutions, save_dict_as_txt, create_dir

In [None]:
#hide
from nbdev.showdoc import show_doc

# Topo solver

In [None]:
#export
class TopoSolver:
    """
    A parent class that inherits all kinds of topo solvers, i.e., algorithms that solve the topology optimization task.
    """
    def __init__(self, 
                 device:str='cpu', # The device of the topo solver. Can bei either "cpu" or "cuda".
                 name:str=None, # The name of the topo solver.
                 trainable:bool=False, # Whether the topo solver is trainable.
                 differentiable:bool=False # Whether the topo solver is differentiable.
                ):
        self._device = device
        self.name = name
        self._trainable = trainable
        self._differentiable = differentiable


    @property
    def device(self):
        return self._device


    @device.setter
    def device(self, device):
        self._device = device


    def to(self, 
           device:str # The device of the topo solver. Can bei either "cpu" or "cuda".
          ):
        """
        Move the topo solver to `device`.
        """
        self.device = device


    def cuda(self):
        """
        Move the topo solver to `cuda`.
        """
        self.to('cuda')


    def cpu(self):
        """
        Move the topo solver to `cpu`.
        """
        self.to('cpu')


    @property
    def trainable(self):
        return self._trainable


    @property
    def differentiable(self):
        return self._differentiable


    def clone(self):
        """
        Return a `dl4to.topo_solvers.TopoSolver` object, which is clone of the current topo solver.
        """
        return copy.deepcopy(self)


    def _if_tuple_cast_to_list(self, problems_or_solutions):
        if type(problems_or_solutions) is tuple:
            problems_or_solutions = list(problems_or_solutions)
        return problems_or_solutions


    def _prepare_input_in_call(self, problems_or_solutions):
        problems_or_solutions = self._if_tuple_cast_to_list(problems_or_solutions)
        was_list = True

        if type(problems_or_solutions) != list:
            problems_or_solutions = [problems_or_solutions]
            was_list = False

        solutions = cast_to_solutions(problems_or_solutions)
        return solutions, was_list


    def _get_new_solutions(self, solutions, eval_mode):
        raise NotImplementedError("Must be overridden.")


    def __call__(self, 
                 problems_or_solutions:list, # A list containing problem and solution objects.
                 eval_mode:bool=True # Determines whether to calculate gradients for the backwards pass or not. If `True`, then no gradients are calculated.
                ):
        """
        Perform a forward pass of the topo solver. Expects a list of problems or solutions.
        Returns a `dl4to.solution.Solution` object or a list of solutions, if the input was also a list.
        """
        solutions, was_list = self._prepare_input_in_call(problems_or_solutions)
        solutions = self._get_new_solutions(solutions, eval_mode)

        if was_list:
            return solutions
        assert len(solutions) == 1
        return solutions[0]


    def eval(self, 
             root:str, # The root directory where the evaluation results are saved.
             criteria:list, # A list of `dl4to.criteria.Criterion` objects that are used for the evaluation.
             dataloader:torch.utils.data.DataLoader # The dataloader that is used for retrieving the validation data.
            ):
        """
        Evalate criteria with outputs from the topo solver. Returns a `collections.defaultdict` dictionary.
        """
        dir_path = create_dir(name=f"eval_on_{dataloader.dataset.name}", path=root, prepend_date=False)

        logs = EvalModule()(
            topo_solver=self,
            criteria=criteria,
            dataloader=dataloader
        )

        save_dict_as_txt(my_dict=logs, dir_path=dir_path, file_name="eval_logs")
        return logs


    def plot_first_solutions_from_dataloader(
        self,
        root:str, # The root where the plots should be saved.
        n_plots:int, # The number of solutions to plot.
        dataloader:torch.utils.data.DataLoader, # The dataloader that is used to obtain the data.
        camera_position:Union[tuple,list]=(0,.1,.12), # x, y, and z coordinates of the camera position.
        export_png:bool=True, # Whether the figure is exported and saved as a png file, in addition to the standard html format.
    ):
        """
        Saves `n_plots` plots of solutions obtained via the topo_solver from problems from the dataloader.
        """
        dir_path = create_dir(name=f"eval_on_{dataloader.dataset.name}", path=root, prepend_date=False)

        solutions = EvalModule.get_first_solutions(
            topo_solver=self,
            n_solutions=n_plots,
            dataloader=dataloader,
        )

        for i, solution in enumerate(solutions):
            torch.save(solution, f"{dir_path}/solution_{i}.pt")

            solution.plot(
                binary=True,
                solve_pde=False, 
                display=False, 
                file_path=f"{dir_path}/{i}",
                camera_position=camera_position,
                show_colorbar=False, 
                export_png=export_png
            )

In [None]:
show_doc(TopoSolver.to)

<h4 id="TopoSolver.to" class="doc_header"><code>TopoSolver.to</code><a href="__main__.py#L28" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.to</code>(**`device`**:`str`)

Move the topo solver to `device`.

||Type|Default|Details|
|---|---|---|---|
|**`device`**|`str`||The device of the topo solver. Can bei either "cpu" or "cuda".|


In [None]:
show_doc(TopoSolver.cuda)

<h4 id="TopoSolver.cuda" class="doc_header"><code>TopoSolver.cuda</code><a href="__main__.py#L37" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.cuda</code>()

Move the topo solver to `cuda`.

In [None]:
show_doc(TopoSolver.cpu)

<h4 id="TopoSolver.cpu" class="doc_header"><code>TopoSolver.cpu</code><a href="__main__.py#L44" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.cpu</code>()

Move the topo solver to `cpu`.

In [None]:
show_doc(TopoSolver.clone)

<h4 id="TopoSolver.clone" class="doc_header"><code>TopoSolver.clone</code><a href="__main__.py#L61" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.clone</code>()

Return a `dl4to.topo_solvers.TopoSolver` object, which is clone of the current topo solver.

In [None]:
show_doc(TopoSolver.__call__)

<h4 id="TopoSolver.__call__" class="doc_header"><code>TopoSolver.__call__</code><a href="__main__.py#L90" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.__call__</code>(**`problems_or_solutions`**:`list`, **`eval_mode`**:`bool`=*`True`*)

Perform a forward pass of the topo solver. Expects a list of problems or solutions.
Returns a `dl4to.solution.Solution` object or a list of solutions, if the input was also a list.

||Type|Default|Details|
|---|---|---|---|
|**`problems_or_solutions`**|`list`||A list containing problem and solution objects.|
|**`eval_mode`**|`bool`|`True`|Determines whether to calculate gradients for the backwards pass or not. If `True`, then no gradients are calculated.|


In [None]:
show_doc(TopoSolver.eval)

<h4 id="TopoSolver.eval" class="doc_header"><code>TopoSolver.eval</code><a href="__main__.py#L107" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.eval</code>(**`root`**:`str`, **`criteria`**:`list`, **`dataloader`**:`DataLoader`)

Evalate criteria with outputs from the topo solver. Returns a `collections.defaultdict` dictionary.

||Type|Default|Details|
|---|---|---|---|
|**`root`**|`str`||The root directory where the evaluation results are saved.|
|**`criteria`**|`list`||A list of `dl4to.criteria.Criterion` objects that are used for the evaluation.|
|**`dataloader`**|`DataLoader`||The dataloader that is used for retrieving the validation data.|


In [None]:
show_doc(TopoSolver.plot_first_solutions_from_dataloader)

<h4 id="TopoSolver.plot_first_solutions_from_dataloader" class="doc_header"><code>TopoSolver.plot_first_solutions_from_dataloader</code><a href="__main__.py#L127" class="source_link" style="float:right">[source]</a></h4>

> <code>TopoSolver.plot_first_solutions_from_dataloader</code>(**`root`**:`str`, **`n_plots`**:`int`, **`dataloader`**:`DataLoader`, **`camera_position`**:`Union`\[`tuple`, `list`\]=*`(0, 0.1, 0.12)`*, **`export_png`**:`bool`=*`True`*)

Saves `n_plots` plots of solutions obtained via the topo_solver from problems from the dataloader.

||Type|Default|Details|
|---|---|---|---|
|**`root`**|`str`||The root where the plots should be saved.|
|**`n_plots`**|`int`||The number of solutions to plot.|
|**`dataloader`**|`DataLoader`||The dataloader that is used to obtain the data.|
|**`camera_position`**|`typing.Union[tuple, list]`|`(0, 0.1, 0.12)`|x, y, and z coordinates of the camera position.|
|**`export_png`**|`bool`|`True`|Whether the figure is exported and saved as a png file, in addition to the standard html format.|


In [None]:
%%time
#hide

def test_that_we_can_instanciate_a_mock():
    class MockTopoSolver(TopoSolver):
        def _get_name(self):
            return "MockTopoSolver"

    topo_solver = MockTopoSolver()
    topo_solver.name == "MockTopoSolver"

    topo_solver.cuda()


test_that_we_can_instanciate_a_mock()

CPU times: user 86 µs, sys: 49 µs, total: 135 µs
Wall time: 149 µs
