diff --git a/.github/workflows/tidy3d-python-client-tests.yml b/.github/workflows/tidy3d-python-client-tests.yml index d005d021fa..d414ca4758 100644 --- a/.github/workflows/tidy3d-python-client-tests.yml +++ b/.github/workflows/tidy3d-python-client-tests.yml @@ -347,7 +347,7 @@ jobs: BRANCH_NAME="${STEPS_EXTRACT_BRANCH_NAME_OUTPUTS_BRANCH_NAME}" echo $BRANCH_NAME # Allow only Jira keys from known projects, even if the branch has an author prefix - ALLOWED_JIRA_PROJECTS=("FXC" "SCEM") + ALLOWED_JIRA_PROJECTS=("FXC" "SCEM" "SCRF") JIRA_PROJECT_PATTERN=$(IFS='|'; echo "${ALLOWED_JIRA_PROJECTS[*]}") JIRA_PATTERN="(${JIRA_PROJECT_PATTERN})-[0-9]+" diff --git a/docs/api/submit_simulations.rst b/docs/api/submit_simulations.rst index 4c08f6e9d7..795f045586 100644 --- a/docs/api/submit_simulations.rst +++ b/docs/api/submit_simulations.rst @@ -94,7 +94,6 @@ Information Containers :template: module.rst tidy3d.web.core.task_info.TaskInfo - tidy3d.web.core.task_info.TaskStatus Mode Solver Web API diff --git a/tests/test_plugins/smatrix/test_run_functions.py b/tests/test_plugins/smatrix/test_run_functions.py index 64c18c9746..a4207165e5 100644 --- a/tests/test_plugins/smatrix/test_run_functions.py +++ b/tests/test_plugins/smatrix/test_run_functions.py @@ -1,6 +1,5 @@ from __future__ import annotations -import json from unittest.mock import MagicMock import pydantic.v1 as pd @@ -14,38 +13,15 @@ make_component_modeler as make_modal_component_modeler, ) from tidy3d import SimulationDataMap -from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.sim_data import SimulationData from tidy3d.plugins.smatrix.data.terminal import TerminalComponentModelerData from tidy3d.plugins.smatrix.run import ( _run_local, - compose_modeler, compose_modeler_data, create_batch, ) -def test_compose_modeler_unsupported_type(tmp_path, monkeypatch): - # Create a dummy HDF5 file path - modeler_file = tmp_path / "dummy_modeler.hdf5" - - # Prepare a dummy JSON string with an unsupported type - dummy_json = {"type": "UnsupportedComponentModeler", "some_key": "some_value"} - dummy_json_str = json.dumps(dummy_json) - - # Mock Tidy3dBaseModel._json_string_from_hdf5 to return our dummy JSON string - def mock_json_string_from_hdf5(filepath): - if filepath == str(modeler_file): - return dummy_json_str - return "" - - monkeypatch.setattr(Tidy3dBaseModel, "_json_string_from_hdf5", mock_json_string_from_hdf5) - - # Expect a TypeError when calling compose_modeler with the unsupported type - with pytest.raises(TypeError, match="Unsupported modeler type: str"): - compose_modeler(modeler_file=str(modeler_file)) - - def test_create_batch(monkeypatch, tmp_path): # Mock Batch and Batch.to_file mock_batch_instance = MagicMock() diff --git a/tests/test_plugins/test_array_factor.py b/tests/test_plugins/test_array_factor.py index 2ffe92afee..959e1dc7b0 100644 --- a/tests/test_plugins/test_array_factor.py +++ b/tests/test_plugins/test_array_factor.py @@ -363,16 +363,15 @@ def make_antenna_sim(): remove_dc_component=False, # Include DC component for more accuracy at low frequencies ) - sim_unit = list(modeler.sim_dict.values())[0] - - return sim_unit + return modeler def test_rectangular_array_calculator_array_make_antenna_array(): """Test automatic antenna array creation.""" freq0 = 10e9 wavelength0 = td.C_0 / 10e9 - sim_unit = make_antenna_sim() + modeler = make_antenna_sim() + sim_unit = list(modeler.sim_dict.values())[0] array_calculator = mw.RectangularAntennaArrayCalculator( array_size=(1, 2, 3), spacings=(0.5 * wavelength0, 0.6 * wavelength0, 0.4 * wavelength0), @@ -437,8 +436,9 @@ def test_rectangular_array_calculator_array_make_antenna_array(): assert len(sim_array.sources) == 6 # check that override_structures are duplicated - assert len(sim_unit.grid_spec.override_structures) == 2 - assert len(sim_array.grid_spec.override_structures) == 7 + # assert len(sim_unit.grid_spec.override_structures) == 2 + # assert len(sim_array.grid_spec.override_structures) == 7 + assert sim_unit.grid.boundaries == modeler.base_sim.grid.boundaries # check that phase shifts are applied correctly phases_expected = array_calculator._antenna_phases @@ -674,7 +674,8 @@ def test_rectangular_array_calculator_simulation_data_from_array_factor(): phase_shifts=(np.pi / 3, np.pi / 4, np.pi / 5), ) - sim_unit = make_antenna_sim() + modeler = make_antenna_sim() + sim_unit = list(modeler.sim_dict.values())[0] monitor = sim_unit.monitors[0] monitor_directivity = sim_unit.monitors[2] diff --git a/tests/test_web/test_local_cache.py b/tests/test_web/test_local_cache.py index 0762fe08c7..3c8a9c7ee6 100644 --- a/tests/test_web/test_local_cache.py +++ b/tests/test_web/test_local_cache.py @@ -43,7 +43,7 @@ resolve_local_cache, ) from tidy3d.web.cli.app import tidy3d_cli -from tidy3d.web.core.task_core import BatchTask +from tidy3d.web.core.task_core import BatchTask, SimulationTask common.CONNECTION_RETRY_TIME = 0.1 @@ -245,7 +245,9 @@ def _fake_field_map_from_file(*args, **kwargs): monkeypatch.setattr( io_utils, "load_simulation", lambda task_id, *args, **kwargs: TASK_TO_SIM[task_id] ) - monkeypatch.setattr(BatchTask, "is_batch", lambda *args, **kwargs: "success") + monkeypatch.setattr( + SimulationTask, "get", lambda *args, **kwargs: SimpleNamespace(taskType="FDTD") + ) monkeypatch.setattr( BatchTask, "detail", lambda *args, **kwargs: SimpleNamespace(status="success") ) diff --git a/tests/test_web/test_webapi.py b/tests/test_web/test_webapi.py index 074c44b069..4fdd280bf0 100644 --- a/tests/test_web/test_webapi.py +++ b/tests/test_web/test_webapi.py @@ -415,7 +415,7 @@ def test_run_with_invalid_priority(mock_webapi, priority): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_eme.py b/tests/test_web/test_webapi_eme.py index fc43f0670d..c44cd19b41 100644 --- a/tests/test_web/test_webapi_eme.py +++ b/tests/test_web/test_webapi_eme.py @@ -260,7 +260,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_heat.py b/tests/test_web/test_webapi_heat.py index 15767089e8..571aad371f 100644 --- a/tests/test_web/test_webapi_heat.py +++ b/tests/test_web/test_webapi_heat.py @@ -250,7 +250,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_mode.py b/tests/test_web/test_webapi_mode.py index 8ee869de5b..93e4a2a8bf 100644 --- a/tests/test_web/test_webapi_mode.py +++ b/tests/test_web/test_webapi_mode.py @@ -307,7 +307,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tests/test_web/test_webapi_mode_sim.py b/tests/test_web/test_webapi_mode_sim.py index 275d204a9f..052591beab 100644 --- a/tests/test_web/test_webapi_mode_sim.py +++ b/tests/test_web/test_webapi_mode_sim.py @@ -303,7 +303,7 @@ def test_get_info(mock_get_info): @responses.activate -def test_get_run_info(mock_get_run_info): +def test_get_run_info(mock_get_run_info, mock_get_info): assert get_run_info(TASK_ID) == (100, 0) diff --git a/tidy3d/__init__.py b/tidy3d/__init__.py index 4acc881996..afa81b60ff 100644 --- a/tidy3d/__init__.py +++ b/tidy3d/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.boundary import BroadbandModeABCFitterParam, BroadbandModeABCSpec from tidy3d.components.data.index import SimulationDataMap from tidy3d.components.frequency_extrapolation import LowFrequencySmoothingSpec @@ -812,6 +813,7 @@ def set_logging_level(level: str) -> None: "TemperatureData", "TemperatureMonitor", "TetrahedralGridDataset", + "Tidy3dBaseModel", "Transformed", "TriangleMesh", "TriangularGridDataset", diff --git a/tidy3d/components/mode/mode_solver.py b/tidy3d/components/mode/mode_solver.py index 3d41c2ffb8..f32a56a4cf 100644 --- a/tidy3d/components/mode/mode_solver.py +++ b/tidy3d/components/mode/mode_solver.py @@ -2714,7 +2714,7 @@ def _validate_modes_size(self) -> None: "frequencies or modes." ) - def validate_pre_upload(self, source_required: bool = True) -> None: + def validate_pre_upload(self) -> None: """Validate the fully initialized mode solver is ok for upload to our servers.""" self._validate_modes_size() diff --git a/tidy3d/components/mode/simulation.py b/tidy3d/components/mode/simulation.py index 2e836c55ec..e506706822 100644 --- a/tidy3d/components/mode/simulation.py +++ b/tidy3d/components/mode/simulation.py @@ -612,8 +612,8 @@ def plot_pml_mode_plane( """ return self._mode_solver.plot_pml(ax=ax) - def validate_pre_upload(self, source_required: bool = False) -> None: + def validate_pre_upload(self) -> None: super().validate_pre_upload() - self._mode_solver.validate_pre_upload(source_required=source_required) + self._mode_solver.validate_pre_upload() _boundaries_for_zero_dims = validate_boundaries_for_zero_dims(warn_on_change=False) diff --git a/tidy3d/components/tcad/mesher.py b/tidy3d/components/tcad/mesher.py index d35b465c9a..82af0de6f3 100644 --- a/tidy3d/components/tcad/mesher.py +++ b/tidy3d/components/tcad/mesher.py @@ -24,3 +24,9 @@ class VolumeMesher(Tidy3dBaseModel): def _get_simulation_types(self) -> list[TCADAnalysisTypes]: return [TCADAnalysisTypes.MESH] + + def validate_pre_upload(self): + """Validate the VolumeMesher before uploading to the cloud. + Currently no validation but method is required when calling ``web.upload``. + """ + return diff --git a/tidy3d/plugins/smatrix/component_modelers/base.py b/tidy3d/plugins/smatrix/component_modelers/base.py index 1d9715d98f..74c5afe1c0 100644 --- a/tidy3d/plugins/smatrix/component_modelers/base.py +++ b/tidy3d/plugins/smatrix/component_modelers/base.py @@ -343,5 +343,9 @@ def run( ) return data.smatrix() + def validate_pre_upload(self): + """Validate the modeler before upload.""" + self.base_sim.validate_pre_upload(source_required=False) + AbstractComponentModeler.update_forward_refs() diff --git a/tidy3d/plugins/smatrix/component_modelers/modal.py b/tidy3d/plugins/smatrix/component_modelers/modal.py index 591ad4c624..f75c98bc81 100644 --- a/tidy3d/plugins/smatrix/component_modelers/modal.py +++ b/tidy3d/plugins/smatrix/component_modelers/modal.py @@ -59,6 +59,11 @@ class ModalComponentModeler(AbstractComponentModeler): "by ``element_mappings``, the simulation corresponding to this column is skipped automatically.", ) + @property + def base_sim(self): + """The base simulation.""" + return self.simulation + @cached_property def sim_dict(self) -> SimulationMap: """Generates all :class:`.Simulation` objects for the S-matrix calculation. @@ -368,3 +373,9 @@ def get_max_mode_indices(matrix_elements: tuple[str, int]) -> int: max_mode_index_in = get_max_mode_indices(self.matrix_indices_source) return max_mode_index_out, max_mode_index_in + + def task_name_from_index(self, matrix_index: MatrixIndex) -> str: + """Compute task name for a given (port_name, mode_index) without constructing simulations.""" + port_name, mode_index = matrix_index + port = self.get_port_by_name(port_name=port_name) + return self.get_task_name(port=port, mode_index=mode_index) diff --git a/tidy3d/plugins/smatrix/component_modelers/terminal.py b/tidy3d/plugins/smatrix/component_modelers/terminal.py index f865286c4e..cbd7cb7fd9 100644 --- a/tidy3d/plugins/smatrix/component_modelers/terminal.py +++ b/tidy3d/plugins/smatrix/component_modelers/terminal.py @@ -223,7 +223,7 @@ def _warn_refactor_2_10(cls, values): @property def _sim_with_sources(self) -> Simulation: - """Instance of :class:`.Simulation` with all sources and absorbers added for each port, for troubleshooting.""" + """Instance of :class:`.Simulation` with all sources and absorbers added for each port, for plotting.""" sources = [port.to_source(self._source_time) for port in self.ports] absorbers = [ @@ -231,7 +231,9 @@ def _sim_with_sources(self) -> Simulation: for port in self.ports if isinstance(port, WavePort) and port.absorber ] - return self.simulation.updated_copy(sources=sources, internal_absorbers=absorbers) + return self.simulation.updated_copy( + sources=sources, internal_absorbers=absorbers, validate=False + ) @equal_aspect @add_ax_if_none @@ -382,6 +384,10 @@ def matrix_indices_run_sim(self) -> tuple[NetworkIndex, ...]: def sim_dict(self) -> SimulationMap: """Generate all the :class:`.Simulation` objects for the port parameter calculation.""" + # Check base simulation for grid size at ports + TerminalComponentModeler._check_grid_size_at_ports(self.base_sim, self._lumped_ports) + TerminalComponentModeler._check_grid_size_at_wave_ports(self.base_sim, self._wave_ports) + sim_dict = {} # Now, create simulations with wave port sources and mode solver monitors for computing port modes for network_index in self.matrix_indices_run_sim: @@ -389,11 +395,6 @@ def sim_dict(self) -> SimulationMap: # update simulation sim_dict[task_name] = sim_with_src - # Check final simulations for grid size at ports - for _, sim in sim_dict.items(): - TerminalComponentModeler._check_grid_size_at_ports(sim, self._lumped_ports) - TerminalComponentModeler._check_grid_size_at_wave_ports(sim, self._wave_ports) - return SimulationMap(keys=tuple(sim_dict.keys()), values=tuple(sim_dict.values())) @cached_property @@ -414,7 +415,10 @@ def _base_sim_no_radiation_monitors(self) -> Simulation: # Make an initial simulation with new grid_spec to determine where LumpedPorts are snapped sim_wo_source = self.simulation.updated_copy( - grid_spec=grid_spec, lumped_elements=lumped_resistors + grid_spec=grid_spec, + lumped_elements=lumped_resistors, + validate=False, + deep=False, ) snap_centers = {} for port in self._lumped_ports: @@ -480,7 +484,11 @@ def _base_sim_no_radiation_monitors(self) -> Simulation: ) # update base simulation with updated set of shared components - sim_wo_source = sim_wo_source.copy(update=update_dict) + sim_wo_source = sim_wo_source.updated_copy( + **update_dict, + validate=False, + deep=False, + ) # extrude port structures sim_wo_source = self._extrude_port_structures(sim=sim_wo_source) @@ -527,7 +535,10 @@ def base_sim(self) -> Simulation: """The base simulation with all components added, including radiation monitors.""" base_sim_tmp = self._base_sim_no_radiation_monitors mnts_with_radiation = list(base_sim_tmp.monitors) + list(self._finalized_radiation_monitors) - return base_sim_tmp.updated_copy(monitors=mnts_with_radiation) + grid_spec = GridSpec.from_grid(base_sim_tmp.grid) + grid_spec.attrs["from_grid_spec"] = base_sim_tmp.grid_spec + # We skipped validations up to now, here we finally validate the base sim + return base_sim_tmp.updated_copy(monitors=mnts_with_radiation, grid_spec=grid_spec) def _generate_radiation_monitor( self, simulation: Simulation, auto_spec: DirectivityMonitorSpec @@ -712,7 +723,10 @@ def _add_source_to_sim(self, source_index: NetworkIndex) -> tuple[str, Simulatio ) task_name = self.get_task_name(port=port, mode_index=mode_index) - return (task_name, self.base_sim.updated_copy(sources=[port_source])) + return ( + task_name, + self.base_sim.updated_copy(sources=[port_source], validate=False, deep=False), + ) @cached_property def _source_time(self): @@ -863,6 +877,11 @@ def get_radiation_monitor_by_name(self, monitor_name: str) -> DirectivityMonitor return monitor raise Tidy3dKeyError(f"No radiation monitor named '{monitor_name}'.") + def task_name_from_index(self, source_index: NetworkIndex) -> str: + """Compute task name for a given network index without constructing simulations.""" + port, mode_index = self.network_dict[source_index] + return self.get_task_name(port=port, mode_index=mode_index) + def _extrude_port_structures(self, sim: Simulation) -> Simulation: """ Extrude structures intersecting a port plane when a wave port lies on a structure boundary. @@ -983,6 +1002,8 @@ def _extrude_port_structures(self, sim: Simulation) -> Simulation: sim = sim.updated_copy( grid_spec=GridSpec.from_grid(sim.grid), structures=[*sim.structures, *all_new_structures], + validate=False, + deep=False, ) return sim diff --git a/tidy3d/plugins/smatrix/run.py b/tidy3d/plugins/smatrix/run.py index 97f9393338..ff4b9d39f1 100644 --- a/tidy3d/plugins/smatrix/run.py +++ b/tidy3d/plugins/smatrix/run.py @@ -1,10 +1,7 @@ from __future__ import annotations -import json -from os import PathLike from typing import Any -from tidy3d.components.base import Tidy3dBaseModel from tidy3d.components.data.index import SimulationDataMap from tidy3d.log import log from tidy3d.plugins.smatrix.component_modelers.modal import ModalComponentModeler @@ -18,43 +15,6 @@ DEFAULT_DATA_DIR = "." -def compose_modeler( - modeler_file: PathLike, -) -> ComponentModelerType: - """Load a component modeler from an HDF5 file. - - This function reads an HDF5 file, determines the modeler type - (`ModalComponentModeler` or `TerminalComponentModeler`), and constructs the - corresponding modeler object. - - Parameters - ---------- - modeler_file : PathLike - Path to the HDF5 file containing the modeler definition. - - Returns - ------- - ComponentModelerType - The loaded `ModalComponentModeler` or `TerminalComponentModeler` object. - - Raises - ------ - TypeError - If the modeler type specified in the file is not supported. - """ - json_str = Tidy3dBaseModel._json_string_from_hdf5(modeler_file) - model_dict = json.loads(json_str) - modeler_type = model_dict["type"] - - if modeler_type == "ModalComponentModeler": - modeler = ModalComponentModeler.from_file(modeler_file) - elif modeler_type == "TerminalComponentModeler": - modeler = TerminalComponentModeler.from_file(modeler_file) - else: - raise TypeError(f"Unsupported modeler type: {type(modeler_type).__name__}") - return modeler - - def compose_modeler_data( modeler: ModalComponentModeler | TerminalComponentModeler, indexed_sim_data: SimulationDataMap, diff --git a/tidy3d/web/__init__.py b/tidy3d/web/__init__.py index 0cdc8942e5..dcdf44c9c3 100644 --- a/tidy3d/web/__init__.py +++ b/tidy3d/web/__init__.py @@ -30,7 +30,6 @@ load, load_simulation, monitor, - postprocess_start, real_cost, start, test, @@ -58,7 +57,6 @@ "load", "load_simulation", "monitor", - "postprocess_start", "real_cost", "run", "run_async", diff --git a/tidy3d/web/api/container.py b/tidy3d/web/api/container.py index 08092e3062..93e80438c4 100644 --- a/tidy3d/web/api/container.py +++ b/tidy3d/web/api/container.py @@ -389,25 +389,7 @@ def status(self) -> str: """Return current status of :class:`Job`.""" if self.load_if_cached: return "success" - if web._is_modeler_batch(self.task_id): - detail = self.get_info() - status = detail.totalStatus.value - return status - else: - return self.get_info().status - - @property - def postprocess_status(self) -> Optional[str]: - """Return current postprocess status of :class:`Job` if it is a Component Modeler.""" - if web._is_modeler_batch(self.task_id): - detail = self.get_info() - return detail.postprocessStatus - else: - log.warning( - f"Task ID '{self.task_id}' is not a modeler batch job. " - "'postprocess_start' is only applicable to Component Modelers" - ) - return + return self.get_info().status def start(self, priority: Optional[int] = None) -> None: """Start running a :class:`Job`. @@ -548,37 +530,6 @@ def estimate_cost(self, verbose: bool = True) -> float: return 0.0 return web.estimate_cost(self.task_id, verbose=verbose, solver_version=self.solver_version) - def postprocess_start(self, worker_group: Optional[str] = None, verbose: bool = True) -> None: - """ - If the job is a modeler batch, checks if the run is complete and starts - the postprocess phase. - - This function does not wait for postprocessing to finish and is only - applicable to Component Modeler batch jobs. - - Parameters - ---------- - worker_group : Optional[str] = None - The specific worker group to run the postprocessing task on. - verbose : bool = True - Whether to print info messages. This overrides the Job's 'verbose' setting for this call. - """ - # First, confirm that the task is a modeler batch job. - if not web._is_modeler_batch(self.task_id): - # If not, inform the user and exit. - # This warning is important and should not be suppressed. - log.warning( - f"Task ID '{self.task_id}' is not a modeler batch job. " - "'postprocess_start' is only applicable to Component Modelers" - ) - return - - # If it is a modeler batch, call the dedicated function to start postprocessing. - # The verbosity is a combination of the job's setting and the method's parameter. - web.postprocess_start( - batch_id=self.task_id, verbose=(self.verbose and verbose), worker_group=worker_group - ) - @staticmethod def _check_path_dir(path: PathLike) -> None: """Make sure parent directory of ``path`` exists and create it if not. @@ -1043,21 +994,6 @@ def get_run_info(self) -> dict[TaskName, RunInfo]: run_info_dict[task_name] = run_info return run_info_dict - def postprocess_start(self, worker_group: Optional[str] = None, verbose: bool = True) -> None: - """ - Start the postprocess phase for all applicable jobs in the batch. - - This simply forwards to each Job's `postprocess_start(...)`. The Job decides - whether it's a Component Modeler task and whether it can/should start now. - This method does not wait for postprocessing to finish. - """ - if self.verbose and verbose: - console = get_logging_console() - console.log("Attempting to start postprocessing for jobs in the batch.") - - for job in self.jobs.values(): - job.postprocess_start(worker_group=worker_group, verbose=verbose) - def monitor( self, *, @@ -1069,7 +1005,6 @@ def monitor( """ Monitor progress of each running task. - - For Component Modeler jobs, automatically triggers postprocessing once run finishes. - Optionally downloads results as soon as a job reaches final success. - Rich progress bars in verbose mode; quiet polling otherwise. @@ -1095,16 +1030,8 @@ def monitor( self._check_path_dir(path_dir=path_dir) download_executor = ThreadPoolExecutor(max_workers=self.num_workers) - def _should_download(job: Job) -> bool: - status = job.status - if not web._is_modeler_batch(job.task_id): - return status == "success" - if status == "success": - return True - return status == "run_success" and getattr(job, "postprocess_status", None) == "success" - def schedule_download(job: Job) -> None: - if download_executor is None or not _should_download(job): + if download_executor is None or job.status not in COMPLETED_STATES: return task_id = job.task_id if task_id in downloads_started: @@ -1128,12 +1055,7 @@ def schedule_download(job: Job) -> None: def check_continue_condition(job: Job) -> bool: if job.load_if_cached: return False - status = job.status - if not web._is_modeler_batch(job.task_id): - return status not in END_STATES - if status == "run_success": - return job.postprocess_status not in END_STATES - return status not in END_STATES + return job.status not in END_STATES def pbar_description( task_name: str, status: str, max_name_length: int, status_width: int @@ -1157,9 +1079,6 @@ def pbar_description( max_task_name = max(len(task_name) for task_name in self.jobs.keys()) max_name_length = min(30, max(max_task_name, 15)) - # track which modeler jobs we've already kicked into postprocess - postprocess_started_tasks: set[str] = set() - try: console = None progress_columns = [] @@ -1195,17 +1114,6 @@ def pbar_description( for task_name, job in self.jobs.items(): status = job.status - # auto-start postprocess for modeler jobs when run finishes - if ( - web._is_modeler_batch(job.task_id) - and status == "run_success" - and job.task_id not in postprocess_started_tasks - ): - job.postprocess_start( - worker_group=postprocess_worker_group, verbose=True - ) - postprocess_started_tasks.add(job.task_id) - schedule_download(job) if self.verbose: diff --git a/tidy3d/web/api/states.py b/tidy3d/web/api/states.py index 9b02d161f6..a7a03935ae 100644 --- a/tidy3d/web/api/states.py +++ b/tidy3d/web/api/states.py @@ -5,15 +5,17 @@ } ERROR_STATES = { - "validate_fail", + "validate_error", "error", "errored", "diverge", "diverged", "blocked", - "run_failed", + "preprocess_error", + "run_error", "aborted", "deleted", + "postprocess_error", } PRE_VALIDATE_STATES = { @@ -30,20 +32,17 @@ "run_success", } -COMPLETED_STATES = {"visualize", "success", "completed", "processed"} +COMPLETED_STATES = {"visualize", "success", "completed", "processed", "postprocess_success"} END_STATES = ERROR_STATES | COMPLETED_STATES -POST_VALIDATE_STATES = { - "validate_success", - "validate_warn", -} +POST_VALIDATE_STATES = {"validate_success", "validate_warn"} RUNNING_STATES = ( PRE_VALIDATE_STATES | POST_VALIDATE_STATES | {"running"} | POST_RUN_STATES | COMPLETED_STATES ) -ALL_POST_VALIDATE_STATES = POST_VALIDATE_STATES | {"running"} | POST_RUN_STATES | COMPLETED_STATES +ALL_POST_VALIDATE_STATES = POST_VALIDATE_STATES | {"running"} | POST_RUN_STATES | END_STATES VALID_PROGRESS_STATES = RUNNING_STATES | PRE_ERROR_STATES @@ -87,6 +86,7 @@ "visualize": round((11 / MAX_STEPS) * COMPLETED_PERCENT), # 85% "success": COMPLETED_PERCENT, # 100% "completed": COMPLETED_PERCENT, # 100% + "postprocess_success": COMPLETED_PERCENT, # 100% # --- Error States --- # All error states map to 0% "validate_fail": 0, @@ -98,4 +98,8 @@ "run_failed": 0, "aborted": 0, "deleted": 0, + "validate_error": 0, + "preprocess_error": 0, + "run_error": 0, + "postprocess_error": 0, } diff --git a/tidy3d/web/api/tidy3d_stub.py b/tidy3d/web/api/tidy3d_stub.py index a55bafe6f9..eb49f90461 100644 --- a/tidy3d/web/api/tidy3d_stub.py +++ b/tidy3d/web/api/tidy3d_stub.py @@ -46,8 +46,8 @@ EMESimulation: TaskType.EME, ModeSimulation: TaskType.MODE, VolumeMesher: TaskType.VOLUME_MESH, - ModalComponentModeler: TaskType.COMPONENT_MODELER, - TerminalComponentModeler: TaskType.TERMINAL_COMPONENT_MODELER, + ModalComponentModeler: TaskType.MODAL_CM, + TerminalComponentModeler: TaskType.TERMINAL_CM, } @@ -116,7 +116,7 @@ def validate_pre_upload(self, source_required: bool) -> None: """Perform some pre-checks on instances of component""" if isinstance(self.simulation, Simulation): self.simulation.validate_pre_upload(source_required) - elif isinstance(self.simulation, EMESimulation): + else: self.simulation.validate_pre_upload() def get_default_task_name(self) -> str: diff --git a/tidy3d/web/api/webapi.py b/tidy3d/web/api/webapi.py index a57f0ebbfe..9134fff385 100644 --- a/tidy3d/web/api/webapi.py +++ b/tidy3d/web/api/webapi.py @@ -19,12 +19,10 @@ from tidy3d.config import config from tidy3d.exceptions import WebError from tidy3d.log import get_logging_console, log -from tidy3d.plugins.smatrix.component_modelers.terminal import TerminalComponentModeler from tidy3d.web.api.states import ( ALL_POST_VALIDATE_STATES, END_STATES, ERROR_STATES, - POST_VALIDATE_STATES, STATE_PROGRESS_PERCENTAGE, ) from tidy3d.web.cache import CacheEntry, _store_mode_solver_in_cache, resolve_local_cache @@ -39,11 +37,15 @@ SIMULATION_DATA_HDF5_GZ, TaskId, ) -from tidy3d.web.core.exceptions import WebNotFoundError -from tidy3d.web.core.http_util import get_version as _get_protocol_version -from tidy3d.web.core.http_util import http -from tidy3d.web.core.task_core import BatchDetail, BatchTask, Folder, SimulationTask -from tidy3d.web.core.task_info import AsyncJobDetail, ChargeType, TaskInfo +from tidy3d.web.core.task_core import ( + BatchDetail, + BatchTask, + Folder, + SimulationTask, + TaskFactory, + WebTask, +) +from tidy3d.web.core.task_info import ChargeType, TaskInfo from tidy3d.web.core.types import PayType, TaskType from .connect_util import REFRESH_TIME, get_grid_points_str, get_time_steps_str, wait_for_connection @@ -56,7 +58,7 @@ SIM_FILE_JSON = "simulation.json" # not all solvers are supported yet in GUI -GUI_SUPPORTED_TASK_TYPES = ["FDTD", "MODE_SOLVER", "HEAT", "RF"] +GUI_SUPPORTED_TASK_TYPES = ["FDTD", "MODE_SOLVER", "HEAT", "TERMINAL_CM"] # if a solver is in beta stage, cost is subject to change BETA_TASK_TYPES = ["HEAT", "EME", "HEAT_CHARGE", "VOLUME_MESH"] @@ -95,15 +97,6 @@ def _build_website_url(path: str) -> str: return "/".join([base.rstrip("/"), str(path).lstrip("/")]) -def _is_modeler_batch(resource_id: str) -> bool: - """Detect whether the given id corresponds to a modeler batch resource.""" - return BatchTask.is_batch(resource_id, batch_type="RF_SWEEP") - - -def _batch_detail(resource_id: str) -> BatchDetail: - return BatchTask(resource_id).detail(batch_type="RF_SWEEP") - - def _batch_detail_error(resource_id: str) -> Optional[WebError]: """Processes a failed batch job to generate a detailed error. @@ -115,45 +108,40 @@ def _batch_detail_error(resource_id: str) -> Optional[WebError]: Args: resource_id (str): The identifier of the batch resource that failed. - Returns: - An instance of `WebError` if the batch failed, otherwise `None`. + Raises: + An instance of ``WebError`` if the batch failed. """ + + # TODO: test properly try: - batch_detail = BatchTask(batch_id=resource_id).detail(batch_type="RF_SWEEP") - status = batch_detail.totalStatus.value + batch = BatchTask.get(resource_id) + batch_detail = batch.detail() + status = batch_detail.status.lower() except Exception as e: log.error(f"Could not retrieve batch details for '{resource_id}': {e}") - return WebError(f"Failed to retrieve status for batch '{resource_id}'.") + raise WebError(f"Failed to retrieve status for batch '{resource_id}'.") from e if status not in ERROR_STATES: - return None - - log.error(f"The ComponentModeler batch '{resource_id}' has failed with status: {status}") + return - if ( - status == "validate_fail" - and hasattr(batch_detail, "validateErrors") - and batch_detail.validateErrors - ): - error_details = [] - for key, error_str in batch_detail.validateErrors.items(): - try: - error_dict = json.loads(error_str) - validation_error = error_dict.get("validation_error", "Unknown validation error.") - msg = f"- Subtask '{key}' failed: {validation_error}" - log.error(msg) + if hasattr(batch_detail, "validateErrors") and batch_detail.validateErrors: + try: + error_details = [] + for key, error_str in batch_detail.validateErrors.items(): + msg = f"- Subtask '{key}' failed: {error_str}" error_details.append(msg) - except (json.JSONDecodeError, TypeError): - # Handle cases where the error string isn't valid JSON - log.error(f"Could not parse validation error for subtask '{key}'.") - error_details.append(f"- Subtask '{key}': Could not parse error details.") - - details_string = "\n".join(error_details) - full_error_msg = ( - "One or more subtasks failed validation. Please fix the component modeler configuration.\n" - f"Details:\n{details_string}" - ) - return WebError(full_error_msg) + + details_string = "\n".join(error_details) + full_error_msg = ( + "One or more subtasks failed validation. Please fix the component modeler " + "configuration.\n" + f"Details:\n{details_string}" + ) + except Exception as e: + raise WebError( + "One or more subtasks failed validation. Failed to parse validation errors." + ) from e + raise WebError(full_error_msg) # Handle all other generic error states else: @@ -161,168 +149,7 @@ def _batch_detail_error(resource_id: str) -> Optional[WebError]: f"Batch '{resource_id}' failed with status '{status}'. Check server " "logs for details or contact customer support." ) - return WebError(error_msg) - - -def _upload_component_modeler_subtasks( - resource_id: str, verbose: bool = True, solver_version: Optional[str] = None -) -> Optional[WebError]: - """Kicks off and monitors the split and validation of component modeler tasks. - - This function orchestrates a two-phase process. First, it initiates a - server-side asynchronous job to split the components of a modeler batch. - It monitors this job's progress by polling the API and parsing the - response into an `AsyncJobDetail` model until the job completes or fails. - - If the split is successful, the function proceeds to the second phase: - triggering a batch validation via `batch.check()`. It then monitors this - validation process by polling for `BatchDetail` updates. The progress bar, - if verbose, reflects the status according to a predefined state mapping. - - Finally, it processes the terminal state of the validation. If a - 'validate_fail' status occurs, it parses detailed error messages for each - failed subtask and includes them in the raised exception. - - Args: - resource_id (str): The identifier for the batch resource to be processed. - verbose (bool): If True, displays progress bars and logs detailed - status messages to the console during the operation. - solver_version (str): Solver version in which to run validation. - - Raises: - RuntimeError: If the initial asynchronous split job fails. - WebError: If the subsequent batch validation fails, ends in an - unexpected state, or if a 'validate_fail' status is encountered. - """ - console = get_logging_console() if verbose else None - final_error = None - batch_type = "RF_SWEEP" - - split_path = "tidy3d/async-biz/component-modeler-split" - payload = { - "batchType": batch_type, - "batchId": resource_id, - "fileName": "modeler.hdf5.gz", - "protocolVersion": _get_protocol_version(), - } - - if verbose: - console.log("Starting modeler and subtasks validation...") - - initial_resp = http.post(split_path, payload) - split_job_detail = AsyncJobDetail(**initial_resp) - monitor_split_path = f"{split_path}?asyncId={split_job_detail.asyncId}" - - if verbose: - progress_bar = Progress( - TextColumn("[progress.description]{task.description}"), - BarColumn(), - TaskProgressColumn(), - TimeElapsedColumn(), - console=console, - ) - - with progress_bar as progress: - description = "Upload Subtasks" - pbar = progress.add_task(description, completed=split_job_detail.progress, total=100) - while True: - split_job_raw_result = http.get(monitor_split_path) - split_job_detail = AsyncJobDetail(**split_job_raw_result) - - progress.update( - pbar, completed=split_job_detail.progress, description=f"[blue]{description}" - ) - - if split_job_detail.status in END_STATES: - progress.update( - pbar, - completed=split_job_detail.progress, - description=f"[green]{description}", - ) - break - time.sleep(RUN_REFRESH_TIME) - - if split_job_detail.status in ERROR_STATES: - msg = split_job_detail.result or "An unknown error occurred." - final_error = WebError( - f"Component modeler split job failed ({split_job_detail.status}): {msg}" - ) - - if not final_error: - description = "Validating" - pbar = progress.add_task( - completed=10, total=100, description=f"[blue]{description}" - ) - batch = BatchTask(resource_id) - batch.check(solver_version=solver_version, batch_type=batch_type) - - while True: - batch_detail = batch.detail(batch_type=batch_type) - status = batch_detail.totalStatus - progress_percent = STATE_PROGRESS_PERCENTAGE.get(status, 0) - progress.update( - pbar, completed=progress_percent, description=f"[blue]{description}" - ) - - if status in POST_VALIDATE_STATES: - progress.update(pbar, completed=100, description=f"[green]{description}") - task_mapping = json.loads(split_job_detail.result) - console.log( - f"Uploaded Subtasks: \n{_task_dict_to_url_bullet_list(task_mapping)}" - ) - progress.refresh() - break - elif status in ERROR_STATES: - progress.update(pbar, completed=0, description=f"[red]{description}") - progress.refresh() - break - time.sleep(RUN_REFRESH_TIME) - - else: - # Non-verbose mode: Poll for split job completion. - while True: - split_job_raw_result = http.get(monitor_split_path) - split_job_detail = AsyncJobDetail(**split_job_raw_result) - if split_job_detail.status in END_STATES: - break - time.sleep(RUN_REFRESH_TIME) - - # Check for split job failure. - if split_job_detail.status in ERROR_STATES: - msg = split_job_detail.result or "An unknown error occurred." - final_error = WebError( - f"Component modeler split job failed ({split_job_detail.status}): {msg}" - ) - - # If split succeeded, poll for validation completion. - if not final_error: - batch = BatchTask(resource_id) - batch.check(solver_version=solver_version, batch_type=batch_type) - while True: - batch_detail = batch.detail(batch_type=batch_type) - status = batch_detail.totalStatus - if status in POST_VALIDATE_STATES or status in END_STATES: - break - time.sleep(RUN_REFRESH_TIME) - - return _batch_detail_error(resource_id=resource_id) - - -def _task_dict_to_url_bullet_list(data_dict: dict) -> str: - """ - Converts a dictionary into a string formatted as a bullet point list. - - Args: - data_dict: The dictionary to convert. - - Returns: - A string with each key-url/value pair as a bullet point. - """ - # Use a list comprehension to format each key-value pair - # and then join them together with newline characters. - if data_dict is None: - raise WebError("Error in subtask dictionary data.") - return "\n".join([f"- {key}: '{value}'" for key, value in data_dict.items()]) + raise WebError(error_msg) def _copy_simulation_data_from_cache_entry(entry: CacheEntry, path: PathLike) -> bool: @@ -398,7 +225,7 @@ def restore_simulation_if_cached( cached_workflow_type = entry.metadata.get("workflow_type") if cached_task_id is not None and cached_workflow_type is not None and verbose: console = get_logging_console() - url, _ = _get_task_urls(cached_workflow_type, simulation, cached_task_id) + url, _ = _get_task_urls(cached_workflow_type, cached_task_id) console.log( f"Loading simulation from local cache. View cached task using web UI at [link={url}]'{url}'[/link]." ) @@ -572,7 +399,6 @@ def run( ) start( task_id, - verbose=verbose, solver_version=solver_version, worker_group=worker_group, pay_type=pay_type, @@ -600,15 +426,12 @@ def run( def _get_task_urls( task_type: str, - simulation: WorkflowType, resource_id: str, folder_id: Optional[str] = None, group_id: Optional[str] = None, ) -> tuple[str, Optional[str]]: """Log task and folder links to the web UI.""" - if (task_type in ["RF", "COMPONENT_MODELER", "TERMINAL_COMPONENT_MODELER"]) and isinstance( - simulation, TerminalComponentModeler - ): + if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: url = _get_url_rf(group_id or resource_id) else: url = _get_url(resource_id) @@ -694,13 +517,8 @@ def upload( task_name = stub.get_default_task_name() task_type = stub.get_type() - # Component modeler compatibility: map to RF task type - port_name_list = None - if task_type in ("COMPONENT_MODELER", "TERMINAL_COMPONENT_MODELER"): - task_type = "RF" - port_name_list = tuple(simulation.sim_dict.keys()) - task = SimulationTask.create( + task = WebTask.create( task_type, task_name, folder_name, @@ -708,26 +526,10 @@ def upload( simulation_type, parent_tasks, "Gz", - port_name_list=port_name_list, ) - if task_type == "RF": - # Prefer the group id if present in the creation response; avoid extra GET. - group_id = getattr(task, "groupId", None) or getattr(task, "group_id", None) - if not group_id: - try: - detail_task = SimulationTask.get(task.task_id, verbose=False) - group_id = getattr(detail_task, "groupId", None) or getattr( - detail_task, "group_id", None - ) - except Exception: - group_id = None - # Prefer returning batch/group id for downstream batch endpoints - batch_id = getattr(task, "batchId", None) or getattr(task, "batch_id", None) - resource_id = batch_id or task.task_id - else: - group_id = None - resource_id = task.task_id + group_id = getattr(task, "groupId", None) + resource_id = task.task_id if verbose: console.log( @@ -740,16 +542,14 @@ def upload( f"Cost of {solver_name} simulations is subject to change in the future." ) if task_type in GUI_SUPPORTED_TASK_TYPES: - url, folder_url = _get_task_urls( - task_type, simulation, resource_id, task.folder_id, group_id - ) + url, folder_url = _get_task_urls(task_type, resource_id, task.folder_id, group_id) console.log(f"View task using web UI at [link={url}]'{url}'[/link].") console.log(f"Task folder: [link={folder_url}]'{task.folder_name}'[/link].") remote_sim_file = SIM_FILE_HDF5_GZ if task_type == "MODE_SOLVER": remote_sim_file = MODE_FILE_HDF5_GZ - elif task_type == "RF": + elif task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: remote_sim_file = MODELER_FILE_HDF5_GZ task.upload_simulation( @@ -759,15 +559,10 @@ def upload( remote_sim_file=remote_sim_file, ) - if task_type == "RF": - _upload_component_modeler_subtasks(resource_id=resource_id, verbose=verbose) - estimate_cost(task_id=resource_id, solver_version=solver_version, verbose=verbose) task.validate_post_upload(parent_tasks=parent_tasks) - # log the url for the task in the web UI - log.debug(_build_website_url(f"folders/{task.folder_id}/tasks/{resource_id}")) return resource_id @@ -847,20 +642,15 @@ def get_info(task_id: TaskId, verbose: bool = True) -> TaskInfo | BatchDetail: ValueError If no task is found for the given ``task_id``. """ - if _is_modeler_batch(task_id): - batch = BatchTask(task_id) - return batch.detail(batch_type="RF_SWEEP") - else: - task = SimulationTask.get(task_id, verbose) - if not task: - raise ValueError("Task not found.") - return TaskInfo(**{"taskId": task.task_id, "taskType": task.task_type, **task.dict()}) + task = TaskFactory.get(task_id, verbose=verbose) + if not task: + raise ValueError("Task not found.") + return task.detail() @wait_for_connection def start( task_id: TaskId, - verbose: bool = True, solver_version: Optional[str] = None, worker_group: Optional[str] = None, pay_type: Union[PayType, str] = PayType.AUTO, @@ -889,30 +679,10 @@ def start( To monitor progress, can call :meth:`monitor` after starting simulation. """ - console = get_logging_console() if verbose else None - - # Component modeler batch path: hide split/check/submit - if _is_modeler_batch(task_id): - # split (modeler-specific) - batch = BatchTask(task_id) - detail = batch.wait_for_validate(batch_type="RF_SWEEP") - status = detail.totalStatus - status_str = status.value - if status_str in POST_VALIDATE_STATES: - pass - elif status_str not in POST_VALIDATE_STATES: - raise WebError(f"Batch task {task_id} is blocked: {status_str}") - # Submit batch to start runs after validation - batch.submit( - solver_version=solver_version, batch_type="RF_SWEEP", worker_group=worker_group - ) - if verbose: - console.log(f"Component Modeler '{task_id}' validated. Solving...") - return - if priority is not None and (priority < 1 or priority > 10): raise ValueError("Priority must be between '1' and '10' if specified.") - task = SimulationTask.get(task_id) + + task = TaskFactory.get(task_id) if not task: raise ValueError("Task not found.") task.submit( @@ -941,10 +711,21 @@ def get_run_info(task_id: TaskId) -> tuple[Optional[float], Optional[float]]: Average field intensity normalized to max value (1.0). Is ``None`` if run info not available. """ - task = SimulationTask(taskId=task_id) + task = TaskFactory.get(task_id) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") return task.get_running_info() +def _get_batch_detail_handle_error_status(batch: BatchTask) -> BatchDetail: + """Get batch detail and raise error if status is in ERROR_STATES.""" + detail = batch.detail() + status = detail.status.lower() + if status in ERROR_STATES: + _batch_detail_error(batch.task_id) + return detail + + def get_status(task_id: TaskId) -> str: """Get the status of a task. Raises an error if status is "error". @@ -953,22 +734,9 @@ def get_status(task_id: TaskId) -> str: task_id : str Unique identifier of task on server. Returned by :meth:`upload`. """ - if _is_modeler_batch(task_id): - # split (modeler-specific) - batch = BatchTask(task_id) - detail = batch.detail(batch_type="RF_SWEEP") - status = detail.totalStatus - if status == "visualize": - return "success" - if status in ERROR_STATES: - try: - # TODO Try to obtain the error message - pass - except Exception: - # If the error message could not be obtained, raise a generic error message - error_msg = "Error message could not be obtained, please contact customer support." - - raise WebError(f"Error running task {task_id}! {error_msg}") + task = TaskFactory.get(task_id) + if isinstance(task, BatchTask): + return _get_batch_detail_handle_error_status(task).status else: task_info = get_info(task_id) status = task_info.status @@ -1017,9 +785,9 @@ def monitor(task_id: TaskId, verbose: bool = True, worker_group: Optional[str] = """ # Batch/modeler monitoring path - if _is_modeler_batch(task_id): - _monitor_modeler_batch(task_id, verbose=verbose, worker_group=worker_group) - return + task = TaskFactory.get(task_id) + if isinstance(task, BatchTask): + return _monitor_modeler_batch(task_id, verbose=verbose) console = get_logging_console() if verbose else None @@ -1094,7 +862,7 @@ def monitor_preprocess() -> None: if verbose: # verbose case, update progressbar console.log("running solver") - if task_type == "FDTD": + if "FDTD" in task_type: with Progress(console=console) as progress: pbar_pd = progress.add_task("% done", total=100) perc_done, _ = get_run_info(task_id) @@ -1178,29 +946,18 @@ def abort(task_id: TaskId) -> Optional[TaskInfo]: Object containing information about status, size, credits of task. """ console = get_logging_console() - try: - task = SimulationTask.get(task_id, verbose=False) - if task: - task.abort() - url = _get_url(task.task_id) - console.log( - f"Task is aborting. View task using web UI at [link={url}]'{url}'[/link] to check the result." - ) - return TaskInfo(**{"taskId": task.task_id, **task.dict()}) - except WebNotFoundError: - pass # Task not found, might be a batch task - - is_batch = BatchTask.is_batch(task_id, batch_type="RF_SWEEP") - if is_batch: - url = _get_url_rf(task_id) - console.log( - f"Batch task abortion is not yet supported, contact customer support." - f" View task using web UI at [link={url}]'{url}'[/link]." - ) - return - console.log("Task ID cannot be found to be aborted.") - return + task = TaskFactory.get(task_id, verbose=False) + if not task: + return None + url = task.get_url() + task.abort() + console.log( + f"Task is aborting. View task using web UI at [link={url}]'{url}'[/link] to check the result." + ) + return TaskInfo( + **{"taskId": task_id, "taskType": getattr(task, "task_type", None), **task.dict()} + ) @wait_for_connection @@ -1225,59 +982,26 @@ def download( """ path = Path(path) - - if _is_modeler_batch(task_id): - # Use a more descriptive default filename for component modeler downloads. - # If the caller left the default as 'simulation_data.hdf5', prefer 'cm_data.hdf5'. + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): if path.name == "simulation_data.hdf5": path = path.with_name("cm_data.hdf5") - - def _download_cm() -> bool: - try: - BatchTask(task_id).get_data_hdf5( - remote_data_file_gz=CM_DATA_HDF5_GZ, - to_file=path, - verbose=verbose, - progress_callback=progress_callback, - ) - return True - except Exception: - return False - - if not _download_cm(): - BatchTask(task_id).postprocess(batch_type="RF_SWEEP") - # wait for postprocess to finish - while True: - resp = BatchTask(task_id).detail(batch_type="RF_SWEEP") - total = resp.totalTask or 0 - post_succ = resp.postprocessSuccess or 0 - status = resp.totalStatus - status_str = status.value - if status_str in ERROR_STATES: - raise WebError( - f"Batch task {task_id} failed during postprocess: {status_str}" - ) from None - if total > 0 and post_succ >= total: - break - time.sleep(REFRESH_TIME) - if not _download_cm(): - raise WebError("Failed to download 'cm_data' after postprocess completion.") + task.get_data_hdf5( + to_file=path, + remote_data_file_gz=CM_DATA_HDF5_GZ, + verbose=verbose, + progress_callback=progress_callback, + ) return - - # Regular single-task download - task_info = get_info(task_id) - task_type = task_info.taskType - + info = get_info(task_id, verbose=False) remote_data_file = SIMULATION_DATA_HDF5_GZ - if task_type == "MODE_SOLVER": + if info.taskType == "MODE_SOLVER": remote_data_file = MODE_DATA_HDF5_GZ - - task = SimulationTask(taskId=task_id) - task.get_sim_data_hdf5( - path, + task.get_data_hdf5( + to_file=path, + remote_data_file_gz=remote_data_file, verbose=verbose, progress_callback=progress_callback, - remote_data_file=remote_data_file, ) @@ -1295,7 +1019,9 @@ def download_json(task_id: TaskId, path: PathLike = SIM_FILE_JSON, verbose: bool If ``True``, will print progressbars and status, otherwise, will run silently. """ - task = SimulationTask(taskId=task_id) + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") task.get_simulation_json(path, verbose=verbose) @@ -1326,7 +1052,9 @@ def load_simulation( Union[:class:`.Simulation`, :class:`.HeatSimulation`, :class:`.EMESimulation`] Simulation loaded from downloaded json file. """ - task = SimulationTask.get(task_id) + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") path = Path(path) if path.suffix == ".json": task.get_simulation_json(path, verbose=verbose) @@ -1361,7 +1089,9 @@ def download_log( ---- To load downloaded results into data, call :meth:`load` with option ``replace_existing=False``. """ - task = SimulationTask(taskId=task_id) + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") task.get_log(path, verbose=verbose, progress_callback=progress_callback) @@ -1412,10 +1142,11 @@ def load( Object containing simulation data. """ path = Path(path) + task = TaskFactory.get(task_id) if task_id else None # For component modeler batches, default to a clearer filename if the default was used. if ( task_id - and _is_modeler_batch(task_id) + and isinstance(task, BatchTask) and path.name in {"simulation_data.hdf5", "simulation_data.hdf5.gz"} ): path = path.with_name(path.name.replace("simulation", "cm")) @@ -1428,8 +1159,8 @@ def load( if verbose and task_id is not None: console = get_logging_console() - if _is_modeler_batch(task_id): - console.log(f"loading component modeler data from {path}") + if isinstance(task, BatchTask): + console.log(f"Loading component modeler data from {path}") else: console.log(f"Loading simulation from {path}") @@ -1461,74 +1192,49 @@ def load( return stub_data +def _status_to_stage(status: str) -> tuple[str, int]: + """Map task status to monotonic stage for progress bars.""" + s = (status or "").lower() + # Map a broader set of states to monotonic stages for progress bars + if s in ("draft", "created"): + return ("draft", 0) + if s in ("queue", "queued"): + return ("queued", 1) + if s in ("validating",): + return ("validating", 2) + if s in ("validate_success", "validate_warn", "preprocess", "preprocessing"): + return ("preprocess", 3) + if s in ("running", "preprocess_success"): + return ("running", 4) + if s in ("run_success", "postprocess"): + return ("postprocess", 5) + if s in ("success", "postprocess_success"): + return ("success", 6) + # Unknown states map to earliest stage to avoid showing 100% prematurely + return (s or "unknown", 0) + + def _monitor_modeler_batch( - batch_id: str, + task_id: str, verbose: bool = True, max_detail_tasks: int = 20, - worker_group: Optional[str] = None, ) -> None: """Monitor modeler batch progress with aggregate and per-task views.""" console = get_logging_console() if verbose else None - - def _status_to_stage(status: str) -> tuple[str, int]: - s = (status or "").lower() - # Map a broader set of states to monotonic stages for progress bars - if s in ("draft", "created"): - return ("draft", 0) - if s in ("queue", "queued"): - return ("queued", 1) - if s in ("preprocess",): - return ("preprocess", 1) - if s in ("validating",): - return ("validating", 2) - if s in ("validate_success", "validate_warn"): - return ("validate", 3) - if s in ("running",): - return ("running", 4) - if s in ("postprocess",): - return ("postprocess", 5) - if s in ("run_success", "success"): - return ("success", 6) - # Unknown states map to earliest stage to avoid showing 100% prematurely - return (s or "unknown", 0) - - detail = _batch_detail(batch_id) + task = BatchTask.get(task_id=task_id) + detail = _get_batch_detail_handle_error_status(task) name = detail.name or "modeler_batch" group_id = detail.groupId - - header = f"Subtasks status - {name}" - if group_id: - header += f"\nGroup ID: '{group_id}'" - if console is not None: - console.log(header) + status = detail.status.lower() # Non-verbose path: poll without progress bars then return if not verbose: # Run phase - while True: - d = _batch_detail(batch_id) - s = d.totalStatus.value - total = d.totalTask or 0 - r = d.runSuccess or 0 - if s in ERROR_STATES: - raise WebError(f"Batch {batch_id} terminated: {s}") - # Updated break condition for robustness - if s in ("run_success", "success") or (total and r >= total): - break + while _status_to_stage(status)[0] not in END_STATES: time.sleep(REFRESH_TIME) + detail = _get_batch_detail_handle_error_status(task) + status = detail.status.lower() - postprocess_start(batch_id, verbose=False, worker_group=worker_group) - - while True: - d = _batch_detail(batch_id) - postprocess_status = d.postprocessStatus - if postprocess_status == "success": - break - elif postprocess_status in ERROR_STATES: - raise WebError( - f"Batch {batch_id} terminated. Please contact customer support and provide this Component Modeler batch ID: '{batch_id}'" - ) - time.sleep(REFRESH_TIME) return progress_columns = ( @@ -1537,16 +1243,26 @@ def _status_to_stage(status: str) -> tuple[str, int]: TaskProgressColumn(), TimeElapsedColumn(), ) + # Make the header + header = f"Subtasks status - {name}" + if group_id: + header += f"\nGroup ID: '{group_id}'" + console.log(header) with Progress(*progress_columns, console=console, transient=False) as progress: # Phase: Run (aggregate + per-task) p_run = progress.add_task("Run Total", total=1.0) task_bars: dict[str, int] = {} + stage = _status_to_stage(status)[0] + prev_stage = _status_to_stage(status)[0] + console.log(f"Batch status = {status}") - while True: - detail = _batch_detail(batch_id) - status = detail.totalStatus.value - total = detail.totalTask or 0 + # Note: get_status errors if an erroring status occurred + while stage not in END_STATES: + total = len(detail.tasks) r = detail.runSuccess or 0 + if stage != prev_stage: + prev_stage = stage + console.log(f"Batch status = {stage}") # Create per-task bars as soon as tasks appear if total and total <= max_detail_tasks and detail.tasks: @@ -1593,40 +1309,15 @@ def _status_to_stage(status: str) -> tuple[str, int]: refresh=False, ) - # Updated break condition for robustness - if status in ("run_success", "success") or (total and r >= total): - break - if status in ERROR_STATES: - raise WebError(f"Batch {batch_id} terminated: {status}") - progress.refresh() - time.sleep(REFRESH_TIME) - - postprocess_start(batch_id, verbose=True, worker_group=worker_group) - - p_post = progress.add_task("Postprocess", total=1.0) - while True: - detail = _batch_detail(batch_id) - postprocess_status = detail.postprocessStatus - if postprocess_status == "success": - progress.update(p_post, completed=1.0) - progress.refresh() - break - elif postprocess_status == "queued": - progress.update(p_post, completed=0.22) - elif postprocess_status == "preprocess": - progress.update(p_post, completed=0.33) - elif postprocess_status == "running": - progress.update(p_post, completed=0.55) - elif postprocess_status in ERROR_STATES: - raise WebError( - f"Batch {batch_id} terminated. Please contact customer support and provide this Component Modeler batch ID: '{batch_id}'" - ) progress.refresh() time.sleep(REFRESH_TIME) + detail = _get_batch_detail_handle_error_status(task) + status = detail.status.lower() + stage = _status_to_stage(status)[0] if console is not None: console.log("Modeler has finished running successfully.") - real_cost(batch_id, verbose=verbose) + real_cost(task.task_id, verbose=verbose) @wait_for_connection @@ -1648,7 +1339,7 @@ def delete(task_id: TaskId, versions: bool = False) -> TaskInfo: """ if not task_id: raise ValueError("Task id not found.") - task = SimulationTask.get(task_id, verbose=False) + task = TaskFactory.get(task_id, verbose=False) task.delete(versions) return TaskInfo(**{"taskId": task.task_id, **task.dict()}) @@ -1674,14 +1365,13 @@ def download_simulation( Optional callback function called when downloading file with ``bytes_in_chunk`` as argument. """ - task_info = get_info(task_id) - task_type = task_info.taskType - + task = TaskFactory.get(task_id, verbose=False) + if isinstance(task, BatchTask): + raise NotImplementedError("Operation not implemented for modeler batches.") + info = get_info(task_id, verbose=False) remote_sim_file = SIM_FILE_HDF5_GZ - if task_type == "MODE_SOLVER": + if info.taskType == "MODE_SOLVER": remote_sim_file = MODE_FILE_HDF5_GZ - - task = SimulationTask(taskId=task_id) task.get_simulation_hdf5( path, verbose=verbose, @@ -1779,63 +1469,64 @@ def estimate_cost( console = get_logging_console() if verbose else None - if _is_modeler_batch(task_id): - d = _batch_detail(task_id) - status = d.totalStatus.value - - if status in ALL_POST_VALIDATE_STATES: - est_flex_unit = _batch_detail(task_id).estFlexUnit - if verbose: - console.log( - f"Maximum FlexCredit cost: {est_flex_unit:1.3f}. Minimum cost depends on " - "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " - "cost after a simulation run." - ) - return est_flex_unit - - elif status in ERROR_STATES: - return _batch_detail_error(resource_id=task_id) - - raise WebError("Could not get estimated cost!") - - else: - task = SimulationTask.get(task_id) - - if not task: - raise ValueError("Task not found.") - - task.estimate_cost(solver_version=solver_version) - task_info = get_info(task_id) - status = task_info.metadataStatus - - # Wait for a termination status + task = TaskFactory.get(task_id, verbose=False) + detail = task.detail() + if isinstance(task, BatchTask): + check_task_type = "FDTD" if detail.taskType == "MODAL_CM" else "RF_FDTD" + task.check(solver_version=solver_version, check_task_type=check_task_type) + detail = task.detail() + status = detail.status.lower() while status not in ALL_POST_VALIDATE_STATES: time.sleep(REFRESH_TIME) - task_info = get_info(task_id) - status = task_info.metadataStatus + detail = task.detail() + status = detail.status.lower() + if status in ERROR_STATES: + _batch_detail_error(resource_id=task_id) + est_flex_unit = detail.estFlexUnit + if verbose: + console.log( + f"Maximum FlexCredit cost: {est_flex_unit:1.3f}. Minimum cost depends on " + "task execution details. Use 'web.real_cost(task_id)' after run." + ) + return est_flex_unit - if status in ALL_POST_VALIDATE_STATES: - if verbose: - console.log( - f"Estimated FlexCredit cost: {task_info.estFlexUnit:1.3f}. Minimum cost depends on " - "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " - "cost after a simulation run." - ) - fc_mode = task_info.estFlexCreditMode - fc_post = task_info.estFlexCreditPostProcess - if fc_mode: - console.log(f" {fc_mode:1.3f} FlexCredit of the total cost from mode solves.") - if fc_post: - console.log( - f" {fc_post:1.3f} FlexCredit of the total cost from post-processing." - ) - return task_info.estFlexUnit + # simulation path + task.estimate_cost(solver_version=solver_version) + task_info = get_info(task_id) + status = task_info.metadataStatus - elif status in ERROR_STATES: - log.error(f"The task '{task_id}' has failed: {status}") + # Wait for a termination status + while status not in ALL_POST_VALIDATE_STATES: + time.sleep(REFRESH_TIME) + task_info = get_info(task_id) + status = task_info.metadataStatus - # Something went wrong - raise WebError("Could not get estimated cost!") + if status in ERROR_STATES: + try: + # Try to obtain the error message + task = SimulationTask(taskId=task_id) + with tempfile.NamedTemporaryFile(suffix=".json") as tmp_file: + task.get_error_json(to_file=tmp_file.name, validation=True) + with open(tmp_file.name) as f: + error_content = json.load(f) + error_msg = error_content["validation_error"] + except Exception: + # If the error message could not be obtained, raise a generic error message + error_msg = "Error message could not be obtained, please contact customer support." + raise WebError(f"Error estimating cost for task {task_id}! {error_msg}") + if verbose: + console.log( + f"Estimated FlexCredit cost: {task_info.estFlexUnit:1.3f}. Minimum cost depends on " + "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " + "cost after a simulation run." + ) + fc_mode = task_info.estFlexCreditMode + fc_post = task_info.estFlexCreditPostProcess + if fc_mode: + console.log(f" {fc_mode:1.3f} FlexCredit of the total cost from mode solves.") + if fc_post: + console.log(f" {fc_post:1.3f} FlexCredit of the total cost from post-processing.") + return task_info.estFlexUnit @wait_for_connection @@ -1891,42 +1582,24 @@ def real_cost(task_id: str, verbose: bool = True) -> float | None: ) console = get_logging_console() if verbose else None - if _is_modeler_batch(task_id): - status = _batch_detail(task_id).totalStatus.value - flex_unit = _batch_detail(task_id).realFlexUnit or None - if (status not in ["success", "run_success"]) or (flex_unit is None): - log.warning( - f"Billed FlexCredit for task '{task_id}' is not available. If the task has been " - "successfully run, it should be available shortly. If this issue persists, contact customer support." - ) - else: - if verbose: + task_info = get_info(task_id) + flex_unit = task_info.realFlexUnit + ori_flex_unit = getattr(task_info, "oriRealFlexUnit", flex_unit) + if not flex_unit: + log.warning( + f"Billed FlexCredit for task '{task_id}' is not available. If the task has been " + "successfully run, it should be available shortly." + ) + else: + if verbose: + console.log(f"Billed flex credit cost: {flex_unit:1.3f}.") + if flex_unit != ori_flex_unit and "FDTD" in task_info.taskType: console.log( - f"Billed FlexCredit cost: {flex_unit:1.3f}. Minimum cost depends on " - "task execution details. Use 'web.real_cost(task_id)' to get the billed FlexCredit " - "cost after a simulation run." + "Note: the task cost pro-rated due to early shutoff was below the minimum " + "threshold, due to fast shutoff. Decreasing the simulation 'run_time' should " + "decrease the estimated, and correspondingly the billed cost of such tasks." ) - - return flex_unit - else: - task_info = get_info(task_id) - flex_unit = task_info.realFlexUnit - ori_flex_unit = task_info.oriRealFlexUnit - if not flex_unit: - log.warning( - f"Billed FlexCredit for task '{task_id}' is not available. If the task has been " - "successfully run, it should be available shortly." - ) - else: - if verbose: - console.log(f"Billed flex credit cost: {flex_unit:1.3f}.") - if flex_unit != ori_flex_unit and task_info.taskType == "FDTD": - console.log( - "Note: the task cost pro-rated due to early shutoff was below the minimum " - "threshold, due to fast shutoff. Decreasing the simulation 'run_time' should " - "decrease the estimated, and correspondingly the billed cost of such tasks." - ) - return flex_unit + return flex_unit @wait_for_connection @@ -1986,49 +1659,6 @@ def account(verbose: bool = True) -> Account: return account_info -@wait_for_connection -def postprocess_start( - batch_id: str, - verbose: bool = True, - worker_group: Optional[str] = None, -) -> None: - """ - Checks if a batch run is complete and starts the postprocess phase. - - This function does not wait for postprocessing to finish. - """ - console = get_logging_console() if verbose else None - if _is_modeler_batch(batch_id): - # Perform a single check on the run phase status - detail = _batch_detail(batch_id) - status = detail.totalStatus.value - total_tasks = detail.totalTask or 0 - successful_runs = detail.runSuccess or 0 - - if status in ERROR_STATES: - raise WebError(f"Batch '{batch_id}' terminated with error status: {status}") - - # Check if the run phase is complete before proceeding - is_run_complete = status in ("run_success", "success") or ( - total_tasks > 0 and successful_runs >= total_tasks - ) - - if not is_run_complete: - if console: - console.log( - f"Batch '{batch_id}' run phase is not yet complete (Status: {status}). " - f"Cannot start postprocessing." - ) - return # Exit if the run is not done - BatchTask(batch_id).postprocess(batch_type="RF_SWEEP", worker_group=worker_group) - return - else: - raise WebError( - f"Batch ID '{batch_id}' is not a component modeler batch job. " - "'postprocess_start' is only applicable to those classes." - ) - - @wait_for_connection def test() -> None: """Confirm whether Tidy3D authentication is configured. diff --git a/tidy3d/web/core/constants.py b/tidy3d/web/core/constants.py index 109d985026..623af2bba8 100644 --- a/tidy3d/web/core/constants.py +++ b/tidy3d/web/core/constants.py @@ -31,6 +31,7 @@ MODE_FILE_HDF5_GZ = "mode_solver.hdf5.gz" MODE_DATA_HDF5_GZ = "output/mode_solver_data.hdf5.gz" SIM_ERROR_FILE = "output/tidy3d_error.json" +SIM_VALIDATION_FILE = "output/tidy3d_validation.json" # Component modeler specific artifacts MODELER_FILE_HDF5_GZ = "modeler.hdf5.gz" diff --git a/tidy3d/web/core/http_util.py b/tidy3d/web/core/http_util.py index be66ea8985..9e5c9f3f9f 100644 --- a/tidy3d/web/core/http_util.py +++ b/tidy3d/web/core/http_util.py @@ -42,6 +42,7 @@ class ResponseCodes(Enum): def get_version() -> str: """Get the version for the current environment.""" return core_config.get_version() + # return "2.10.0rc2.1" def get_user_agent() -> str: diff --git a/tidy3d/web/core/task_core.py b/tidy3d/web/core/task_core.py index e3d7db1a97..f1e4ca8c7d 100644 --- a/tidy3d/web/core/task_core.py +++ b/tidy3d/web/core/task_core.py @@ -5,7 +5,6 @@ import os import pathlib import tempfile -import time from datetime import datetime from os import PathLike from typing import Callable, Optional, Union @@ -17,7 +16,6 @@ import tidy3d as td from tidy3d.config import config from tidy3d.exceptions import ValidationError -from tidy3d.web.common import REFRESH_TIME from . import http_util from .cache import FOLDER_CACHE @@ -25,6 +23,7 @@ SIM_ERROR_FILE, SIM_FILE_HDF5_GZ, SIM_LOG_FILE, + SIM_VALIDATION_FILE, SIMULATION_DATA_HDF5_GZ, ) from .core_config import get_logger_console @@ -34,7 +33,7 @@ from .http_util import http from .s3utils import download_file, download_gz_file, upload_file from .stub import TaskStub -from .task_info import BatchDetail +from .task_info import BatchDetail, TaskInfo from .types import PayType, Queryable, ResourceLifecycle, Submittable, Tidy3DResource @@ -145,8 +144,8 @@ def list_tasks(self, projects_endpoint: str = "tidy3d/projects") -> list[Tidy3DR ) -class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): - """Interface for managing the running of a :class:`.Simulation` task on server.""" +class WebTask(ResourceLifecycle, Submittable, extra=Extra.allow): + """Interface for managing the running a task on the server.""" task_id: Optional[str] = Field( ..., @@ -154,52 +153,6 @@ class SimulationTask(ResourceLifecycle, Submittable, extra=Extra.allow): description="Task ID number, set when the task is uploaded, leave as None.", alias="taskId", ) - folder_id: Optional[str] = Field( - None, - title="folder_id", - description="Folder ID number, set when the task is uploaded, leave as None.", - alias="folderId", - ) - status: Optional[str] = Field(title="status", description="Simulation task status.") - - real_flex_unit: float = Field( - None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" - ) - - created_at: Optional[datetime] = Field( - title="created_at", description="Time at which this task was created.", alias="createdAt" - ) - - task_type: Optional[str] = Field( - title="task_type", description="The type of task.", alias="taskType" - ) - - folder_name: Optional[str] = Field( - "default", - title="Folder Name", - description="Name of the folder associated with this task.", - alias="folderName", - ) - - callback_url: str = Field( - None, - title="Callback URL", - description="Http PUT url to receive simulation finish event. " - "The body content is a json file with fields " - "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", - ) - - # simulation_type: str = pd.Field( - # None, - # title="Simulation Type", - # description="Type of simulation, used internally only.", - # ) - - # parent_tasks: Tuple[TaskId, ...] = pd.Field( - # None, - # title="Parent Tasks", - # description="List of parent task ids for the simulation, used internally only." - # ) @classmethod def create( @@ -211,7 +164,6 @@ def create( simulation_type: str = "tidy3d", parent_tasks: Optional[list[str]] = None, file_type: str = "Gz", - port_name_list: Optional[list[str]] = None, projects_endpoint: str = "tidy3d/projects", ) -> SimulationTask: """Create a new task on the server. @@ -246,32 +198,241 @@ def create( simulation_type = "tidy3d" folder = Folder.get(folder_name, create=True) - payload = { - "taskName": task_name, - "taskType": task_type, - "callbackUrl": callback_url, - "simulationType": simulation_type, - "parentTasks": parent_tasks, - "fileType": file_type, - } - # Component modeler: include port names if provided - if port_name_list: - # Align with backend contract: expect 'portNames' (not 'portNameList') - payload["portNames"] = port_name_list - - resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) - # RF group creation may return group-level info without 'taskId'. - # Use 'groupId' (or 'batchId' as fallback) as the resource id for subsequent uploads. - if "taskId" not in resp and task_type == "RF": - # Prefer using 'batchId' as the resource id for uploads (S3 STS expects a task-like id). - if "batchId" in resp: - resp["taskId"] = resp["batchId"] - elif "groupId" in resp: - resp["taskId"] = resp["groupId"] - else: - raise WebError("Missing resource ID for task creation. Contact customer support.") + + if task_type in ["RF", "TERMINAL_CM", "MODAL_CM"]: + payload = { + "groupName": task_name, + "folderId": folder.folder_id, + "fileType": file_type, + "taskType": task_type, + } + resp = http.post("rf/task", payload) + else: + payload = { + "taskName": task_name, + "taskType": task_type, + "callbackUrl": callback_url, + "simulationType": simulation_type, + "parentTasks": parent_tasks, + "fileType": file_type, + } + resp = http.post(f"{projects_endpoint}/{folder.folder_id}/tasks", payload) + return SimulationTask(**resp, taskType=task_type, folder_name=folder_name) + def get_url(self) -> str: + base = str(config.web.website_endpoint or "") + if isinstance(self, BatchTask): + return "/".join([base.rstrip("/"), f"rf?taskId={self.task_id}"]) + return "/".join([base.rstrip("/"), f"workbench?taskId={self.task_id}"]) + + def get_folder_url(self) -> Optional[str]: + folder_id = getattr(self, "folder_id", None) + if not folder_id: + return None + base = str(config.web.website_endpoint or "") + return "/".join([base.rstrip("/"), f"folders/{folder_id}"]) + + def get_log( + self, + to_file: PathLike, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Get log file from Server. + + Parameters + ---------- + to_file: PathLike + Save file to path. + verbose: bool = True + Whether to display progress bars. + progress_callback : Callable[[float], None] = None + Optional callback function called while downloading the data. + + Returns + ------- + path: pathlib.Path + Path to saved file. + """ + + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + + target_path = pathlib.Path(to_file) + + return download_file( + self.task_id, + SIM_LOG_FILE, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + + def get_data_hdf5( + self, + to_file: PathLike, + remote_data_file_gz: PathLike = SIMULATION_DATA_HDF5_GZ, + verbose: bool = True, + progress_callback: Optional[Callable[[float], None]] = None, + ) -> pathlib.Path: + """Download data artifact (simulation or batch) with gz fallback handling. + + Parameters + ---------- + remote_data_file_gz : PathLike + Gzipped remote filename. + to_file : PathLike + Local target path. + verbose : bool + Whether to log progress. + progress_callback : Optional[Callable[[float], None]] + Progress callback. + + Returns + ------- + pathlib.Path + Saved local path. + """ + if not self.task_id: + raise WebError("Expected field 'task_id' is unset.") + target_path = pathlib.Path(to_file) + file = None + try: + file = download_gz_file( + resource_id=self.task_id, + remote_filename=remote_data_file_gz, + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except ClientError: + if verbose: + console = get_logger_console() + console.log(f"Unable to download '{remote_data_file_gz}'.") + if not file: + try: + file = download_file( + resource_id=self.task_id, + remote_filename=str(remote_data_file_gz)[:-3], + to_file=target_path, + verbose=verbose, + progress_callback=progress_callback, + ) + except Exception as e: + raise WebError( + "Failed to download the data file from the server. " + "Please confirm that the task completed successfully." + ) from e + return file + + @staticmethod + def is_batch(resource_id: str) -> bool: + """Checks if a given resource ID corresponds to a valid batch task. + + This is a utility function to verify a batch task's existence before + instantiating the class. + + Parameters + ---------- + resource_id : str + The unique identifier for the resource. + + Returns + ------- + bool + ``True`` if the resource is a valid batch task, ``False`` otherwise. + """ + try: + # TODO PROPERLY FIXME + # Disable non critical logs due to check for resourceId, until we have a dedicated API for this + resp = http.get( + f"rf/task/{resource_id}/statistics", + suppress_404=True, + ) + status = bool(resp and isinstance(resp, dict) and "status" in resp) + return status + except Exception: + return False + + def delete(self, versions: bool = False) -> None: + """Delete current task from server. + + Parameters + ---------- + versions : bool = False + If ``True``, delete all versions of the task in the task group. Otherwise, delete only + the version associated with the current task ID. + """ + if not self.task_id: + raise ValueError("Task id not found.") + + task_details = self.detail().dict() + + if task_details and "groupId" in task_details: + group_id = task_details["groupId"] + if versions: + http.delete("tidy3d/group", json={"groupIds": [group_id]}) + return + elif "version" in task_details: + version = task_details["version"] + http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) + return + + # Fallback to old method if we can't get the groupId and version + http.delete(f"tidy3d/tasks/{self.task_id}") + + +class SimulationTask(WebTask): + """Interface for managing the running of solver tasks on the server.""" + + folder_id: Optional[str] = Field( + None, + title="folder_id", + description="Folder ID number, set when the task is uploaded, leave as None.", + alias="folderId", + ) + status: Optional[str] = Field(title="status", description="Simulation task status.") + + real_flex_unit: float = Field( + None, title="real FlexCredits", description="Billed FlexCredits.", alias="realCost" + ) + + created_at: Optional[datetime] = Field( + title="created_at", description="Time at which this task was created.", alias="createdAt" + ) + + task_type: Optional[str] = Field( + title="task_type", description="The type of task.", alias="taskType" + ) + + folder_name: Optional[str] = Field( + "default", + title="Folder Name", + description="Name of the folder associated with this task.", + alias="folderName", + ) + + callback_url: str = Field( + None, + title="Callback URL", + description="Http PUT url to receive simulation finish event. " + "The body content is a json file with fields " + "``{'id', 'status', 'name', 'workUnit', 'solverVersion'}``.", + ) + + # simulation_type: str = pd.Field( + # None, + # title="Simulation Type", + # description="Type of simulation, used internally only.", + # ) + + # parent_tasks: Tuple[TaskId, ...] = pd.Field( + # None, + # title="Parent Tasks", + # description="List of parent task ids for the simulation, used internally only." + # ) + @classmethod def get(cls, task_id: str, verbose: bool = True) -> SimulationTask: """Get task from the server by id. @@ -313,28 +474,16 @@ def get_running_tasks(cls) -> list[SimulationTask]: return [] return parse_obj_as(list[SimulationTask], resp) - def delete(self, versions: bool = False) -> None: - """Delete current task from server. + def detail(self) -> TaskInfo: + """Fetches the detailed information and status of the task. - Parameters - ---------- - versions : bool = False - If ``True``, delete all versions of the task in the task group. Otherwise, delete only the version associated with the current task ID. + Returns + ------- + TaskInfo + An object containing the task's latest data. """ - if not self.task_id: - raise ValueError("Task id not found.") - - task_details = http.get(f"tidy3d/tasks/{self.task_id}") - - if task_details and "groupId" in task_details and "version" in task_details: - group_id = task_details["groupId"] - version = task_details["version"] - if versions: - http.delete("tidy3d/group", json={"groupIds": [group_id]}) - else: - http.delete(f"tidy3d/group/{group_id}/versions", json={"versions": [version]}) - else: # Fallback to old method if we can't get the groupId and version - http.delete(f"tidy3d/tasks/{self.task_id}") + resp = http.get(f"tidy3d/tasks/{self.task_id}/detail") + return TaskInfo(**{"taskId": self.task_id, "taskType": self.task_type, **resp}) def get_simulation_json(self, to_file: PathLike, verbose: bool = True) -> None: """Get json file for a :class:`.Simulation` from server. @@ -515,65 +664,6 @@ def estimate_cost(self, solver_version: Optional[str] = None) -> float: ) return resp - def get_sim_data_hdf5( - self, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - remote_data_file: PathLike = SIMULATION_DATA_HDF5_GZ, - ) -> pathlib.Path: - """Get simulation data file from Server. - - Parameters - ---------- - to_file: PathLike - Save file to path. - verbose: bool = True - Whether to display progress bars. - progress_callback : Callable[[float], None] = None - Optional callback function called while downloading the data. - - Returns - ------- - path: pathlib.Path - Path to saved file. - """ - if not self.task_id: - raise WebError("Expected field 'task_id' is unset.") - - target_path = pathlib.Path(to_file) - - file = None - try: - file = download_gz_file( - resource_id=self.task_id, - remote_filename=remote_data_file, - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except ClientError: - if verbose: - console = get_logger_console() - console.log(f"Unable to download '{remote_data_file}'.") - - if not file: - try: - file = download_file( - resource_id=self.task_id, - remote_filename=remote_data_file[:-3], - to_file=target_path, - verbose=verbose, - progress_callback=progress_callback, - ) - except Exception as e: - raise WebError( - "Failed to download the simulation data file from the server. " - "Please confirm that the task was successfully run." - ) from e - - return file - def get_simulation_hdf5( self, to_file: PathLike, @@ -666,7 +756,9 @@ def get_log( progress_callback=progress_callback, ) - def get_error_json(self, to_file: PathLike, verbose: bool = True) -> pathlib.Path: + def get_error_json( + self, to_file: PathLike, verbose: bool = True, validation: bool = False + ) -> pathlib.Path: """Get error json file for a :class:`.Simulation` from server. Parameters @@ -675,6 +767,8 @@ def get_error_json(self, to_file: PathLike, verbose: bool = True) -> pathlib.Pat Save file to path. verbose: bool = True Whether to display progress bars. + validation: bool = False + Whether to get a validation error file or a solver error file. Returns ------- @@ -685,16 +779,17 @@ def get_error_json(self, to_file: PathLike, verbose: bool = True) -> pathlib.Pat raise WebError("Expected field 'task_id' is unset.") target_path = pathlib.Path(to_file) + target_file = SIM_ERROR_FILE if not validation else SIM_VALIDATION_FILE return download_file( self.task_id, - SIM_ERROR_FILE, + target_file, to_file=target_path, verbose=verbose, ) def abort(self) -> requests.Response: - """Aborting current task from server.""" + """Abort the current task on the server.""" if not self.task_id: raise ValueError("Task id not found.") return http.put( @@ -732,68 +827,43 @@ def validate_post_upload(self, parent_tasks: Optional[list[str]] = None) -> None raise WebError(f"Provided 'parent_tasks' failed validation: {e!s}") from e -class BatchTask: - """Provides a client-side interface for managing a remote batch task. - - This class acts as a wrapper around the API endpoints for a specific batch, - allowing users to check, submit, monitor, and download data from it. +class BatchTask(WebTask): + """Interface for managing a batch task on the server.""" - Note: - The 'batch_type' (e.g., "RF_SWEEP") must be provided by the caller to - most methods, as it dictates which backend service handles the request. - """ - - def __init__(self, batch_id: str) -> None: - self.batch_id = batch_id - - @staticmethod - def is_batch(resource_id: str, batch_type: str) -> bool: - """Checks if a given resource ID corresponds to a valid batch task. - - This is a utility function to verify a batch task's existence before - instantiating the class. + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> BatchTask: + """Get batch task by id. Parameters ---------- - resource_id : str - The unique identifier for the resource. - batch_type : str - The type of the batch to check (e.g., "RF_SWEEP"). + task_id: str + Unique identifier of batch on server. + verbose: + If `True`, will print progressbars and status, otherwise, will run silently. Returns ------- - bool - ``True`` if the resource is a valid batch task, ``False`` otherwise. + :class:`.BatchTask` | None + BatchTask object if found, otherwise None. """ try: - # TODO PROPERLY FIXME - # Disable non critical logs due to check for resourceId, until we have a dedicated API for this - resp = http.get( - f"tidy3d/tasks/{resource_id}/batch-detail", - params={"batchType": batch_type}, - suppress_404=True, - ) - status = bool(resp and isinstance(resp, dict) and "status" in resp) - return status - except Exception: - return False + resp = http.get(f"rf/task/{task_id}/statistics") + except WebNotFoundError as e: + td.log.error(f"The requested batch ID '{task_id}' does not exist.") + raise e + # We only need to validate existence; store id on the instance. + return BatchTask(taskId=task_id) if resp else None - def detail(self, batch_type: str) -> BatchDetail: + def detail(self) -> BatchDetail: """Fetches the detailed information and status of the batch. - Parameters - ---------- - batch_type : str - The type of the batch (e.g., "RF_SWEEP"). - Returns ------- BatchDetail An object containing the batch's latest data. """ resp = http.get( - f"tidy3d/tasks/{self.batch_id}/batch-detail", - params={"batchType": batch_type}, + f"rf/task/{self.task_id}/statistics", ) # Some backends may return null for collection fields; coerce to sensible defaults if isinstance(resp, dict): @@ -803,9 +873,9 @@ def detail(self, batch_type: str) -> BatchDetail: def check( self, + check_task_type: str, solver_version: Optional[str] = None, protocol_version: Optional[str] = None, - batch_type: str = "", ) -> requests.Response: """Submits a request to validate the batch configuration on the server. @@ -815,8 +885,6 @@ def check( The version of the solver to use for validation. protocol_version : Optional[str], default=None The data protocol version. Defaults to the current version. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). Returns ------- @@ -826,11 +894,11 @@ def check( if protocol_version is None: protocol_version = _get_protocol_version() return http.post( - f"tidy3d/projects/{self.batch_id}/batch-check", + f"rf/task/{self.task_id}/check", { - "batchType": batch_type, "solverVersion": solver_version, "protocolVersion": protocol_version, + "taskType": check_task_type, }, ) @@ -839,7 +907,8 @@ def submit( solver_version: Optional[str] = None, protocol_version: Optional[str] = None, worker_group: Optional[str] = None, - batch_type: str = "", + pay_type: Union[PayType, str] = PayType.AUTO, + priority: Optional[int] = None, ) -> requests.Response: """Submits the batch for execution on the server. @@ -851,199 +920,67 @@ def submit( The data protocol version. Defaults to the current version. worker_group : Optional[str], default=None Optional identifier for a specific worker group to run on. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). Returns ------- Any The server's response to the submit request. """ - if protocol_version is None: - protocol_version = _get_protocol_version() - return http.post( - f"tidy3d/projects/{self.batch_id}/batch-submit", - { - "batchType": batch_type, - "solverVersion": solver_version, - "protocolVersion": protocol_version, - "workerGroup": worker_group, - }, - ) - - def postprocess( - self, - solver_version: Optional[str] = None, - protocol_version: Optional[str] = None, - worker_group: Optional[str] = None, - batch_type: str = "", - ) -> requests.Response: - """Initiates post-processing for a completed batch run. - Parameters - ---------- - solver_version : Optional[str], default=None - The version of the solver to use for post-processing. - protocol_version : Optional[str], default=None - The data protocol version. Defaults to the current version. - worker_group : Optional[str], default=None - Optional identifier for a specific worker group to run on. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). + # TODO: add support for pay_type and priority arguments + if pay_type != PayType.AUTO: + raise NotImplementedError( + "The 'pay_type' argument is not yet supported and will be ignored." + ) + if priority is not None: + raise NotImplementedError( + "The 'priority' argument is not yet supported and will be ignored." + ) - Returns - ------- - Any - The server's response to the post-process request. - """ if protocol_version is None: protocol_version = _get_protocol_version() return http.post( - f"tidy3d/projects/{self.batch_id}/postprocess", + f"rf/task/{self.task_id}/submit", { - "batchType": batch_type, "solverVersion": solver_version, "protocolVersion": protocol_version, "workerGroup": worker_group, }, ) - def wait_for_validate( - self, timeout: Optional[float] = None, batch_type: str = "" - ) -> BatchDetail: - """Waits for the batch to complete the validation stage by polling its status. - - Parameters - ---------- - timeout : Optional[float], default=None - Maximum time in seconds to wait. If ``None``, waits indefinitely. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). - - Returns - ------- - BatchDetail - The final object after validation completes or a timeout occurs. - - Notes - ----- - This method blocks until the batch status is 'validate_success', - 'validate_warn', 'validate_fail', or another terminal state like 'blocked' - or 'aborted', or until the timeout is reached. - """ - start = datetime.now().timestamp() - while True: - d = self.detail(batch_type=batch_type) - status = d.totalStatus - if status in ("validate_success", "validate_warn", "validate_fail"): - return d - if status in ("blocked", "aborting", "aborted"): - return d - if timeout is not None and (datetime.now().timestamp() - start) > timeout: - return d - time.sleep(REFRESH_TIME) - - def wait_for_run(self, timeout: Optional[float] = None, batch_type: str = "") -> BatchDetail: - """Waits for the batch to complete the execution stage by polling its status. - - Parameters - ---------- - timeout : Optional[float], default=None - Maximum time in seconds to wait. If ``None``, waits indefinitely. - batch_type : str, default="" - The type of the batch (e.g., "RF_SWEEP"). - - Returns - ------- - BatchDetail - The final object after the run completes or a timeout occurs. - - Notes - ----- - This method blocks until the batch status reaches a terminal run state like - 'run_success', 'run_failed', 'diverged', 'blocked', or 'aborted', - or until the timeout is reached. - """ - start = datetime.now().timestamp() - while True: - d = self.detail(batch_type=batch_type) - status = d.totalStatus - if status in ( - "run_success", - "run_failed", - "diverged", - "blocked", - "aborting", - "aborted", - ): - return d - if timeout is not None and (datetime.now().timestamp() - start) > timeout: - return d - time.sleep(REFRESH_TIME) - - def get_data_hdf5( - self, - remote_data_file_gz: str, - to_file: PathLike, - verbose: bool = True, - progress_callback: Optional[Callable[[float], None]] = None, - ) -> pathlib.Path: - """Downloads a batch data artifact, with a fallback mechanism. + def abort(self) -> requests.Response: + """Abort the current task on the server.""" + if not self.task_id: + raise ValueError("Batch id not found.") + return http.put(f"rf/task/{self.task_id}/abort", {}) - Parameters - ---------- - remote_data_file_gz : str - Remote gzipped filename to download (e.g., 'output/cm_data.hdf5.gz'). - to_file : PathLike - Local path where the downloaded file will be saved. - verbose : bool, default=True - If ``True``, shows progress logs and messages. - progress_callback : Optional[Callable[[float], None]], default=None - Optional callback function for progress updates, which receives the - download percentage as a float. - Returns - ------- - pathlib.Path - An object pointing to the downloaded local file. +class TaskFactory: + """Factory for obtaining the correct task subclass.""" - Raises - ------ - WebError - If both the gzipped and uncompressed file downloads fail. + _REGISTRY: dict[str, str] = {} - Notes - ----- - This method first attempts to download the gzipped version of a file. - If that fails, it falls back to downloading the uncompressed version. - """ - file = None - try: - file = download_gz_file( - resource_id=self.batch_id, - remote_filename=remote_data_file_gz, - to_file=to_file, - verbose=verbose, - progress_callback=progress_callback, - ) - except ClientError: - if verbose: - console = get_logger_console() - console.log(f"Unable to download '{remote_data_file_gz}'.") + @classmethod + def reset(cls) -> None: + """Clear the cached task kind registry (used in tests).""" + cls._REGISTRY.clear() - if not file: - try: - file = download_file( - resource_id=self.batch_id, - remote_filename=remote_data_file_gz[:-3], - to_file=to_file, - verbose=verbose, - progress_callback=progress_callback, - ) - except Exception as e: - raise WebError( - "Failed to download the batch data file from the server. " - "Please confirm that the batch has been successfully postprocessed." - ) from e + @classmethod + def register(cls, task_id: str, kind: str) -> None: + cls._REGISTRY[task_id] = kind - return file + @classmethod + def get(cls, task_id: str, verbose: bool = True) -> WebTask: + kind = cls._REGISTRY.get(task_id) + if kind == "batch": + return BatchTask.get(task_id, verbose=verbose) + if kind == "simulation": + task = SimulationTask.get(task_id, verbose=verbose) + return task + if WebTask.is_batch(task_id): + cls.register(task_id, "batch") + return BatchTask.get(task_id, verbose=verbose) + task = SimulationTask.get(task_id, verbose=verbose) + if task: + cls.register(task_id, "simulation") + return task diff --git a/tidy3d/web/core/task_info.py b/tidy3d/web/core/task_info.py index 64cf00ea0f..825ab9b6a5 100644 --- a/tidy3d/web/core/task_info.py +++ b/tidy3d/web/core/task_info.py @@ -10,31 +10,6 @@ import pydantic.v1 as pydantic -class TaskStatus(Enum): - """The statuses that the task can be in.""" - - INIT = "initialized" - """The task has been initialized.""" - - QUEUE = "queued" - """The task is in the queue.""" - - PRE = "preprocessing" - """The task is in the preprocessing stage.""" - - RUN = "running" - """The task is running.""" - - POST = "postprocessing" - """The task is in the postprocessing stage.""" - - SUCCESS = "success" - """The task has completed successfully.""" - - ERROR = "error" - """The task has completed with an error.""" - - class TaskBase(pydantic.BaseModel, ABC): """Base configuration for all task objects.""" @@ -153,6 +128,9 @@ class TaskInfo(TaskBase): taskBlockInfo: TaskBlockInfo = None """Blocking information for the task.""" + version: str = None + """Version of the task.""" + class RunInfo(TaskBase): """Information about the run of a task.""" @@ -172,41 +150,6 @@ def display(self) -> None: # ---------------------- Batch (Modeler) detail schema ---------------------- # -class BatchStatus(str, Enum): - """Enumerates the possible statuses for a batch of tasks.""" - - draft = "draft" - """The batch is being configured and has not been submitted.""" - preprocess = "preprocess" - """The batch is undergoing preprocessing.""" - validating = "validating" - """The tasks within the batch are being validated.""" - validate_success = "validate_success" - """All tasks in the batch passed validation.""" - validate_warn = "validate_warn" - """Validation passed, but with warnings.""" - validate_fail = "validate_fail" - """Validation failed for one or more tasks.""" - blocked = "blocked" - """The batch is blocked and cannot run.""" - running = "running" - """The batch is currently executing.""" - aborting = "aborting" - """The batch is in the process of being aborted.""" - run_success = "run_success" - """The batch completed successfully.""" - postprocess = "postprocess" - """The batch is undergoing postprocessing.""" - run_failed = "run_failed" - """The batch execution failed.""" - diverged = "diverged" - """The simulation in the batch diverged.""" - aborted = "aborted" - """The batch was successfully aborted.""" - error = "error" - """An error occurred during the solver run.""" - - class BatchTaskBlockInfo(TaskBlockInfo): """ Extends `TaskBlockInfo` with specific details for batch task blocking. @@ -282,7 +225,6 @@ class BatchDetail(TaskBase): groupId: Identifier for the group the batch belongs to. name: The user-defined name of the batch. status: The current status of the batch. - totalStatus: The overall status, consolidating individual task statuses. totalTask: The total number of tasks in the batch. preprocessSuccess: The count of tasks that completed preprocessing. postprocessStatus: The status of the batch's postprocessing stage. @@ -295,6 +237,7 @@ class BatchDetail(TaskBase): totalCheckMillis: Total time in milliseconds spent on checks. message: A general message providing information about the batch status. tasks: A list of `BatchMember` objects, one for each task in the batch. + taskType: The type of tasks contained in the batch. """ refId: str = None @@ -302,7 +245,6 @@ class BatchDetail(TaskBase): groupId: str = None name: str = None status: str = None - totalStatus: BatchStatus = None totalTask: int = 0 preprocessSuccess: int = 0 postprocessStatus: str = None @@ -317,6 +259,8 @@ class BatchDetail(TaskBase): message: str = None tasks: list[BatchMember] = [] validateErrors: dict = None + taskType: str = "RF" + version: str = None class AsyncJobDetail(TaskBase): diff --git a/tidy3d/web/core/types.py b/tidy3d/web/core/types.py index 39ab0bc926..1c1439366c 100644 --- a/tidy3d/web/core/types.py +++ b/tidy3d/web/core/types.py @@ -56,8 +56,8 @@ class TaskType(str, Enum): EME = "EME" MODE = "MODE" VOLUME_MESH = "VOLUME_MESH" - COMPONENT_MODELER = "COMPONENT_MODELER" - TERMINAL_COMPONENT_MODELER = "TERMINAL_COMPONENT_MODELER" + MODAL_CM = "MODAL_CM" + TERMINAL_CM = "TERMINAL_CM" class PayType(str, Enum): diff --git a/tidy3d/web/tests/conftest.py b/tidy3d/web/tests/conftest.py new file mode 100644 index 0000000000..918dce5588 --- /dev/null +++ b/tidy3d/web/tests/conftest.py @@ -0,0 +1,15 @@ +from __future__ import annotations + +from collections.abc import Generator + +import pytest +from tidy3d_frontend.tidy3d.web.core.task_core import TaskFactory + + +@pytest.fixture(autouse=True) +def clear_task_factory_registry() -> Generator[None, None, None]: + """Ensure TaskFactory registry is empty for each test.""" + TaskFactory.reset() + TaskFactory.reset() + yield + TaskFactory.reset()