diff --git a/tidy3d/plugins/invdes/design.py b/tidy3d/plugins/invdes/design.py index 3ed0c043e3..289e8c0129 100644 --- a/tidy3d/plugins/invdes/design.py +++ b/tidy3d/plugins/invdes/design.py @@ -10,7 +10,6 @@ import pydantic.v1 as pd import tidy3d as td -import tidy3d.web as web from tidy3d.components.autograd import get_static from tidy3d.exceptions import ValidationError from tidy3d.plugins.expressions.metrics import Metric @@ -71,7 +70,7 @@ def objective_fn(params: anp.ndarray, aux_data: dict = None) -> float: post_process_val = post_process_fn(data) elif isinstance(data, td.SimulationData): post_process_val = self.metric.evaluate(data) - elif isinstance(data, web.BatchData): + elif getattr(data, "type", None) == "BatchData": raise NotImplementedError("Metrics currently do not support 'BatchData'") else: raise ValueError(f"Invalid data type: {type(data)}") @@ -100,6 +99,21 @@ def initial_simulation(self) -> td.Simulation: initial_params = self.design_region.initial_parameters return self.to_simulation(initial_params) + def run(self, simulation, **kwargs) -> td.SimulationData: + """Run a single tidy3d simulation.""" + from tidy3d.web import run + + kwargs.setdefault("verbose", self.verbose) + kwargs.setdefault("task_name", self.task_name) + return run(simulation, **kwargs) + + def run_async(self, simulations, **kwargs) -> web.BatchData: # noqa: F821 + """Run a batch of tidy3d simulations.""" + from tidy3d.web import run_async + + kwargs.setdefault("verbose", self.verbose) + return run_async(simulations, **kwargs) + class InverseDesign(AbstractInverseDesign): """Container for an inverse design problem.""" @@ -221,8 +235,7 @@ def to_simulation(self, params: anp.ndarray) -> td.Simulation: def to_simulation_data(self, params: anp.ndarray, **kwargs) -> td.SimulationData: """Convert the ``InverseDesign`` to a ``td.Simulation`` and run it.""" simulation = self.to_simulation(params=params) - kwargs.setdefault("task_name", self.task_name) - return web.run(simulation, verbose=self.verbose, **kwargs) + return self.run(simulation, **kwargs) class InverseDesignMulti(AbstractInverseDesign): @@ -292,11 +305,10 @@ def to_simulation(self, params: anp.ndarray) -> dict[str, td.Simulation]: simulation_list = [design.to_simulation(params) for design in self.designs] return dict(zip(self.task_names, simulation_list)) - def to_simulation_data(self, params: anp.ndarray, **kwargs) -> web.BatchData: + def to_simulation_data(self, params: anp.ndarray, **kwargs) -> web.BatchData: # noqa: F821 """Convert the ``InverseDesignMulti`` to a set of ``td.Simulation``s and run async.""" simulations = self.to_simulation(params) - kwargs.setdefault("verbose", self.verbose) - return web.run_async(simulations, **kwargs) + return self.run_async(simulations, **kwargs) InverseDesignType = typing.Union[InverseDesign, InverseDesignMulti]