From 960440fe8df86748bf5925b6d1567d90477b0497 Mon Sep 17 00:00:00 2001 From: Momchil Minkov Date: Wed, 5 Nov 2025 14:35:35 +0100 Subject: [PATCH] feat: updating webapi for new modeler worfklow unifying add not implemented errors introduce WebTask from which SimulationTask and BatchTask subclass fix to taskType Test fixes and fixture to clear TaskFactory registry Remove compose_modeler in favor of Tidy3dBaseModel.from_file Unifying detail and delete methods Passing taskType, cleanup and using MODAL_CM and TERMINAL_CM cleanup --- .../workflows/tidy3d-python-client-tests.yml | 2 +- docs/api/submit_simulations.rst | 1 - .../smatrix/test_run_functions.py | 24 - tests/test_plugins/test_array_factor.py | 15 +- tests/test_web/test_local_cache.py | 6 +- tests/test_web/test_webapi.py | 2 +- tests/test_web/test_webapi_eme.py | 2 +- tests/test_web/test_webapi_heat.py | 2 +- tests/test_web/test_webapi_mode.py | 2 +- tests/test_web/test_webapi_mode_sim.py | 2 +- tidy3d/__init__.py | 2 + tidy3d/components/mode/mode_solver.py | 2 +- tidy3d/components/mode/simulation.py | 4 +- tidy3d/components/tcad/mesher.py | 6 + .../smatrix/component_modelers/base.py | 4 + .../smatrix/component_modelers/modal.py | 11 + .../smatrix/component_modelers/terminal.py | 43 +- tidy3d/plugins/smatrix/run.py | 40 - tidy3d/web/__init__.py | 2 - tidy3d/web/api/container.py | 98 +-- tidy3d/web/api/states.py | 20 +- tidy3d/web/api/tidy3d_stub.py | 6 +- tidy3d/web/api/webapi.py | 826 +++++------------- tidy3d/web/core/constants.py | 1 + tidy3d/web/core/http_util.py | 1 + tidy3d/web/core/task_core.py | 697 +++++++-------- tidy3d/web/core/task_info.py | 68 +- tidy3d/web/core/types.py | 4 +- tidy3d/web/tests/conftest.py | 15 + 29 files changed, 664 insertions(+), 1244 deletions(-) create mode 100644 tidy3d/web/tests/conftest.py diff --git a/.github/workflows/tidy3d-python-client-tests.yml b/.github/workflows/tidy3d-python-client-tests.yml index d005d021fa..d414ca4758 100644 --- a/.github/workflows/tidy3d-python-client-tests.yml +++ b/.github/workflows/tidy3d-python-client-tests.yml @@ -347,7 +347,7 @@ jobs: BRANCH_NAME="${STEPS_EXTRACT_BRANCH_NAME_OUTPUTS_BRANCH_NAME}" echo $BRANCH_NAME # Allow only Jira keys from known projects, even if the branch has an author prefix - ALLOWED_JIRA_PROJECTS=("FXC" "SCEM") + ALLOWED_JIRA_PROJECTS=("FXC" "SCEM" "SCRF") JIRA_PROJECT_PATTERN=$(IFS='|'; echo "${ALLOWED_JIRA_PROJECTS[*]}") JIRA_PATTERN="(${JIRA_PROJECT_PATTERN})-[0-9]+" diff --git a/docs/api/submit_simulations.rst b/docs/api/submit_simulations.rst index 4c08f6e9d7..795f045586 100644 --- a/docs/api/submit_simulations.rst +++ b/docs/api/submit_simulations.rst @@ -94,7 +94,6 @@ Information Containers :template: module.rst tidy3d.web.core.task_info.TaskInfo - tidy3d.web.core.task_info.TaskStatus Mode Solver Web API diff --git a/tests/test_plugins/smatrix/test_run_functions.py b/tests/test_plugins/smatrix/test_run_functions.py index 64c18c9746..a4207165e5 100644 --- a/tests/test_plugins/smatrix/test_run_functions.py +++ b/tests/test_plugins/smatrix/test_run_functions.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json from unittest.mock import MagicMock import pydantic.v1 as pd @@ -14,38 +13,15 @@ make_component_modeler as make_modal_component_modeler, ) from tidy3d import SimulationDataMap -from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.sim_data import SimulationData from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData from tidy3d.plugins.smatrix.run import ( _run_local, - compose_modeler, compose_modeler_data, create_batch, ) -def test_compose_modeler_unsupported_type(tmp_path, monkeypatch): - # Create a dummy HDF5 file path - modeler_file = tmp_path / "dummy_modeler.hdf5" - - # Prepare a dummy JSON string with an unsupported type - dummy_json = {"type": "UnsupportedComponentModeler", "some_key": "some_value"} - dummy_json_str = json.dumps(dummy_json) - - # Mock Tidy3dBaseModel._json_string_from_hdf5 to return our dummy JSON string - def mock_json_string_from_hdf5(filepath): - if filepath == str(modeler_file): - return dummy_json_str - return "" - - monkeypatch.setattr(Tidy3dBaseModel, "_json_string_from_hdf5", mock_json_string_from_hdf5) - - # Expect a TypeError when calling compose_modeler with the unsupported type - with pytest.raises(TypeError, match="Unsupported modeler type: str"): - compose_modeler(modeler_file=str(modeler_file)) - - def test_create_batch(monkeypatch, tmp_path): # Mock Batch and Batch.to_file mock_batch_instance = MagicMock() diff --git a/tests/test_plugins/test_array_factor.py b/tests/test_plugins/test_array_factor.py index 2ffe92afee..959e1dc7b0 100644 --- a/tests/test_plugins/test_array_factor.py +++ b/tests/test_plugins/test_array_factor.py @@ -363,16 +363,15 @@ def make_antenna_sim(): remove_dc_component=False, # Include DC component for more accuracy at low frequencies ) - sim_unit = list(modeler.sim_dict.values())[0] - - return sim_unit + return modeler def test_rectangular_array_calculator_array_make_antenna_array(): """Test automatic antenna array creation.""" freq0 = 10e9 wavelength0 = td.C_0 / 10e9 - sim_unit = make_antenna_sim() + modeler = make_antenna_sim() + sim_unit = list(modeler.sim_dict.values())[0] array_calculator = mw.RectangularAntennaArrayCalculator( array_size=(1, 2, 3), spacings=(0.5 * wavelength0, 0.6 * wavelength0, 0.4 * wavelength0), @@ -437,8 +436,9 @@ def test_rectangular_array_calculator_array_make_antenna_array(): assert len(sim_array.sources) == 6 # check that override_structures are duplicated - assert len(sim_unit.grid_spec.override_structures) == 2 - assert len(sim_array.grid_spec.override_structures) == 7 + # assert len(sim_unit.grid_spec.override_structures) == 2 + # assert len(sim_array.grid_spec.override_structures) == 7 + assert sim_unit.grid.boundaries == modeler.base_sim.grid.boundaries # check that phase shifts are applied correctly phases_expected = array_calculator._antenna_phases @@ -674,7 +674,8 @@ def test_rectangular_array_calculator_simulation_data_from_array_factor(): phase_shifts=(np.pi / 3, np.pi / 4, np.pi / 5), ) - sim_unit = make_antenna_sim() + modeler = make_antenna_sim() + sim_unit = list(modeler.sim_dict.values())[0] monitor = sim_unit.monitors[0] monitor_directivity = sim_unit.monitors[2] diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py index 0762fe08c7..3c8a9c7ee6 100644 --- a/tests/test_web/test_local_cache.py +++ b/tests/test_web/test_local_cache.py @@ -43,7 +43,7 @@ resolve_local_cache, ) from tidy3d.web.cli.app import tidy3d_cli -from tidy3d.web.core.task_core import BatchTask +from tidy3d.web.core.task_core import BatchTask, SimulationTask common.CONNECTION_RETRY_TIME = 0.1 @@ -245,7 +245,9 @@ def _fake_field_map_from_file(*args, **kwargs): monkeypatch.setattr( io_utils, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id] ) - monkeypatch.setattr(BatchTask, "is_batch", lambda *args, **kwargs: "success") + monkeypatch.setattr( + SimulationTask, "get", lambda *args, **kwargs: SimpleNamespace(taskType="FDTD") + ) monkeypatch.setattr( BatchTask, "detail", lambda *args, **kwargs: SimpleNamespace(status="success") ) diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 074c44b069..4fdd280bf0 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -415,7 +415,7 @@ def test_run_with_invalid_priority(mock_webapi, priority): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_eme.py b/tests/test_web/test_webapi_eme.py index fc43f0670d..c44cd19b41 100644 --- a/tests/test_web/test_webapi_eme.py +++ b/tests/test_web/test_webapi_eme.py @@ -260,7 +260,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_heat.py b/tests/test_web/test_webapi_heat.py index 15767089e8..571aad371f 100644 --- a/tests/test_web/test_webapi_heat.py +++ b/tests/test_web/test_webapi_heat.py @@ -250,7 +250,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_mode.py b/tests/test_web/test_webapi_mode.py index 8ee869de5b..93e4a2a8bf 100644 --- a/tests/test_web/test_webapi_mode.py +++ b/tests/test_web/test_webapi_mode.py @@ -307,7 +307,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_mode_sim.py b/tests/test_web/test_webapi_mode_sim.py index 275d204a9f..052591beab 100644 --- a/tests/test_web/test_webapi_mode_sim.py +++ b/tests/test_web/test_webapi_mode_sim.py @@ -303,7 +303,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 4acc881996..afa81b60ff 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.boundary import BroadbandModeABCFitterParam, BroadbandModeABCSpec from tidy3d.components.data.index import SimulationDataMap from tidy3d.components.frequency_extrapolation import LowFrequencySmoothingSpec @@ -812,6 +813,7 @@ def set_logging_level(level: str) -> None: "TemperatureData", "TemperatureMonitor", "TetrahedralGridDataset", + "Tidy3dBaseModel", "Transformed", "TriangleMesh", "TriangularGridDataset", diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 3d41c2ffb8..f32a56a4cf 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -2714,7 +2714,7 @@ def _validate_modes_size(self) -> None: "frequencies or modes." ) - def validate_pre_upload(self, source_required: bool = True) -> None: + def validate_pre_upload(self) -> None: """Validate the fully initialized mode solver is ok for upload to our servers.""" self._validate_modes_size() diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index 2e836c55ec..e506706822 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -612,8 +612,8 @@ def plot_pml_mode_plane( """ return self._mode_solver.plot_pml(ax=ax) - def validate_pre_upload(self, source_required: bool = False) -> None: + def validate_pre_upload(self) -> None: super().validate_pre_upload() - self._mode_solver.validate_pre_upload(source_required=source_required) + self._mode_solver.validate_pre_upload() _boundaries_for_zero_dims = validate_boundaries_for_zero_dims(warn_on_change=False) diff --git a/tidy3d/components/tcad/mesher.py b/tidy3d/components/tcad/mesher.py index d35b465c9a..82af0de6f3 100644 --- a/tidy3d/components/tcad/mesher.py +++ b/tidy3d/components/tcad/mesher.py @@ -24,3 +24,9 @@ class VolumeMesher(Tidy3dBaseModel): def _get_simulation_types(self) -> list[TCADAnalysisTypes]: return [TCADAnalysisTypes.MESH] + + def validate_pre_upload(self): + """Validate the VolumeMesher before uploading to the cloud. + Currently no validation but method is required when calling ``web.upload``. + """ + return diff --git a/tidy3d/plugins/smatrix/component_modelers/base.py b/tidy3d/plugins/smatrix/component_modelers/base.py index 1d9715d98f..74c5afe1c0 100644 --- a/tidy3d/plugins/smatrix/component_modelers/base.py +++ b/tidy3d/plugins/smatrix/component_modelers/base.py @@ -343,5 +343,9 @@ def run( ) return data.smatrix() + def validate_pre_upload(self): + """Validate the modeler before upload.""" + self.base_sim.validate_pre_upload(source_required=False) + AbstractComponentModeler.update_forward_refs() diff --git a/tidy3d/plugins/smatrix/component_modelers/modal.py b/tidy3d/plugins/smatrix/component_modelers/modal.py index 591ad4c624..f75c98bc81 100644 --- a/tidy3d/plugins/smatrix/component_modelers/modal.py +++ b/tidy3d/plugins/smatrix/component_modelers/modal.py @@ -59,6 +59,11 @@ class ModalComponentModeler(AbstractComponentModeler): "by ``element_mappings``, the simulation corresponding to this column is skipped automatically.", ) + @property + def base_sim(self): + """The base simulation.""" + return self.simulation + @cached_property def sim_dict(self) -> SimulationMap: """Generates all :class:`.Simulation` objects for the S-matrix calculation. @@ -368,3 +373,9 @@ def get_max_mode_indices(matrix_elements: tuple[str, int]) -> int: max_mode_index_in = get_max_mode_indices(self.matrix_indices_source) return max_mode_index_out, max_mode_index_in + + def task_name_from_index(self, matrix_index: MatrixIndex) -> str: + """Compute task name for a given (port_name, mode_index) without constructing simulations.""" + port_name, mode_index = matrix_index + port = self.get_port_by_name(port_name=port_name) + return self.get_task_name(port=port, mode_index=mode_index) diff --git a/tidy3d/plugins/smatrix/component_modelers/terminal.py b/tidy3d/plugins/smatrix/component_modelers/terminal.py index f865286c4e..cbd7cb7fd9 100644 --- a/tidy3d/plugins/smatrix/component_modelers/terminal.py +++ b/tidy3d/plugins/smatrix/component_modelers/terminal.py @@ -223,7 +223,7 @@ def _warn_refactor_2_10(cls, values): @property def _sim_with_sources(self) -> Simulation: - """Instance of :class:`.Simulation` with all sources and absorbers added for each port, for troubleshooting.""" + """Instance of :class:`.Simulation` with all sources and absorbers added for each port, for plotting.""" sources = [port.to_source(self._source_time) for port in self.ports] absorbers = [ @@ -231,7 +231,9 @@ def _sim_with_sources(self) -> Simulation: for port in self.ports if isinstance(port, WavePort) and port.absorber ] - return self.simulation.updated_copy(sources=sources, internal_absorbers=absorbers) + return self.simulation.updated_copy( + sources=sources, internal_absorbers=absorbers, validate=False + ) @equal_aspect @add_ax_if_none @@ -382,6 +384,10 @@ def matrix_indices_run_sim(self) -> tuple[NetworkIndex, ...]: def sim_dict(self) -> SimulationMap: """Generate all the :class:`.Simulation` objects for the port parameter calculation.""" + # Check base simulation for grid size at ports + TerminalComponentModeler._check_grid_size_at_ports(self.base_sim, self._lumped_ports) + TerminalComponentModeler._check_grid_size_at_wave_ports(self.base_sim, self._wave_ports) + sim_dict = {} # Now, create simulations with wave port sources and mode solver monitors for computing port modes for network_index in self.matrix_indices_run_sim: @@ -389,11 +395,6 @@ def sim_dict(self) -> SimulationMap: # update simulation sim_dict[task_name] = sim_with_src - # Check final simulations for grid size at ports - for _, sim in sim_dict.items(): - TerminalComponentModeler._check_grid_size_at_ports(sim, self._lumped_ports) - TerminalComponentModeler._check_grid_size_at_wave_ports(sim, self._wave_ports) - return SimulationMap(keys=tuple(sim_dict.keys()), values=tuple(sim_dict.values())) @cached_property @@ -414,7 +415,10 @@ def _base_sim_no_radiation_monitors(self) -> Simulation: # Make an initial simulation with new grid_spec to determine where LumpedPorts are snapped sim_wo_source = self.simulation.updated_copy( - grid_spec=grid_spec, lumped_elements=lumped_resistors + grid_spec=grid_spec, + lumped_elements=lumped_resistors, + validate=False, + deep=False, ) snap_centers = {} for port in self._lumped_ports: @@ -480,7 +484,11 @@ def _base_sim_no_radiation_monitors(self) -> Simulation: ) # update base simulation with updated set of shared components - sim_wo_source = sim_wo_source.copy(update=update_dict) + sim_wo_source = sim_wo_source.updated_copy( + **update_dict, + validate=False, + deep=False, + ) # extrude port structures sim_wo_source = self._extrude_port_structures(sim=sim_wo_source) @@ -527,7 +535,10 @@ def base_sim(self) -> Simulation: """The base simulation with all components added, including radiation monitors.""" base_sim_tmp = self._base_sim_no_radiation_monitors mnts_with_radiation = list(base_sim_tmp.monitors) + list(self._finalized_radiation_monitors) - return base_sim_tmp.updated_copy(monitors=mnts_with_radiation) + grid_spec = GridSpec.from_grid(base_sim_tmp.grid) + grid_spec.attrs["from_grid_spec"] = base_sim_tmp.grid_spec + # We skipped validations up to now, here we finally validate the base sim + return base_sim_tmp.updated_copy(monitors=mnts_with_radiation, grid_spec=grid_spec) def _generate_radiation_monitor( self, simulation: Simulation, auto_spec: DirectivityMonitorSpec @@ -712,7 +723,10 @@ def _add_source_to_sim(self, source_index: NetworkIndex) -> tuple[str, Simulatio ) task_name = self.get_task_name(port=port, mode_index=mode_index) - return (task_name, self.base_sim.updated_copy(sources=[port_source])) + return ( + task_name, + self.base_sim.updated_copy(sources=[port_source], validate=False, deep=False), + ) @cached_property def _source_time(self): @@ -863,6 +877,11 @@ def get_radiation_monitor_by_name(self, monitor_name: str) -> DirectivityMonitor return monitor raise Tidy3dKeyError(f"No radiation monitor named '{monitor_name}'.") + def task_name_from_index(self, source_index: NetworkIndex) -> str: + """Compute task name for a given network index without constructing simulations.""" + port, mode_index = self.network_dict[source_index] + return self.get_task_name(port=port, mode_index=mode_index) + def _extrude_port_structures(self, sim: Simulation) -> Simulation: """ Extrude structures intersecting a port plane when a wave port lies on a structure boundary. @@ -983,6 +1002,8 @@ def _extrude_port_structures(self, sim: Simulation) -> Simulation: sim = sim.updated_copy( grid_spec=GridSpec.from_grid(sim.grid), structures=[*sim.structures, *all_new_structures], + validate=False, + deep=False, ) return sim diff --git a/tidy3d/plugins/smatrix/run.py b/tidy3d/plugins/smatrix/run.py index 97f9393338..ff4b9d39f1 100644 --- a/tidy3d/plugins/smatrix/run.py +++ b/tidy3d/plugins/smatrix/run.py @@ -1,10 +1,7 @@ from __future__ import annotations -import json -from os import PathLike from typing import Any -from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.index import SimulationDataMap from tidy3d.log import log from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler @@ -18,43 +15,6 @@ DEFAULT_DATA_DIR = "." -def compose_modeler( - modeler_file: PathLike, -) -> ComponentModelerType: - """Load a component modeler from an HDF5 file. - - This function reads an HDF5 file, determines the modeler type - (`ModalComponentModeler` or `TerminalComponentModeler`), and constructs the - corresponding modeler object. - - Parameters - ---------- - modeler_file : PathLike - Path to the HDF5 file containing the modeler definition. - - Returns - ------- - ComponentModelerType - The loaded `ModalComponentModeler` or `TerminalComponentModeler` object. - - Raises - ------ - TypeError - If the modeler type specified in the file is not supported. - """ - json_str = Tidy3dBaseModel._json_string_from_hdf5(modeler_file) - model_dict = json.loads(json_str) - modeler_type = model_dict["type"] - - if modeler_type == "ModalComponentModeler": - modeler = ModalComponentModeler.from_file(modeler_file) - elif modeler_type == "TerminalComponentModeler": - modeler = TerminalComponentModeler.from_file(modeler_file) - else: - raise TypeError(f"Unsupported modeler type: {type(modeler_type).__name__}") - return modeler - - def compose_modeler_data( modeler: ModalComponentModeler | TerminalComponentModeler, indexed_sim_data: SimulationDataMap, diff --git a/tidy3d/web/__init__.py b/tidy3d/web/__init__.py index 0cdc8942e5..dcdf44c9c3 100644 --- a/tidy3d/web/__init__.py +++ b/tidy3d/web/__init__.py @@ -30,7 +30,6 @@ load, load_simulation, monitor, - postprocess_start, real_cost, start, test, @@ -58,7 +57,6 @@ "load", "load_simulation", "monitor", - "postprocess_start", "real_cost", "run", "run_async", diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index 08092e3062..93e80438c4 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -389,25 +389,7 @@ def status(self) -> str: """Return current status of :class:`Job`.""" if self.load_if_cached: return "success" - if web._is_modeler_batch(self.task_id): - detail = self.get_info() - status = detail.totalStatus.value - return status - else: - return self.get_info().status - - @property - def postprocess_status(self) -> Optional[str]: - """Return current postprocess status of :class:`Job` if it is a Component Modeler.""" - if web._is_modeler_batch(self.task_id): - detail = self.get_info() - return detail.postprocessStatus - else: - log.warning( - f"Task ID '{self.task_id}' is not a modeler batch job. " - "'postprocess_start' is only applicable to Component Modelers" - ) - return + return self.get_info().status def start(self, priority: Optional[int] = None) -> None: """Start running a :class:`Job`. @@ -548,37 +530,6 @@ def estimate_cost(self, verbose: bool = True) -> float: return 0.0 return web.estimate_cost(self.task_id, verbose=verbose, solver_version=self.solver_version) - def postprocess_start(self, worker_group: Optional[str] = None, verbose: bool = True) -> None: - """ - If the job is a modeler batch, checks if the run is complete and starts - the postprocess phase. - - This function does not wait for postprocessing to finish and is only - applicable to Component Modeler batch jobs. - - Parameters - ---------- - worker_group : Optional[str] = None - The specific worker group to run the postprocessing task on. - verbose : bool = True - Whether to print info messages. This overrides the Job's 'verbose' setting for this call. - """ - # First, confirm that the task is a modeler batch job. - if not web._is_modeler_batch(self.task_id): - # If not, inform the user and exit. - # This warning is important and should not be suppressed. - log.warning( - f"Task ID '{self.task_id}' is not a modeler batch job. " - "'postprocess_start' is only applicable to Component Modelers" - ) - return - - # If it is a modeler batch, call the dedicated function to start postprocessing. - # The verbosity is a combination of the job's setting and the method's parameter. - web.postprocess_start( - batch_id=self.task_id, verbose=(self.verbose and verbose), worker_group=worker_group - ) - @staticmethod def _check_path_dir(path: PathLike) -> None: """Make sure parent directory of ``path`` exists and create it if not. @@ -1043,21 +994,6 @@ def get_run_info(self) -> dict[TaskName, RunInfo]: run_info_dict[task_name] = run_info return run_info_dict - def postprocess_start(self, worker_group: Optional[str] = None, verbose: bool = True) -> None: - """ - Start the postprocess phase for all applicable jobs in the batch. - - This simply forwards to each Job's `postprocess_start(...)`. The Job decides - whether it's a Component Modeler task and whether it can/should start now. - This method does not wait for postprocessing to finish. - """ - if self.verbose and verbose: - console = get_logging_console() - console.log("Attempting to start postprocessing for jobs in the batch.") - - for job in self.jobs.values(): - job.postprocess_start(worker_group=worker_group, verbose=verbose) - def monitor( self, *, @@ -1069,7 +1005,6 @@ def monitor( """ Monitor progress of each running task. - - For Component Modeler jobs, automatically triggers postprocessing once run finishes. - Optionally downloads results as soon as a job reaches final success. - Rich progress bars in verbose mode; quiet polling otherwise. @@ -1095,16 +1030,8 @@ def monitor( self._check_path_dir(path_dir=path_dir) download_executor = ThreadPoolExecutor(max_workers=self.num_workers) - def _should_download(job: Job) -> bool: - status = job.status - if not web._is_modeler_batch(job.task_id): - return status == "success" - if status == "success": - return True - return status == "run_success" and getattr(job, "postprocess_status", None) == "success" - def schedule_download(job: Job) -> None: - if download_executor is None or not _should_download(job): + if download_executor is None or job.status not in COMPLETED_STATES: return task_id = job.task_id if task_id in downloads_started: @@ -1128,12 +1055,7 @@ def schedule_download(job: Job) -> None: def check_continue_condition(job: Job) -> bool: if job.load_if_cached: return False - status = job.status - if not web._is_modeler_batch(job.task_id): - return status not in END_STATES - if status == "run_success": - return job.postprocess_status not in END_STATES - return status not in END_STATES + return job.status not in END_STATES def pbar_description( task_name: str, status: str, max_name_length: int, status_width: int @@ -1157,9 +1079,6 @@ def pbar_description( max_task_name = max(len(task_name) for task_name in self.jobs.keys()) max_name_length = min(30, max(max_task_name, 15)) - # track which modeler jobs we've already kicked into postprocess - postprocess_started_tasks: set[str] = set() - try: console = None progress_columns = [] @@ -1195,17 +1114,6 @@ def pbar_description( for task_name, job in self.jobs.items(): status = job.status - # auto-start postprocess for modeler jobs when run finishes - if ( - web._is_modeler_batch(job.task_id) - and status == "run_success" - and job.task_id not in postprocess_started_tasks - ): - job.postprocess_start( - worker_group=postprocess_worker_group, verbose=True - ) - postprocess_started_tasks.add(job.task_id) - schedule_download(job) if self.verbose: diff --git a/tidy3d/web/api/states.py b/tidy3d/web/api/states.py index 9b02d161f6..a7a03935ae 100644 --- a/tidy3d/web/api/states.py +++ b/tidy3d/web/api/states.py @@ -5,15 +5,17 @@ } ERROR_STATES = { - "validate_fail", + "validate_error", "error", "errored", "diverge", "diverged", "blocked", - "run_failed", + "preprocess_error", + "run_error", "aborted", "deleted", + "postprocess_error", } PRE_VALIDATE_STATES = { @@ -30,20 +32,17 @@ "run_success", } -COMPLETED_STATES = {"visualize", "success", "completed", "processed"} +COMPLETED_STATES = {"visualize", "success", "completed", "processed", "postprocess_success"} END_STATES = ERROR_STATES | COMPLETED_STATES -POST_VALIDATE_STATES = { - "validate_success", - "validate_warn", -} +POST_VALIDATE_STATES = {"validate_success", "validate_warn"} RUNNING_STATES = ( PRE_VALIDATE_STATES | POST_VALIDATE_STATES | {"running"} | POST_RUN_STATES | COMPLETED_STATES ) -ALL_POST_VALIDATE_STATES = POST_VALIDATE_STATES | {"running"} | POST_RUN_STATES | COMPLETED_STATES +ALL_POST_VALIDATE_STATES = POST_VALIDATE_STATES | {"running"} | POST_RUN_STATES | END_STATES VALID_PROGRESS_STATES = RUNNING_STATES | PRE_ERROR_STATES @@ -87,6 +86,7 @@ "visualize": round((11 / MAX_STEPS) * COMPLETED_PERCENT), # 85% "success": COMPLETED_PERCENT, # 100% "completed": COMPLETED_PERCENT, # 100% + "postprocess_success": COMPLETED_PERCENT, # 100% # --- Error States --- # All error states map to 0% "validate_fail": 0, @@ -98,4 +98,8 @@ "run_failed": 0, "aborted": 0, "deleted": 0, + "validate_error": 0, + "preprocess_error": 0, + "run_error": 0, + "postprocess_error": 0, } diff --git a/tidy3d/web/api/tidy3d_stub.py b/tidy3d/web/api/tidy3d_stub.py index a55bafe6f9..eb49f90461 100644 --- a/tidy3d/web/api/tidy3d_stub.py +++ b/tidy3d/web/api/tidy3d_stub.py @@ -46,8 +46,8 @@ EMESimulation: TaskType.EME, ModeSimulation: TaskType.MODE, VolumeMesher: TaskType.VOLUME_MESH, - ModalComponentModeler: TaskType.COMPONENT_MODELER, - TerminalComponentModeler: TaskType.TERMINAL_COMPONENT_MODELER, + ModalComponentModeler: TaskType.MODAL_CM, + TerminalComponentModeler: TaskType.TERMINAL_CM, } @@ -116,7 +116,7 @@ def validate_pre_upload(self, source_required: bool) -> None: """Perform some pre-checks on instances of component""" if isinstance(self.simulation, Simulation): self.simulation.validate_pre_upload(source_required) - elif isinstance(self.simulation, EMESimulation): + else: self.simulation.validate_pre_upload() def get_default_task_name(self) -> str: diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index a57f0ebbfe..9134fff385 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -19,12 +19,10 @@ from tidy3d.config import config from tidy3d.exceptions import WebError from tidy3d.log import get_logging_console, log -from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler from tidy3d.web.api.states import ( ALL_POST_VALIDATE_STATES, END_STATES, ERROR_STATES, - POST_VALIDATE_STATES, STATE_PROGRESS_PERCENTAGE, ) from tidy3d.web.cache import CacheEntry, _store_mode_solver_in_cache, resolve_local_cache @@ -39,11 +37,15 @@ SIMULATION_DATA_HDF5_GZ, TaskId, ) -from tidy3d.web.core.exceptions import WebNotFoundError -from tidy3d.web.core.http_util import get_version as _get_protocol_version -from tidy3d.web.core.http_util import http -from tidy3d.web.core.task_core import BatchDetail, BatchTask, Folder, SimulationTask -from tidy3d.web.core.task_info import AsyncJobDetail, ChargeType, TaskInfo +from tidy3d.web.core.task_core import ( + BatchDetail, + BatchTask, + Folder, + SimulationTask, + TaskFactory, + WebTask, +) +from tidy3d.web.core.task_info import ChargeType, TaskInfo from tidy3d.web.core.types import PayType, TaskType from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection @@ -56,7 +58,7 @@ SIM_FILE_JSON = "simulation.json" # not all solvers are supported yet in GUI -GUI_SUPPORTED_TASK_TYPES = ["FDTD", "MODE_SOLVER", "HEAT", "RF"] +GUI_SUPPORTED_TASK_TYPES = ["FDTD", "MODE_SOLVER", "HEAT", "TERMINAL_CM"] # if a solver is in beta stage, cost is subject to change BETA_TASK_TYPES = ["HEAT", "EME", "HEAT_CHARGE", "VOLUME_MESH"] @@ -95,15 +97,6 @@ def _build_website_url(path: str) -> str: return "/".join([base.rstrip("/"), str(path).lstrip("/")]) -def _is_modeler_batch(resource_id: str) -> bool: - """Detect whether the given id corresponds to a modeler batch resource.""" - return BatchTask.is_batch(resource_id, batch_type="RF_SWEEP") - - -def _batch_detail(resource_id: str) -> BatchDetail: - return BatchTask(resource_id).detail(batch_type="RF_SWEEP") - - def _batch_detail_error(resource_id: str) -> Optional[WebError]: """Processes a failed batch job to generate a detailed error. @@ -115,45 +108,40 @@ def _batch_detail_error(resource_id: str) -> Optional[WebError]: Args: resource_id (str): The identifier of the batch resource that failed. - Returns: - An instance of `WebError` if the batch failed, otherwise `None`. + Raises: + An instance of ``WebError`` if the batch failed. """ + + # TODO: test properly try: - batch_detail = BatchTask(batch_id=resource_id).detail(batch_type="RF_SWEEP") - status = batch_detail.totalStatus.value + batch = BatchTask.get(resource_id) + batch_detail = batch.detail() + status = batch_detail.status.lower() except Exception as e: log.error(f"Could not retrieve batch details for '{resource_id}': {e}") - return WebError(f"Failed to retrieve status for batch '{resource_id}'.") + raise WebError(f"Failed to retrieve status for batch '{resource_id}'.") from e if status not in ERROR_STATES: - return None - - log.error(f"The ComponentModeler batch '{resource_id}' has failed with status: {status}") + return - if ( - status == "validate_fail" - and hasattr(batch_detail, "validateErrors") - and batch_detail.validateErrors - ): - error_details = [] - for key, error_str in batch_detail.validateErrors.items(): - try: - error_dict = json.loads(error_str) - validation_error = error_dict.get("validation_error", "Unknown validation error.") - msg = f"- Subtask '{key}' failed: {validation_error}" - log.error(msg) + if hasattr(batch_detail, "validateErrors") and batch_detail.validateErrors: + try: + error_details = [] + for key, error_str in batch_detail.validateErrors.items(): + msg = f"- Subtask '{key}' failed: {error_str}" error_details.append(msg) - except (json.JSONDecodeError, TypeError): - # Handle cases where the error string isn't valid JSON - log.error(f"Could not parse validation error for subtask '{key}'.") - error_details.append(f"- Subtask '{key}': Could not parse error details.") - - details_string = "\n".join(error_details) - full_error_msg = ( - "One or more subtasks failed validation. Please fix the component modeler configuration.\n" - f"Details:\n{details_string}" - ) - return WebError(full_error_msg) + + details_string = "\n".join(error_details) + full_error_msg = ( + "One or more subtasks failed validation. Please fix the component modeler " + "configuration.\n" + f"Details:\n{details_string}" + ) + except Exception as e: + raise WebError( + "One or more subtasks failed validation. Failed to parse validation errors." + ) from e + raise WebError(full_error_msg) # Handle all other generic error states else: @@ -161,168 +149,7 @@ def _batch_detail_error(resource_id: str) -> Optional[WebError]: f"Batch '{resource_id}' failed with status '{status}'. Check server " "logs for details or contact customer support." ) - return WebError(error_msg) - - -def _upload_component_modeler_subtasks( - resource_id: str, verbose: bool = True, solver_version: Optional[str] = None -) -> Optional[WebError]: - """Kicks off and monitors the split and validation of component modeler tasks. - - This function orchestrates a two-phase process. First, it initiates a - server-side asynchronous job to split the components of a modeler batch. - It monitors this job's progress by polling the API and parsing the - response into an `AsyncJobDetail` model until the job completes or fails. - - If the split is successful, the function proceeds to the second phase: - triggering a batch validation via `batch.check()`. It then monitors this - validation process by polling for `BatchDetail` updates. The progress bar, - if verbose, reflects the status according to a predefined state mapping. - - Finally, it processes the terminal state of the validation. If a - 'validate_fail' status occurs, it parses detailed error messages for each - failed subtask and includes them in the raised exception. - - Args: - resource_id (str): The identifier for the batch resource to be processed. - verbose (bool): If True, displays progress bars and logs detailed - status messages to the console during the operation. - solver_version (str): Solver version in which to run validation. - - Raises: - RuntimeError: If the initial asynchronous split job fails. - WebError: If the subsequent batch validation fails, ends in an - unexpected state, or if a 'validate_fail' status is encountered. - """ - console = get_logging_console() if verbose else None - final_error = None - batch_type = "RF_SWEEP" - - split_path = "tidy3d/async-biz/component-modeler-split" - payload = { - "batchType": batch_type, - "batchId": resource_id, - "fileName": "modeler.hdf5.gz", - "protocolVersion": _get_protocol_version(), - } - - if verbose: - console.log("Starting modeler and subtasks validation...") - - initial_resp = http.post(split_path, payload) - split_job_detail = AsyncJobDetail(**initial_resp) - monitor_split_path = f"{split_path}?asyncId={split_job_detail.asyncId}" - - if verbose: - progress_bar = Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - console=console, - ) - - with progress_bar as progress: - description = "Upload Subtasks" - pbar = progress.add_task(description, completed=split_job_detail.progress, total=100) - while True: - split_job_raw_result = http.get(monitor_split_path) - split_job_detail = AsyncJobDetail(**split_job_raw_result) - - progress.update( - pbar, completed=split_job_detail.progress, description=f"[blue]{description}" - ) - - if split_job_detail.status in END_STATES: - progress.update( - pbar, - completed=split_job_detail.progress, - description=f"[green]{description}", - ) - break - time.sleep(RUN_REFRESH_TIME) - - if split_job_detail.status in ERROR_STATES: - msg = split_job_detail.result or "An unknown error occurred." - final_error = WebError( - f"Component modeler split job failed ({split_job_detail.status}): {msg}" - ) - - if not final_error: - description = "Validating" - pbar = progress.add_task( - completed=10, total=100, description=f"[blue]{description}" - ) - batch = BatchTask(resource_id) - batch.check(solver_version=solver_version, batch_type=batch_type) - - while True: - batch_detail = batch.detail(batch_type=batch_type) - status = batch_detail.totalStatus - progress_percent = STATE_PROGRESS_PERCENTAGE.get(status, 0) - progress.update( - pbar, completed=progress_percent, description=f"[blue]{description}" - ) - - if status in POST_VALIDATE_STATES: - progress.update(pbar, completed=100, description=f"[green]{description}") - task_mapping = json.loads(split_job_detail.result) - console.log( - f"Uploaded Subtasks: \n{_task_dict_to_url_bullet_list(task_mapping)}" - ) - progress.refresh() - break - elif status in ERROR_STATES: - progress.update(pbar, completed=0, description=f"[red]{description}") - progress.refresh() - break - time.sleep(RUN_REFRESH_TIME) - - else: - # Non-verbose mode: Poll for split job completion. - while True: - split_job_raw_result = http.get(monitor_split_path) - split_job_detail = AsyncJobDetail(**split_job_raw_result) - if split_job_detail.status in END_STATES: - break - time.sleep(RUN_REFRESH_TIME) - - # Check for split job failure. - if split_job_detail.status in ERROR_STATES: - msg = split_job_detail.result or "An unknown error occurred." - final_error = WebError( - f"Component modeler split job failed ({split_job_detail.status}): {msg}" - ) - - # If split succeeded, poll for validation completion. - if not final_error: - batch = BatchTask(resource_id) - batch.check(solver_version=solver_version, batch_type=batch_type) - while True: - batch_detail = batch.detail(batch_type=batch_type) - status = batch_detail.totalStatus - if status in POST_VALIDATE_STATES or status in END_STATES: - break - time.sleep(RUN_REFRESH_TIME) - - return _batch_detail_error(resource_id=resource_id) - - -def _task_dict_to_url_bullet_list(data_dict: dict) -> str: - """ - Converts a dictionary into a string formatted as a bullet point list. - - Args: - data_dict: The dictionary to convert. - - Returns: - A string with each key-url/value pair as a bullet point. - """ - # Use a list comprehension to format each key-value pair - # and then join them together with newline characters. - if data_dict is None: - raise WebError("Error in subtask dictionary data.") - return "\n".join([f"- {key}: '{value}'" for key, value in data_dict.items()]) + raise WebError(error_msg) def _copy_simulation_data_from_cache_entry(entry: CacheEntry, path: PathLike) -> bool: @@ -398,7 +225,7 @@ def restore_simulation_if_cached( cached_workflow_type = entry.metadata.get("workflow_type") if cached_task_id is not None and cached_workflow_type is not None and verbose: console = get_logging_console() - url, _ = _get_task_urls(cached_workflow_type, simulation, cached_task_id) + url, _ = _get_task_urls(cached_workflow_type, cached_task_id) console.log( f"Loading simulation from local cache. View cached task using web UI at [link={url}]'{url}'[/link]." ) @@ -572,7 +399,6 @@ def run( ) start( task_id, - verbose=verbose, solver_version=solver_version, worker_group=worker_group, pay_type=pay_type, @@ -600,15 +426,12 @@ def run( def _get_task_urls( task_type: str, - simulation: WorkflowType, resource_id: str, folder_id: Optional[str] = None, group_id: Optional[str] = None, ) -> tuple[str, Optional[str]]: """Log task and folder links to the web UI.""" - if (task_type in ["RF", "COMPONENT_MODELER", "TERMINAL_COMPONENT_MODELER"]) and isinstance( - simulation, TerminalComponentModeler - ): + if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: url = _get_url_rf(group_id or resource_id) else: url = _get_url(resource_id) @@ -694,13 +517,8 @@ def upload( task_name = stub.get_default_task_name() task_type = stub.get_type() - # Component modeler compatibility: map to RF task type - port_name_list = None - if task_type in ("COMPONENT_MODELER", "TERMINAL_COMPONENT_MODELER"): - task_type = "RF" - port_name_list = tuple(simulation.sim_dict.keys()) - task = SimulationTask.create( + task = WebTask.create( task_type, task_name, folder_name, @@ -708,26 +526,10 @@ def upload( simulation_type, parent_tasks, "Gz", - port_name_list=port_name_list, ) - if task_type == "RF": - # Prefer the group id if present in the creation response; avoid extra GET. - group_id = getattr(task, "groupId", None) or getattr(task, "group_id", None) - if not group_id: - try: - detail_task = SimulationTask.get(task.task_id, verbose=False) - group_id = getattr(detail_task, "groupId", None) or getattr( - detail_task, "group_id", None - ) - except Exception: - group_id = None - # Prefer returning batch/group id for downstream batch endpoints - batch_id = getattr(task, "batchId", None) or getattr(task, "batch_id", None) - resource_id = batch_id or task.task_id - else: - group_id = None - resource_id = task.task_id + group_id = getattr(task, "groupId", None) + resource_id = task.task_id if verbose: console.log( @@ -740,16 +542,14 @@ def upload( f"Cost of {solver_name} simulations is subject to change in the future." ) if task_type in GUI_SUPPORTED_TASK_TYPES: - url, folder_url = _get_task_urls( - task_type, simulation, resource_id, task.folder_id, group_id - ) + url, folder_url = _get_task_urls(task_type, resource_id, task.folder_id, group_id) console.log(f"View task using web UI at [link={url}]'{url}'[/link].") console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") remote_sim_file = SIM_FILE_HDF5_GZ if task_type == "MODE_SOLVER": remote_sim_file = MODE_FILE_HDF5_GZ - elif task_type == "RF": + elif task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: remote_sim_file = MODELER_FILE_HDF5_GZ task.upload_simulation( @@ -759,15 +559,10 @@ def upload( remote_sim_file=remote_sim_file, ) - if task_type == "RF": - _upload_component_modeler_subtasks(resource_id=resource_id, verbose=verbose) - estimate_cost(task_id=resource_id, solver_version=solver_version, verbose=verbose) task.validate_post_upload(parent_tasks=parent_tasks) - # log the url for the task in the web UI - log.debug(_build_website_url(f"folders/{task.folder_id}/tasks/{resource_id}")) return resource_id @@ -847,20 +642,15 @@ def get_info(task_id: TaskId, verbose: bool = True) -> TaskInfo | BatchDetail: ValueError If no task is found for the given ``task_id``. """ - if _is_modeler_batch(task_id): - batch = BatchTask(task_id) - return batch.detail(batch_type="RF_SWEEP") - else: - task = SimulationTask.get(task_id, verbose) - if not task: - raise ValueError("Task not found.") - return TaskInfo(**{"taskId": task.task_id, "taskType": task.task_type, **task.dict()}) + task = TaskFactory.get(task_id, verbose=verbose) + if not task: + raise ValueError("Task not found.") + return task.detail() @wait_for_connection def start( task_id: TaskId, - verbose: bool = True, solver_version: Optional[str] = None, worker_group: Optional[str] = None, pay_type: Union[PayType, str] = PayType.AUTO, @@ -889,30 +679,10 @@ def start( To monitor progress, can call :meth:`monitor` after starting simulation. """ - console = get_logging_console() if verbose else None - - # Component modeler batch path: hide split/check/submit - if _is_modeler_batch(task_id): - # split (modeler-specific) - batch = BatchTask(task_id) - detail = batch.wait_for_validate(batch_type="RF_SWEEP") - status = detail.totalStatus - status_str = status.value - if status_str in POST_VALIDATE_STATES: - pass - elif status_str not in POST_VALIDATE_STATES: - raise WebError(f"Batch task {task_id} is blocked: {status_str}") - # Submit batch to start runs after validation - batch.submit( - solver_version=solver_version, batch_type="RF_SWEEP", worker_group=worker_group - ) - if verbose: - console.log(f"Component Modeler '{task_id}' validated. Solving...") - return - if priority is not None and (priority < 1 or priority > 10): raise ValueError("Priority must be between '1' and '10' if specified.") - task = SimulationTask.get(task_id) + + task = TaskFactory.get(task_id) if not task: raise ValueError("Task not found.") task.submit( @@ -941,10 +711,21 @@ def get_run_info(task_id: TaskId) -> tuple[Optional[float], Optional[float]]: Average field intensity normalized to max value (1.0). Is ``None`` if run info not available. """ - task = SimulationTask(taskId=task_id) + task = TaskFactory.get(task_id) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") return task.get_running_info() +def _get_batch_detail_handle_error_status(batch: BatchTask) -> BatchDetail: + """Get batch detail and raise error if status is in ERROR_STATES.""" + detail = batch.detail() + status = detail.status.lower() + if status in ERROR_STATES: + _batch_detail_error(batch.task_id) + return detail + + def get_status(task_id: TaskId) -> str: """Get the status of a task. Raises an error if status is "error". @@ -953,22 +734,9 @@ def get_status(task_id: TaskId) -> str: task_id : str Unique identifier of task on server. Returned by :meth:`upload`. """ - if _is_modeler_batch(task_id): - # split (modeler-specific) - batch = BatchTask(task_id) - detail = batch.detail(batch_type="RF_SWEEP") - status = detail.totalStatus - if status == "visualize": - return "success" - if status in ERROR_STATES: - try: - # TODO Try to obtain the error message - pass - except Exception: - # If the error message could not be obtained, raise a generic error message - error_msg = "Error message could not be obtained, please contact customer support." - - raise WebError(f"Error running task {task_id}! {error_msg}") + task = TaskFactory.get(task_id) + if isinstance(task, BatchTask): + return _get_batch_detail_handle_error_status(task).status else: task_info = get_info(task_id) status = task_info.status @@ -1017,9 +785,9 @@ def monitor(task_id: TaskId, verbose: bool = True, worker_group: Optional[str] = """ # Batch/modeler monitoring path - if _is_modeler_batch(task_id): - _monitor_modeler_batch(task_id, verbose=verbose, worker_group=worker_group) - return + task = TaskFactory.get(task_id) + if isinstance(task, BatchTask): + return _monitor_modeler_batch(task_id, verbose=verbose) console = get_logging_console() if verbose else None @@ -1094,7 +862,7 @@ def monitor_preprocess() -> None: if verbose: # verbose case, update progressbar console.log("running solver") - if task_type == "FDTD": + if "FDTD" in task_type: with Progress(console=console) as progress: pbar_pd = progress.add_task("% done", total=100) perc_done, _ = get_run_info(task_id) @@ -1178,29 +946,18 @@ def abort(task_id: TaskId) -> Optional[TaskInfo]: Object containing information about status, size, credits of task. """ console = get_logging_console() - try: - task = SimulationTask.get(task_id, verbose=False) - if task: - task.abort() - url = _get_url(task.task_id) - console.log( - f"Task is aborting. View task using web UI at [link={url}]'{url}'[/link] to check the result." - ) - return TaskInfo(**{"taskId": task.task_id, **task.dict()}) - except WebNotFoundError: - pass # Task not found, might be a batch task - - is_batch = BatchTask.is_batch(task_id, batch_type="RF_SWEEP") - if is_batch: - url = _get_url_rf(task_id) - console.log( - f"Batch task abortion is not yet supported, contact customer support." - f" View task using web UI at [link={url}]'{url}'[/link]." - ) - return - console.log("Task ID cannot be found to be aborted.") - return + task = TaskFactory.get(task_id, verbose=False) + if not task: + return None + url = task.get_url() + task.abort() + console.log( + f"Task is aborting. View task using web UI at [link={url}]'{url}'[/link] to check the result." + ) + return TaskInfo( + **{"taskId": task_id, "taskType": getattr(task, "task_type", None), **task.dict()} + ) @wait_for_connection @@ -1225,59 +982,26 @@ def download( """ path = Path(path) - - if _is_modeler_batch(task_id): - # Use a more descriptive default filename for component modeler downloads. - # If the caller left the default as 'simulation_data.hdf5', prefer 'cm_data.hdf5'. + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): if path.name == "simulation_data.hdf5": path = path.with_name("cm_data.hdf5") - - def _download_cm() -> bool: - try: - BatchTask(task_id).get_data_hdf5( - remote_data_file_gz=CM_DATA_HDF5_GZ, - to_file=path, - verbose=verbose, - progress_callback=progress_callback, - ) - return True - except Exception: - return False - - if not _download_cm(): - BatchTask(task_id).postprocess(batch_type="RF_SWEEP") - # wait for postprocess to finish - while True: - resp = BatchTask(task_id).detail(batch_type="RF_SWEEP") - total = resp.totalTask or 0 - post_succ = resp.postprocessSuccess or 0 - status = resp.totalStatus - status_str = status.value - if status_str in ERROR_STATES: - raise WebError( - f"Batch task {task_id} failed during postprocess: {status_str}" - ) from None - if total > 0 and post_succ >= total: - break - time.sleep(REFRESH_TIME) - if not _download_cm(): - raise WebError("Failed to download 'cm_data' after postprocess completion.") + task.get_data_hdf5( + to_file=path, + remote_data_file_gz=CM_DATA_HDF5_GZ, + verbose=verbose, + progress_callback=progress_callback, + ) return - - # Regular single-task download - task_info = get_info(task_id) - task_type = task_info.taskType - + info = get_info(task_id, verbose=False) remote_data_file = SIMULATION_DATA_HDF5_GZ - if task_type == "MODE_SOLVER": + if info.taskType == "MODE_SOLVER": remote_data_file = MODE_DATA_HDF5_GZ - - task = SimulationTask(taskId=task_id) - task.get_sim_data_hdf5( - path, + task.get_data_hdf5( + to_file=path, + remote_data_file_gz=remote_data_file, verbose=verbose, progress_callback=progress_callback, - remote_data_file=remote_data_file, ) @@ -1295,7 +1019,9 @@ def download_json(task_id: TaskId, path: PathLike = SIM_FILE_JSON, verbose: bool If ``True``, will print progressbars and status, otherwise, will run silently. """ - task = SimulationTask(taskId=task_id) + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") task.get_simulation_json(path, verbose=verbose) @@ -1326,7 +1052,9 @@ def load_simulation( Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] Simulation loaded from downloaded json file. """ - task = SimulationTask.get(task_id) + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") path = Path(path) if path.suffix == ".json": task.get_simulation_json(path, verbose=verbose) @@ -1361,7 +1089,9 @@ def download_log( ---- To load downloaded results into data, call :meth:`load` with option ``replace_existing=False``. """ - task = SimulationTask(taskId=task_id) + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") task.get_log(path, verbose=verbose, progress_callback=progress_callback) @@ -1412,10 +1142,11 @@ def load( Object containing simulation data. """ path = Path(path) + task = TaskFactory.get(task_id) if task_id else None # For component modeler batches, default to a clearer filename if the default was used. if ( task_id - and _is_modeler_batch(task_id) + and isinstance(task, BatchTask) and path.name in {"simulation_data.hdf5", "simulation_data.hdf5.gz"} ): path = path.with_name(path.name.replace("simulation", "cm")) @@ -1428,8 +1159,8 @@ def load( if verbose and task_id is not None: console = get_logging_console() - if _is_modeler_batch(task_id): - console.log(f"loading component modeler data from {path}") + if isinstance(task, BatchTask): + console.log(f"Loading component modeler data from {path}") else: console.log(f"Loading simulation from {path}") @@ -1461,74 +1192,49 @@ def load( return stub_data +def _status_to_stage(status: str) -> tuple[str, int]: + """Map task status to monotonic stage for progress bars.""" + s = (status or "").lower() + # Map a broader set of states to monotonic stages for progress bars + if s in ("draft", "created"): + return ("draft", 0) + if s in ("queue", "queued"): + return ("queued", 1) + if s in ("validating",): + return ("validating", 2) + if s in ("validate_success", "validate_warn", "preprocess", "preprocessing"): + return ("preprocess", 3) + if s in ("running", "preprocess_success"): + return ("running", 4) + if s in ("run_success", "postprocess"): + return ("postprocess", 5) + if s in ("success", "postprocess_success"): + return ("success", 6) + # Unknown states map to earliest stage to avoid showing 100% prematurely + return (s or "unknown", 0) + + def _monitor_modeler_batch( - batch_id: str, + task_id: str, verbose: bool = True, max_detail_tasks: int = 20, - worker_group: Optional[str] = None, ) -> None: """Monitor modeler batch progress with aggregate and per-task views.""" console = get_logging_console() if verbose else None - - def _status_to_stage(status: str) -> tuple[str, int]: - s = (status or "").lower() - # Map a broader set of states to monotonic stages for progress bars - if s in ("draft", "created"): - return ("draft", 0) - if s in ("queue", "queued"): - return ("queued", 1) - if s in ("preprocess",): - return ("preprocess", 1) - if s in ("validating",): - return ("validating", 2) - if s in ("validate_success", "validate_warn"): - return ("validate", 3) - if s in ("running",): - return ("running", 4) - if s in ("postprocess",): - return ("postprocess", 5) - if s in ("run_success", "success"): - return ("success", 6) - # Unknown states map to earliest stage to avoid showing 100% prematurely - return (s or "unknown", 0) - - detail = _batch_detail(batch_id) + task = BatchTask.get(task_id=task_id) + detail = _get_batch_detail_handle_error_status(task) name = detail.name or "modeler_batch" group_id = detail.groupId - - header = f"Subtasks status - {name}" - if group_id: - header += f"\nGroup ID: '{group_id}'" - if console is not None: - console.log(header) + status = detail.status.lower() # Non-verbose path: poll without progress bars then return if not verbose: # Run phase - while True: - d = _batch_detail(batch_id) - s = d.totalStatus.value - total = d.totalTask or 0 - r = d.runSuccess or 0 - if s in ERROR_STATES: - raise WebError(f"Batch {batch_id} terminated: {s}") - # Updated break condition for robustness - if s in ("run_success", "success") or (total and r >= total): - break + while _status_to_stage(status)[0] not in END_STATES: time.sleep(REFRESH_TIME) + detail = _get_batch_detail_handle_error_status(task) + status = detail.status.lower() - postprocess_start(batch_id, verbose=False, worker_group=worker_group) - - while True: - d = _batch_detail(batch_id) - postprocess_status = d.postprocessStatus - if postprocess_status == "success": - break - elif postprocess_status in ERROR_STATES: - raise WebError( - f"Batch {batch_id} terminated. Please contact customer support and provide this Component Modeler batch ID: '{batch_id}'" - ) - time.sleep(REFRESH_TIME) return progress_columns = ( @@ -1537,16 +1243,26 @@ def _status_to_stage(status: str) -> tuple[str, int]: TaskProgressColumn(), TimeElapsedColumn(), ) + # Make the header + header = f"Subtasks status - {name}" + if group_id: + header += f"\nGroup ID: '{group_id}'" + console.log(header) with Progress(*progress_columns, console=console, transient=False) as progress: # Phase: Run (aggregate + per-task) p_run = progress.add_task("Run Total", total=1.0) task_bars: dict[str, int] = {} + stage = _status_to_stage(status)[0] + prev_stage = _status_to_stage(status)[0] + console.log(f"Batch status = {status}") - while True: - detail = _batch_detail(batch_id) - status = detail.totalStatus.value - total = detail.totalTask or 0 + # Note: get_status errors if an erroring status occurred + while stage not in END_STATES: + total = len(detail.tasks) r = detail.runSuccess or 0 + if stage != prev_stage: + prev_stage = stage + console.log(f"Batch status = {stage}") # Create per-task bars as soon as tasks appear if total and total <= max_detail_tasks and detail.tasks: @@ -1593,40 +1309,15 @@ def _status_to_stage(status: str) -> tuple[str, int]: refresh=False, ) - # Updated break condition for robustness - if status in ("run_success", "success") or (total and r >= total): - break - if status in ERROR_STATES: - raise WebError(f"Batch {batch_id} terminated: {status}") - progress.refresh() - time.sleep(REFRESH_TIME) - - postprocess_start(batch_id, verbose=True, worker_group=worker_group) - - p_post = progress.add_task("Postprocess", total=1.0) - while True: - detail = _batch_detail(batch_id) - postprocess_status = detail.postprocessStatus - if postprocess_status == "success": - progress.update(p_post, completed=1.0) - progress.refresh() - break - elif postprocess_status == "queued": - progress.update(p_post, completed=0.22) - elif postprocess_status == "preprocess": - progress.update(p_post, completed=0.33) - elif postprocess_status == "running": - progress.update(p_post, completed=0.55) - elif postprocess_status in ERROR_STATES: - raise WebError( - f"Batch {batch_id} terminated. Please contact customer support and provide this Component Modeler batch ID: '{batch_id}'" - ) progress.refresh() time.sleep(REFRESH_TIME) + detail = _get_batch_detail_handle_error_status(task) + status = detail.status.lower() + stage = _status_to_stage(status)[0] if console is not None: console.log("Modeler has finished running successfully.") - real_cost(batch_id, verbose=verbose) + real_cost(task.task_id, verbose=verbose) @wait_for_connection @@ -1648,7 +1339,7 @@ def delete(task_id: TaskId, versions: bool = False) -> TaskInfo: """ if not task_id: raise ValueError("Task id not found.") - task = SimulationTask.get(task_id, verbose=False) + task = TaskFactory.get(task_id, verbose=False) task.delete(versions) return TaskInfo(**{"taskId": task.task_id, **task.dict()}) @@ -1674,14 +1365,13 @@ def download_simulation( Optional callback function called when downloading file with ``bytes_in_chunk`` as argument. """ - task_info = get_info(task_id) - task_type = task_info.taskType - + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") + info = get_info(task_id, verbose=False) remote_sim_file = SIM_FILE_HDF5_GZ - if task_type == "MODE_SOLVER": + if info.taskType == "MODE_SOLVER": remote_sim_file = MODE_FILE_HDF5_GZ - - task = SimulationTask(taskId=task_id) task.get_simulation_hdf5( path, verbose=verbose, @@ -1779,63 +1469,64 @@ def estimate_cost( console = get_logging_console() if verbose else None - if _is_modeler_batch(task_id): - d = _batch_detail(task_id) - status = d.totalStatus.value - - if status in ALL_POST_VALIDATE_STATES: - est_flex_unit = _batch_detail(task_id).estFlexUnit - if verbose: - console.log( - f"Maximum FlexCredit cost: {est_flex_unit:1.3f}. Minimum cost depends on " - "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " - "cost after a simulation run." - ) - return est_flex_unit - - elif status in ERROR_STATES: - return _batch_detail_error(resource_id=task_id) - - raise WebError("Could not get estimated cost!") - - else: - task = SimulationTask.get(task_id) - - if not task: - raise ValueError("Task not found.") - - task.estimate_cost(solver_version=solver_version) - task_info = get_info(task_id) - status = task_info.metadataStatus - - # Wait for a termination status + task = TaskFactory.get(task_id, verbose=False) + detail = task.detail() + if isinstance(task, BatchTask): + check_task_type = "FDTD" if detail.taskType == "MODAL_CM" else "RF_FDTD" + task.check(solver_version=solver_version, check_task_type=check_task_type) + detail = task.detail() + status = detail.status.lower() while status not in ALL_POST_VALIDATE_STATES: time.sleep(REFRESH_TIME) - task_info = get_info(task_id) - status = task_info.metadataStatus + detail = task.detail() + status = detail.status.lower() + if status in ERROR_STATES: + _batch_detail_error(resource_id=task_id) + est_flex_unit = detail.estFlexUnit + if verbose: + console.log( + f"Maximum FlexCredit cost: {est_flex_unit:1.3f}. Minimum cost depends on " + "task execution details. Use 'web.real_cost(task_id)' after run." + ) + return est_flex_unit - if status in ALL_POST_VALIDATE_STATES: - if verbose: - console.log( - f"Estimated FlexCredit cost: {task_info.estFlexUnit:1.3f}. Minimum cost depends on " - "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " - "cost after a simulation run." - ) - fc_mode = task_info.estFlexCreditMode - fc_post = task_info.estFlexCreditPostProcess - if fc_mode: - console.log(f" {fc_mode:1.3f} FlexCredit of the total cost from mode solves.") - if fc_post: - console.log( - f" {fc_post:1.3f} FlexCredit of the total cost from post-processing." - ) - return task_info.estFlexUnit + # simulation path + task.estimate_cost(solver_version=solver_version) + task_info = get_info(task_id) + status = task_info.metadataStatus - elif status in ERROR_STATES: - log.error(f"The task '{task_id}' has failed: {status}") + # Wait for a termination status + while status not in ALL_POST_VALIDATE_STATES: + time.sleep(REFRESH_TIME) + task_info = get_info(task_id) + status = task_info.metadataStatus - # Something went wrong - raise WebError("Could not get estimated cost!") + if status in ERROR_STATES: + try: + # Try to obtain the error message + task = SimulationTask(taskId=task_id) + with tempfile.NamedTemporaryFile(suffix=".json") as tmp_file: + task.get_error_json(to_file=tmp_file.name, validation=True) + with open(tmp_file.name) as f: + error_content = json.load(f) + error_msg = error_content["validation_error"] + except Exception: + # If the error message could not be obtained, raise a generic error message + error_msg = "Error message could not be obtained, please contact customer support." + raise WebError(f"Error estimating cost for task {task_id}! {error_msg}") + if verbose: + console.log( + f"Estimated FlexCredit cost: {task_info.estFlexUnit:1.3f}. Minimum cost depends on " + "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " + "cost after a simulation run." + ) + fc_mode = task_info.estFlexCreditMode + fc_post = task_info.estFlexCreditPostProcess + if fc_mode: + console.log(f" {fc_mode:1.3f} FlexCredit of the total cost from mode solves.") + if fc_post: + console.log(f" {fc_post:1.3f} FlexCredit of the total cost from post-processing.") + return task_info.estFlexUnit @wait_for_connection @@ -1891,42 +1582,24 @@ def real_cost(task_id: str, verbose: bool = True) -> float | None: ) console = get_logging_console() if verbose else None - if _is_modeler_batch(task_id): - status = _batch_detail(task_id).totalStatus.value - flex_unit = _batch_detail(task_id).realFlexUnit or None - if (status not in ["success", "run_success"]) or (flex_unit is None): - log.warning( - f"Billed FlexCredit for task '{task_id}' is not available. If the task has been " - "successfully run, it should be available shortly. If this issue persists, contact customer support." - ) - else: - if verbose: + task_info = get_info(task_id) + flex_unit = task_info.realFlexUnit + ori_flex_unit = getattr(task_info, "oriRealFlexUnit", flex_unit) + if not flex_unit: + log.warning( + f"Billed FlexCredit for task '{task_id}' is not available. If the task has been " + "successfully run, it should be available shortly." + ) + else: + if verbose: + console.log(f"Billed flex credit cost: {flex_unit:1.3f}.") + if flex_unit != ori_flex_unit and "FDTD" in task_info.taskType: console.log( - f"Billed FlexCredit cost: {flex_unit:1.3f}. Minimum cost depends on " - "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " - "cost after a simulation run." + "Note: the task cost pro-rated due to early shutoff was below the minimum " + "threshold, due to fast shutoff. Decreasing the simulation 'run_time' should " + "decrease the estimated, and correspondingly the billed cost of such tasks." ) - - return flex_unit - else: - task_info = get_info(task_id) - flex_unit = task_info.realFlexUnit - ori_flex_unit = task_info.oriRealFlexUnit - if not flex_unit: - log.warning( - f"Billed FlexCredit for task '{task_id}' is not available. If the task has been " - "successfully run, it should be available shortly." - ) - else: - if verbose: - console.log(f"Billed flex credit cost: {flex_unit:1.3f}.") - if flex_unit != ori_flex_unit and task_info.taskType == "FDTD": - console.log( - "Note: the task cost pro-rated due to early shutoff was below the minimum " - "threshold, due to fast shutoff. Decreasing the simulation 'run_time' should " - "decrease the estimated, and correspondingly the billed cost of such tasks." - ) - return flex_unit + return flex_unit @wait_for_connection @@ -1986,49 +1659,6 @@ def account(verbose: bool = True) -> Account: return account_info -@wait_for_connection -def postprocess_start( - batch_id: str, - verbose: bool = True, - worker_group: Optional[str] = None, -) -> None: - """ - Checks if a batch run is complete and starts the postprocess phase. - - This function does not wait for postprocessing to finish. - """ - console = get_logging_console() if verbose else None - if _is_modeler_batch(batch_id): - # Perform a single check on the run phase status - detail = _batch_detail(batch_id) - status = detail.totalStatus.value - total_tasks = detail.totalTask or 0 - successful_runs = detail.runSuccess or 0 - - if status in ERROR_STATES: - raise WebError(f"Batch '{batch_id}' terminated with error status: {status}") - - # Check if the run phase is complete before proceeding - is_run_complete = status in ("run_success", "success") or ( - total_tasks > 0 and successful_runs >= total_tasks - ) - - if not is_run_complete: - if console: - console.log( - f"Batch '{batch_id}' run phase is not yet complete (Status: {status}). " - f"Cannot start postprocessing." - ) - return # Exit if the run is not done - BatchTask(batch_id).postprocess(batch_type="RF_SWEEP", worker_group=worker_group) - return - else: - raise WebError( - f"Batch ID '{batch_id}' is not a component modeler batch job. " - "'postprocess_start' is only applicable to those classes." - ) - - @wait_for_connection def test() -> None: """Confirm whether Tidy3D authentication is configured. diff --git a/tidy3d/web/core/constants.py b/tidy3d/web/core/constants.py index 109d985026..623af2bba8 100644 --- a/tidy3d/web/core/constants.py +++ b/tidy3d/web/core/constants.py @@ -31,6 +31,7 @@ MODE_FILE_HDF5_GZ = "mode_solver.hdf5.gz" MODE_DATA_HDF5_GZ = "output/mode_solver_data.hdf5.gz" SIM_ERROR_FILE = "output/tidy3d_error.json" +SIM_VALIDATION_FILE = "output/tidy3d_validation.json" # Component modeler specific artifacts MODELER_FILE_HDF5_GZ = "modeler.hdf5.gz" diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index be66ea8985..9e5c9f3f9f 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -42,6 +42,7 @@ class ResponseCodes(Enum): def get_version() -> str: """Get the version for the current environment.""" return core_config.get_version() + # return "2.10.0rc2.1" def get_user_agent() -> str: diff --git a/tidy3d/web/core/task_core.py b/tidy3d/web/core/task_core.py index e3d7db1a97..f1e4ca8c7d 100644 --- a/tidy3d/web/core/task_core.py +++ b/tidy3d/web/core/task_core.py @@ -5,7 +5,6 @@ import os import pathlib import tempfile -import time from datetime import datetime from os import PathLike from typing import Callable, Optional, Union @@ -17,7 +16,6 @@ import tidy3d as td from tidy3d.config import config from tidy3d.exceptions import ValidationError -from tidy3d.web.common import REFRESH_TIME from . import http_util from .cache import FOLDER_CACHE @@ -25,6 +23,7 @@ SIM_ERROR_FILE, SIM_FILE_HDF5_GZ, SIM_LOG_FILE, + SIM_VALIDATION_FILE, SIMULATION_DATA_HDF5_GZ, ) from .core_config import get_logger_console @@ -34,7 +33,7 @@ from .http_util import http from .s3utils import download_file, download_gz_file, upload_file from .stub import TaskStub -from .task_info import BatchDetail +from .task_info import BatchDetail, TaskInfo from .types import PayType, Queryable, ResourceLifecycle, Submittable, Tidy3DResource @@ -145,8 +144,8 @@ def list_tasks(self, projects_endpoint: str = "tidy3d/projects") -> list[Tidy3DR ) -class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): - """Interface for managing the running of a :class:`.Simulation` task on server.""" +class WebTask(ResourceLifecycle, Submittable, extra=Extra.allow): + """Interface for managing the running a task on the server.""" task_id: Optional[str] = Field( ..., @@ -154,52 +153,6 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): description="Task ID number, set when the task is uploaded, leave as None.", alias="taskId", ) - folder_id: Optional[str] = Field( - None, - title="folder_id", - description="Folder ID number, set when the task is uploaded, leave as None.", - alias="folderId", - ) - status: Optional[str] = Field(title="status", description="Simulation task status.") - - real_flex_unit: float = Field( - None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" - ) - - created_at: Optional[datetime] = Field( - title="created_at", description="Time at which this task was created.", alias="createdAt" - ) - - task_type: Optional[str] = Field( - title="task_type", description="The type of task.", alias="taskType" - ) - - folder_name: Optional[str] = Field( - "default", - title="Folder Name", - description="Name of the folder associated with this task.", - alias="folderName", - ) - - callback_url: str = Field( - None, - title="Callback URL", - description="Http PUT url to receive simulation finish event. " - "The body content is a json file with fields " - "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", - ) - - # simulation_type: str = pd.Field( - # None, - # title="Simulation Type", - # description="Type of simulation, used internally only.", - # ) - - # parent_tasks: Tuple[TaskId, ...] = pd.Field( - # None, - # title="Parent Tasks", - # description="List of parent task ids for the simulation, used internally only." - # ) @classmethod def create( @@ -211,7 +164,6 @@ def create( simulation_type: str = "tidy3d", parent_tasks: Optional[list[str]] = None, file_type: str = "Gz", - port_name_list: Optional[list[str]] = None, projects_endpoint: str = "tidy3d/projects", ) -> SimulationTask: """Create a new task on the server. @@ -246,32 +198,241 @@ def create( simulation_type = "tidy3d" folder = Folder.get(folder_name, create=True) - payload = { - "taskName": task_name, - "taskType": task_type, - "callbackUrl": callback_url, - "simulationType": simulation_type, - "parentTasks": parent_tasks, - "fileType": file_type, - } - # Component modeler: include port names if provided - if port_name_list: - # Align with backend contract: expect 'portNames' (not 'portNameList') - payload["portNames"] = port_name_list - - resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) - # RF group creation may return group-level info without 'taskId'. - # Use 'groupId' (or 'batchId' as fallback) as the resource id for subsequent uploads. - if "taskId" not in resp and task_type == "RF": - # Prefer using 'batchId' as the resource id for uploads (S3 STS expects a task-like id). - if "batchId" in resp: - resp["taskId"] = resp["batchId"] - elif "groupId" in resp: - resp["taskId"] = resp["groupId"] - else: - raise WebError("Missing resource ID for task creation. Contact customer support.") + + if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: + payload = { + "groupName": task_name, + "folderId": folder.folder_id, + "fileType": file_type, + "taskType": task_type, + } + resp = http.post("rf/task", payload) + else: + payload = { + "taskName": task_name, + "taskType": task_type, + "callbackUrl": callback_url, + "simulationType": simulation_type, + "parentTasks": parent_tasks, + "fileType": file_type, + } + resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) + return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) + def get_url(self) -> str: + base = str(config.web.website_endpoint or "") + if isinstance(self, BatchTask): + return "/".join([base.rstrip("/"), f"rf?taskId={self.task_id}"]) + return "/".join([base.rstrip("/"), f"workbench?taskId={self.task_id}"]) + + def get_folder_url(self) -> Optional[str]: + folder_id = getattr(self, "folder_id", None) + if not folder_id: + return None + base = str(config.web.website_endpoint or "") + return "/".join([base.rstrip("/"), f"folders/{folder_id}"]) + + def get_log( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Get log file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_file( + self.task_id, + SIM_LOG_FILE, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_data_hdf5( + self, + to_file: PathLike, + remote_data_file_gz: PathLike = SIMULATION_DATA_HDF5_GZ, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Download data artifact (simulation or batch) with gz fallback handling. + + Parameters + ---------- + remote_data_file_gz : PathLike + Gzipped remote filename. + to_file : PathLike + Local target path. + verbose : bool + Whether to log progress. + progress_callback : Optional[Callable[[float], None]] + Progress callback. + + Returns + ------- + pathlib.Path + Saved local path. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + target_path = pathlib.Path(to_file) + file = None + try: + file = download_gz_file( + resource_id=self.task_id, + remote_filename=remote_data_file_gz, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except ClientError: + if verbose: + console = get_logger_console() + console.log(f"Unable to download '{remote_data_file_gz}'.") + if not file: + try: + file = download_file( + resource_id=self.task_id, + remote_filename=str(remote_data_file_gz)[:-3], + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except Exception as e: + raise WebError( + "Failed to download the data file from the server. " + "Please confirm that the task completed successfully." + ) from e + return file + + @staticmethod + def is_batch(resource_id: str) -> bool: + """Checks if a given resource ID corresponds to a valid batch task. + + This is a utility function to verify a batch task's existence before + instantiating the class. + + Parameters + ---------- + resource_id : str + The unique identifier for the resource. + + Returns + ------- + bool + ``True`` if the resource is a valid batch task, ``False`` otherwise. + """ + try: + # TODO PROPERLY FIXME + # Disable non critical logs due to check for resourceId, until we have a dedicated API for this + resp = http.get( + f"rf/task/{resource_id}/statistics", + suppress_404=True, + ) + status = bool(resp and isinstance(resp, dict) and "status" in resp) + return status + except Exception: + return False + + def delete(self, versions: bool = False) -> None: + """Delete current task from server. + + Parameters + ---------- + versions : bool = False + If ``True``, delete all versions of the task in the task group. Otherwise, delete only + the version associated with the current task ID. + """ + if not self.task_id: + raise ValueError("Task id not found.") + + task_details = self.detail().dict() + + if task_details and "groupId" in task_details: + group_id = task_details["groupId"] + if versions: + http.delete("tidy3d/group", json={"groupIds": [group_id]}) + return + elif "version" in task_details: + version = task_details["version"] + http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) + return + + # Fallback to old method if we can't get the groupId and version + http.delete(f"tidy3d/tasks/{self.task_id}") + + +class SimulationTask(WebTask): + """Interface for managing the running of solver tasks on the server.""" + + folder_id: Optional[str] = Field( + None, + title="folder_id", + description="Folder ID number, set when the task is uploaded, leave as None.", + alias="folderId", + ) + status: Optional[str] = Field(title="status", description="Simulation task status.") + + real_flex_unit: float = Field( + None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" + ) + + created_at: Optional[datetime] = Field( + title="created_at", description="Time at which this task was created.", alias="createdAt" + ) + + task_type: Optional[str] = Field( + title="task_type", description="The type of task.", alias="taskType" + ) + + folder_name: Optional[str] = Field( + "default", + title="Folder Name", + description="Name of the folder associated with this task.", + alias="folderName", + ) + + callback_url: str = Field( + None, + title="Callback URL", + description="Http PUT url to receive simulation finish event. " + "The body content is a json file with fields " + "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", + ) + + # simulation_type: str = pd.Field( + # None, + # title="Simulation Type", + # description="Type of simulation, used internally only.", + # ) + + # parent_tasks: Tuple[TaskId, ...] = pd.Field( + # None, + # title="Parent Tasks", + # description="List of parent task ids for the simulation, used internally only." + # ) + @classmethod def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: """Get task from the server by id. @@ -313,28 +474,16 @@ def get_running_tasks(cls) -> list[SimulationTask]: return [] return parse_obj_as(list[SimulationTask], resp) - def delete(self, versions: bool = False) -> None: - """Delete current task from server. + def detail(self) -> TaskInfo: + """Fetches the detailed information and status of the task. - Parameters - ---------- - versions : bool = False - If ``True``, delete all versions of the task in the task group. Otherwise, delete only the version associated with the current task ID. + Returns + ------- + TaskInfo + An object containing the task's latest data. """ - if not self.task_id: - raise ValueError("Task id not found.") - - task_details = http.get(f"tidy3d/tasks/{self.task_id}") - - if task_details and "groupId" in task_details and "version" in task_details: - group_id = task_details["groupId"] - version = task_details["version"] - if versions: - http.delete("tidy3d/group", json={"groupIds": [group_id]}) - else: - http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) - else: # Fallback to old method if we can't get the groupId and version - http.delete(f"tidy3d/tasks/{self.task_id}") + resp = http.get(f"tidy3d/tasks/{self.task_id}/detail") + return TaskInfo(**{"taskId": self.task_id, "taskType": self.task_type, **resp}) def get_simulation_json(self, to_file: PathLike, verbose: bool = True) -> None: """Get json file for a :class:`.Simulation` from server. @@ -515,65 +664,6 @@ def estimate_cost(self, solver_version: Optional[str] = None) -> float: ) return resp - def get_sim_data_hdf5( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - remote_data_file: PathLike = SIMULATION_DATA_HDF5_GZ, - ) -> pathlib.Path: - """Get simulation data file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - file = None - try: - file = download_gz_file( - resource_id=self.task_id, - remote_filename=remote_data_file, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except ClientError: - if verbose: - console = get_logger_console() - console.log(f"Unable to download '{remote_data_file}'.") - - if not file: - try: - file = download_file( - resource_id=self.task_id, - remote_filename=remote_data_file[:-3], - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except Exception as e: - raise WebError( - "Failed to download the simulation data file from the server. " - "Please confirm that the task was successfully run." - ) from e - - return file - def get_simulation_hdf5( self, to_file: PathLike, @@ -666,7 +756,9 @@ def get_log( progress_callback=progress_callback, ) - def get_error_json(self, to_file: PathLike, verbose: bool = True) -> pathlib.Path: + def get_error_json( + self, to_file: PathLike, verbose: bool = True, validation: bool = False + ) -> pathlib.Path: """Get error json file for a :class:`.Simulation` from server. Parameters @@ -675,6 +767,8 @@ def get_error_json(self, to_file: PathLike, verbose: bool = True) -> pathlib.Pat Save file to path. verbose: bool = True Whether to display progress bars. + validation: bool = False + Whether to get a validation error file or a solver error file. Returns ------- @@ -685,16 +779,17 @@ def get_error_json(self, to_file: PathLike, verbose: bool = True) -> pathlib.Pat raise WebError("Expected field 'task_id' is unset.") target_path = pathlib.Path(to_file) + target_file = SIM_ERROR_FILE if not validation else SIM_VALIDATION_FILE return download_file( self.task_id, - SIM_ERROR_FILE, + target_file, to_file=target_path, verbose=verbose, ) def abort(self) -> requests.Response: - """Aborting current task from server.""" + """Abort the current task on the server.""" if not self.task_id: raise ValueError("Task id not found.") return http.put( @@ -732,68 +827,43 @@ def validate_post_upload(self, parent_tasks: Optional[list[str]] = None) -> None raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e -class BatchTask: - """Provides a client-side interface for managing a remote batch task. - - This class acts as a wrapper around the API endpoints for a specific batch, - allowing users to check, submit, monitor, and download data from it. +class BatchTask(WebTask): + """Interface for managing a batch task on the server.""" - Note: - The 'batch_type' (e.g., "RF_SWEEP") must be provided by the caller to - most methods, as it dictates which backend service handles the request. - """ - - def __init__(self, batch_id: str) -> None: - self.batch_id = batch_id - - @staticmethod - def is_batch(resource_id: str, batch_type: str) -> bool: - """Checks if a given resource ID corresponds to a valid batch task. - - This is a utility function to verify a batch task's existence before - instantiating the class. + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> BatchTask: + """Get batch task by id. Parameters ---------- - resource_id : str - The unique identifier for the resource. - batch_type : str - The type of the batch to check (e.g., "RF_SWEEP"). + task_id: str + Unique identifier of batch on server. + verbose: + If `True`, will print progressbars and status, otherwise, will run silently. Returns ------- - bool - ``True`` if the resource is a valid batch task, ``False`` otherwise. + :class:`.BatchTask` | None + BatchTask object if found, otherwise None. """ try: - # TODO PROPERLY FIXME - # Disable non critical logs due to check for resourceId, until we have a dedicated API for this - resp = http.get( - f"tidy3d/tasks/{resource_id}/batch-detail", - params={"batchType": batch_type}, - suppress_404=True, - ) - status = bool(resp and isinstance(resp, dict) and "status" in resp) - return status - except Exception: - return False + resp = http.get(f"rf/task/{task_id}/statistics") + except WebNotFoundError as e: + td.log.error(f"The requested batch ID '{task_id}' does not exist.") + raise e + # We only need to validate existence; store id on the instance. + return BatchTask(taskId=task_id) if resp else None - def detail(self, batch_type: str) -> BatchDetail: + def detail(self) -> BatchDetail: """Fetches the detailed information and status of the batch. - Parameters - ---------- - batch_type : str - The type of the batch (e.g., "RF_SWEEP"). - Returns ------- BatchDetail An object containing the batch's latest data. """ resp = http.get( - f"tidy3d/tasks/{self.batch_id}/batch-detail", - params={"batchType": batch_type}, + f"rf/task/{self.task_id}/statistics", ) # Some backends may return null for collection fields; coerce to sensible defaults if isinstance(resp, dict): @@ -803,9 +873,9 @@ def detail(self, batch_type: str) -> BatchDetail: def check( self, + check_task_type: str, solver_version: Optional[str] = None, protocol_version: Optional[str] = None, - batch_type: str = "", ) -> requests.Response: """Submits a request to validate the batch configuration on the server. @@ -815,8 +885,6 @@ def check( The version of the solver to use for validation. protocol_version : Optional[str], default=None The data protocol version. Defaults to the current version. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). Returns ------- @@ -826,11 +894,11 @@ def check( if protocol_version is None: protocol_version = _get_protocol_version() return http.post( - f"tidy3d/projects/{self.batch_id}/batch-check", + f"rf/task/{self.task_id}/check", { - "batchType": batch_type, "solverVersion": solver_version, "protocolVersion": protocol_version, + "taskType": check_task_type, }, ) @@ -839,7 +907,8 @@ def submit( solver_version: Optional[str] = None, protocol_version: Optional[str] = None, worker_group: Optional[str] = None, - batch_type: str = "", + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, ) -> requests.Response: """Submits the batch for execution on the server. @@ -851,199 +920,67 @@ def submit( The data protocol version. Defaults to the current version. worker_group : Optional[str], default=None Optional identifier for a specific worker group to run on. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). Returns ------- Any The server's response to the submit request. """ - if protocol_version is None: - protocol_version = _get_protocol_version() - return http.post( - f"tidy3d/projects/{self.batch_id}/batch-submit", - { - "batchType": batch_type, - "solverVersion": solver_version, - "protocolVersion": protocol_version, - "workerGroup": worker_group, - }, - ) - - def postprocess( - self, - solver_version: Optional[str] = None, - protocol_version: Optional[str] = None, - worker_group: Optional[str] = None, - batch_type: str = "", - ) -> requests.Response: - """Initiates post-processing for a completed batch run. - Parameters - ---------- - solver_version : Optional[str], default=None - The version of the solver to use for post-processing. - protocol_version : Optional[str], default=None - The data protocol version. Defaults to the current version. - worker_group : Optional[str], default=None - Optional identifier for a specific worker group to run on. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). + # TODO: add support for pay_type and priority arguments + if pay_type != PayType.AUTO: + raise NotImplementedError( + "The 'pay_type' argument is not yet supported and will be ignored." + ) + if priority is not None: + raise NotImplementedError( + "The 'priority' argument is not yet supported and will be ignored." + ) - Returns - ------- - Any - The server's response to the post-process request. - """ if protocol_version is None: protocol_version = _get_protocol_version() return http.post( - f"tidy3d/projects/{self.batch_id}/postprocess", + f"rf/task/{self.task_id}/submit", { - "batchType": batch_type, "solverVersion": solver_version, "protocolVersion": protocol_version, "workerGroup": worker_group, }, ) - def wait_for_validate( - self, timeout: Optional[float] = None, batch_type: str = "" - ) -> BatchDetail: - """Waits for the batch to complete the validation stage by polling its status. - - Parameters - ---------- - timeout : Optional[float], default=None - Maximum time in seconds to wait. If ``None``, waits indefinitely. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). - - Returns - ------- - BatchDetail - The final object after validation completes or a timeout occurs. - - Notes - ----- - This method blocks until the batch status is 'validate_success', - 'validate_warn', 'validate_fail', or another terminal state like 'blocked' - or 'aborted', or until the timeout is reached. - """ - start = datetime.now().timestamp() - while True: - d = self.detail(batch_type=batch_type) - status = d.totalStatus - if status in ("validate_success", "validate_warn", "validate_fail"): - return d - if status in ("blocked", "aborting", "aborted"): - return d - if timeout is not None and (datetime.now().timestamp() - start) > timeout: - return d - time.sleep(REFRESH_TIME) - - def wait_for_run(self, timeout: Optional[float] = None, batch_type: str = "") -> BatchDetail: - """Waits for the batch to complete the execution stage by polling its status. - - Parameters - ---------- - timeout : Optional[float], default=None - Maximum time in seconds to wait. If ``None``, waits indefinitely. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). - - Returns - ------- - BatchDetail - The final object after the run completes or a timeout occurs. - - Notes - ----- - This method blocks until the batch status reaches a terminal run state like - 'run_success', 'run_failed', 'diverged', 'blocked', or 'aborted', - or until the timeout is reached. - """ - start = datetime.now().timestamp() - while True: - d = self.detail(batch_type=batch_type) - status = d.totalStatus - if status in ( - "run_success", - "run_failed", - "diverged", - "blocked", - "aborting", - "aborted", - ): - return d - if timeout is not None and (datetime.now().timestamp() - start) > timeout: - return d - time.sleep(REFRESH_TIME) - - def get_data_hdf5( - self, - remote_data_file_gz: str, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Downloads a batch data artifact, with a fallback mechanism. + def abort(self) -> requests.Response: + """Abort the current task on the server.""" + if not self.task_id: + raise ValueError("Batch id not found.") + return http.put(f"rf/task/{self.task_id}/abort", {}) - Parameters - ---------- - remote_data_file_gz : str - Remote gzipped filename to download (e.g., 'output/cm_data.hdf5.gz'). - to_file : PathLike - Local path where the downloaded file will be saved. - verbose : bool, default=True - If ``True``, shows progress logs and messages. - progress_callback : Optional[Callable[[float], None]], default=None - Optional callback function for progress updates, which receives the - download percentage as a float. - Returns - ------- - pathlib.Path - An object pointing to the downloaded local file. +class TaskFactory: + """Factory for obtaining the correct task subclass.""" - Raises - ------ - WebError - If both the gzipped and uncompressed file downloads fail. + _REGISTRY: dict[str, str] = {} - Notes - ----- - This method first attempts to download the gzipped version of a file. - If that fails, it falls back to downloading the uncompressed version. - """ - file = None - try: - file = download_gz_file( - resource_id=self.batch_id, - remote_filename=remote_data_file_gz, - to_file=to_file, - verbose=verbose, - progress_callback=progress_callback, - ) - except ClientError: - if verbose: - console = get_logger_console() - console.log(f"Unable to download '{remote_data_file_gz}'.") + @classmethod + def reset(cls) -> None: + """Clear the cached task kind registry (used in tests).""" + cls._REGISTRY.clear() - if not file: - try: - file = download_file( - resource_id=self.batch_id, - remote_filename=remote_data_file_gz[:-3], - to_file=to_file, - verbose=verbose, - progress_callback=progress_callback, - ) - except Exception as e: - raise WebError( - "Failed to download the batch data file from the server. " - "Please confirm that the batch has been successfully postprocessed." - ) from e + @classmethod + def register(cls, task_id: str, kind: str) -> None: + cls._REGISTRY[task_id] = kind - return file + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> WebTask: + kind = cls._REGISTRY.get(task_id) + if kind == "batch": + return BatchTask.get(task_id, verbose=verbose) + if kind == "simulation": + task = SimulationTask.get(task_id, verbose=verbose) + return task + if WebTask.is_batch(task_id): + cls.register(task_id, "batch") + return BatchTask.get(task_id, verbose=verbose) + task = SimulationTask.get(task_id, verbose=verbose) + if task: + cls.register(task_id, "simulation") + return task diff --git a/tidy3d/web/core/task_info.py b/tidy3d/web/core/task_info.py index 64cf00ea0f..825ab9b6a5 100644 --- a/tidy3d/web/core/task_info.py +++ b/tidy3d/web/core/task_info.py @@ -10,31 +10,6 @@ import pydantic.v1 as pydantic -class TaskStatus(Enum): - """The statuses that the task can be in.""" - - INIT = "initialized" - """The task has been initialized.""" - - QUEUE = "queued" - """The task is in the queue.""" - - PRE = "preprocessing" - """The task is in the preprocessing stage.""" - - RUN = "running" - """The task is running.""" - - POST = "postprocessing" - """The task is in the postprocessing stage.""" - - SUCCESS = "success" - """The task has completed successfully.""" - - ERROR = "error" - """The task has completed with an error.""" - - class TaskBase(pydantic.BaseModel, ABC): """Base configuration for all task objects.""" @@ -153,6 +128,9 @@ class TaskInfo(TaskBase): taskBlockInfo: TaskBlockInfo = None """Blocking information for the task.""" + version: str = None + """Version of the task.""" + class RunInfo(TaskBase): """Information about the run of a task.""" @@ -172,41 +150,6 @@ def display(self) -> None: # ---------------------- Batch (Modeler) detail schema ---------------------- # -class BatchStatus(str, Enum): - """Enumerates the possible statuses for a batch of tasks.""" - - draft = "draft" - """The batch is being configured and has not been submitted.""" - preprocess = "preprocess" - """The batch is undergoing preprocessing.""" - validating = "validating" - """The tasks within the batch are being validated.""" - validate_success = "validate_success" - """All tasks in the batch passed validation.""" - validate_warn = "validate_warn" - """Validation passed, but with warnings.""" - validate_fail = "validate_fail" - """Validation failed for one or more tasks.""" - blocked = "blocked" - """The batch is blocked and cannot run.""" - running = "running" - """The batch is currently executing.""" - aborting = "aborting" - """The batch is in the process of being aborted.""" - run_success = "run_success" - """The batch completed successfully.""" - postprocess = "postprocess" - """The batch is undergoing postprocessing.""" - run_failed = "run_failed" - """The batch execution failed.""" - diverged = "diverged" - """The simulation in the batch diverged.""" - aborted = "aborted" - """The batch was successfully aborted.""" - error = "error" - """An error occurred during the solver run.""" - - class BatchTaskBlockInfo(TaskBlockInfo): """ Extends `TaskBlockInfo` with specific details for batch task blocking. @@ -282,7 +225,6 @@ class BatchDetail(TaskBase): groupId: Identifier for the group the batch belongs to. name: The user-defined name of the batch. status: The current status of the batch. - totalStatus: The overall status, consolidating individual task statuses. totalTask: The total number of tasks in the batch. preprocessSuccess: The count of tasks that completed preprocessing. postprocessStatus: The status of the batch's postprocessing stage. @@ -295,6 +237,7 @@ class BatchDetail(TaskBase): totalCheckMillis: Total time in milliseconds spent on checks. message: A general message providing information about the batch status. tasks: A list of `BatchMember` objects, one for each task in the batch. + taskType: The type of tasks contained in the batch. """ refId: str = None @@ -302,7 +245,6 @@ class BatchDetail(TaskBase): groupId: str = None name: str = None status: str = None - totalStatus: BatchStatus = None totalTask: int = 0 preprocessSuccess: int = 0 postprocessStatus: str = None @@ -317,6 +259,8 @@ class BatchDetail(TaskBase): message: str = None tasks: list[BatchMember] = [] validateErrors: dict = None + taskType: str = "RF" + version: str = None class AsyncJobDetail(TaskBase): diff --git a/tidy3d/web/core/types.py b/tidy3d/web/core/types.py index 39ab0bc926..1c1439366c 100644 --- a/tidy3d/web/core/types.py +++ b/tidy3d/web/core/types.py @@ -56,8 +56,8 @@ class TaskType(str, Enum): EME = "EME" MODE = "MODE" VOLUME_MESH = "VOLUME_MESH" - COMPONENT_MODELER = "COMPONENT_MODELER" - TERMINAL_COMPONENT_MODELER = "TERMINAL_COMPONENT_MODELER" + MODAL_CM = "MODAL_CM" + TERMINAL_CM = "TERMINAL_CM" class PayType(str, Enum): diff --git a/tidy3d/web/tests/conftest.py b/tidy3d/web/tests/conftest.py new file mode 100644 index 0000000000..918dce5588 --- /dev/null +++ b/tidy3d/web/tests/conftest.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from tidy3d_frontend.tidy3d.web.core.task_core import TaskFactory + + +@pytest.fixture(autouse=True) +def clear_task_factory_registry() -> Generator[None, None, None]: + """Ensure TaskFactory registry is empty for each test.""" + TaskFactory.reset() + TaskFactory.reset() + yield + TaskFactory.reset()