In [None]:
#default_exp datasets

In [None]:
#exporti
import torch
import numpy as np
from copy import deepcopy
from collections import defaultdict

from dl4to.datasets import TopoDataset

# SIMP dataset

In [None]:
#export

class SIMPDataset(TopoDataset):
    def __init__(self, problems, simp, name=None, verbose=True):
        self.problems = problems
        self.simp = simp
        assert self.simp.return_intermediate_solutions == True
        dataset = self._generate_dataset()
        super().__init__(dataset=dataset,
                         name=name, 
                         verbose=verbose)


    def _generate_dataset(self):
        self.list_of_problems_and_orig_solution_indices = defaultdict(list)
        dataset = []
        i_problem = 0
        for problem in self.problems:
            dataset.append((problem,  0.*problem.trivial_solution))
            dataset.append((problem, 0.5*problem.trivial_solution))
            solutions = self.simp(problems_or_solutions=problem)
            for solution in solutions:
                dataset.append((problem, solution.detach()))
                self.list_of_problems_and_orig_solution_indices[i_problem].append(len(dataset)-1)
            i_problem += 1
        return dataset


    def augment(self, pde_solver, max_augmentation_per_problem=5, threshold=1e-3):
        problems_ = deepcopy(self.problems)
        for i_problem, problem in enumerate(problems_):
            old_pde_solver = self.problems[i_problem].pde_solver
            problem.pde_solver = pde_solver
            augmentation_counter = 0
            simp_correct_counter = 0
            simp_deviations = []
            first_error = 0
            solutions = self.simp(problems_or_solutions=problem)
            solutions_to_add_to_dataset = []
            for i_solution, solution in enumerate(solutions):
                if augmentation_counter < max_augmentation_per_problem:
                    solution._θ = solution._θ.detach()
                    orig_simp_solution = self.dataset[self.list_of_problems_and_orig_solution_indices[i_problem][i_solution]][1]
                    simp_deviation = (solution._θ[problem.Ω_design==-1] - orig_simp_solution._θ[problem.Ω_design==-1]).abs().mean()
                    simp_deviations.append(simp_deviation.item())
                    if simp_deviation > threshold:
                        if augmentation_counter == 0:
                            first_error = simp_deviation
                        solutions_to_add_to_dataset.append(solution)
                        augmentation_counter += 1
                    else:
                        simp_correct_counter += 1
                else:
                    break
            problem.pde_solver = old_pde_solver
            for solution in solutions_to_add_to_dataset:
                solution.pde_solver = old_pde_solver
                solution.problem.pde_solver = old_pde_solver
                solution.u_current_θ = None
                self.dataset.append((solution.problem, solution))
            self._size += augmentation_counter
            print(f"Problem {i_problem}: {simp_correct_counter} correct SIMP iterations (next iteration had an error of {first_error:.3f}).")
            print(simp_deviations)
            print(f"Augmented dataset by {augmentation_counter} samples.")