From 9bedcfd10ade0bbc2ee5e1f8745c06cb36f79e01 Mon Sep 17 00:00:00 2001 From: ljcornel Date: Tue, 16 May 2023 13:38:30 +0200 Subject: [PATCH 1/4] Add `TestingClient` to perform model tests --- geti_sdk/data_models/__init__.py | 3 + geti_sdk/data_models/test_result.py | 117 ++++++++++++++ geti_sdk/rest_clients/__init__.py | 2 + geti_sdk/rest_clients/testing_client.py | 114 +++++++++++++ geti_sdk/rest_clients/training_client.py | 105 +++--------- geti_sdk/rest_converters/__init__.py | 4 + .../annotation_rest_converter.py | 4 +- .../normalized_annotation_rest_converter.py | 3 +- .../rest_converters/model_rest_converter.py | 3 +- .../test_result_rest_converter.py | 36 +++++ geti_sdk/utils/job_helpers.py | 150 ++++++++++++++++++ 11 files changed, 451 insertions(+), 90 deletions(-) create mode 100644 geti_sdk/data_models/test_result.py create mode 100644 geti_sdk/rest_clients/testing_client.py create mode 100644 geti_sdk/rest_converters/test_result_rest_converter.py create mode 100644 geti_sdk/utils/job_helpers.py diff --git a/geti_sdk/data_models/__init__.py b/geti_sdk/data_models/__init__.py index 4a263f2c..f43c9eb4 100644 --- a/geti_sdk/data_models/__init__.py +++ b/geti_sdk/data_models/__init__.py @@ -180,6 +180,7 @@ from .project import Dataset, Pipeline, Project from .status import ProjectStatus from .task import Task +from .test_result import Score, TestResult __all__ = [ "TaskType", @@ -211,4 +212,6 @@ "Job", "CodeDeploymentInformation", "Dataset", + "TestResult", + "Score", ] diff --git a/geti_sdk/data_models/test_result.py b/geti_sdk/data_models/test_result.py new file mode 100644 index 00000000..4f08f8fd --- /dev/null +++ b/geti_sdk/data_models/test_result.py @@ -0,0 +1,117 @@ +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +from typing import List, Optional + +import attr + +from geti_sdk.data_models.enums import OptimizationType +from geti_sdk.data_models.utils import ( + str_to_datetime, + str_to_enum_converter, + str_to_task_type, +) + + +@attr.define() +class DatasetInfo: + """ + Container for dataset information, specific to test datasets + """ + + id: str + is_deleted: bool + n_frames: int + n_images: int + n_samples: int + name: str + + +@attr.define() +class JobInfo: + """ + Container for job information, specific to model testing jobs + """ + + id: str + status: str + + @property + def is_done(self) -> bool: + """ + Return True if the testing job has finished, False otherwise + + :return: True for a finished job, False otherwise + """ + return self.status.lower() == "done" + + +@attr.define() +class ModelInfo: + """ + Container for information related to the model, specific for model tests + """ + + group_id: str + id: str + n_labels: int + task_type: str = attr.field(converter=str_to_task_type) + template_id: str + optimization_type: str = attr.field( + converter=str_to_enum_converter(OptimizationType) + ) + version: int + + +@attr.define() +class Score: + """ + Container class holding a score resulting from a model testing job. The metric + contained can either relate to a single label (`label_id` will be assigned) or + averaged over the dataset as a whole (`label_id` will be None) + + Score values range from 0 to 1 + """ + + name: str + value: float + label_id: Optional[str] = None + + +@attr.define() +class TestResult: + """ + Representation of the results of a model test job that was run for a specific + model and dataset in an Intel® Geti™ project + """ + + datasets_info: List[DatasetInfo] + id: str + job_info: JobInfo + model_info: ModelInfo + name: str + scores: List[Score] + creation_time: Optional[str] = attr.field(default=None, converter=str_to_datetime) + + def get_mean_score(self) -> Score: + """ + Return the mean score computed over the full dataset + + :return: Mean score on the dataset + """ + if not self.job_info.is_done: + raise ValueError( + "Unable to retrieve mean model score, the model testing job is not " + "finished yet." + ) + return [score for score in self.scores if score.label_id is None][0] diff --git a/geti_sdk/rest_clients/__init__.py b/geti_sdk/rest_clients/__init__.py index 3296b2e5..03d47d6b 100644 --- a/geti_sdk/rest_clients/__init__.py +++ b/geti_sdk/rest_clients/__init__.py @@ -97,6 +97,7 @@ from .model_client import ModelClient from .prediction_client import PredictionClient from .project_client import ProjectClient +from .testing_client import TestingClient from .training_client import TrainingClient __all__ = [ @@ -111,4 +112,5 @@ "TrainingClient", "DeploymentClient", "ActiveLearningClient", + "TestingClient", ] diff --git a/geti_sdk/rest_clients/testing_client.py b/geti_sdk/rest_clients/testing_client.py new file mode 100644 index 00000000..9a8c628d --- /dev/null +++ b/geti_sdk/rest_clients/testing_client.py @@ -0,0 +1,114 @@ +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +from typing import List, Optional, Sequence + +from geti_sdk.data_models import Dataset, Job, Model, Project, TestResult +from geti_sdk.http_session import GetiSession +from geti_sdk.rest_converters import TestResultRESTConverter +from geti_sdk.utils.job_helpers import get_job_with_timeout, monitor_jobs + +SUPPORTED_METRICS = ["global", "local"] + + +class TestingClient: + """ + Class to manage testing jobs for a certain Intel® Geti™ project. + """ + + def __init__(self, workspace_id: str, project: Project, session: GetiSession): + self.session = session + self.project = project + self.workspace_id = workspace_id + self.base_url = f"workspaces/{workspace_id}/projects/{project.id}/tests" + + def test_model( + self, + model: Model, + datasets: Sequence[Dataset], + name: Optional[str] = None, + metric: Optional[str] = None, + ) -> Job: + """ + Start a model testing job for a specific `model` and `dataset` + + :param model: Model to evaluate + :param datasets: Testing dataset(s) to evaluate the model on + :param name: Optional name to assign to the testing job + :param metric: Optional metric to calculate. This is only valid for either + anomaly segmentation or anomaly detection models. Possible values are + `global` or `local` + :return: Job object representing the testing job + """ + if name is None: + name = ( + f"Testing job for model `{model.name}` on datasets " + f"`{[ds.name for ds in datasets]}`" + ) + dataset_ids = [ds.id for ds in datasets] + + test_data = { + "name": name, + "model_group_id": model.model_group_id, + "model_id": model.id, + "dataset_ids": dataset_ids, + } + if metric is not None: + if metric not in SUPPORTED_METRICS: + raise ValueError( + f"Invalid metric received! Only `{SUPPORTED_METRICS}` are " + f"supported currently." + ) + test_data.update({"metric": metric}) + + response = self.session.get_rest_response( + url=self.base_url, method="POST", data=test_data + ) + return get_job_with_timeout( + job_id=response["job_ids"][0], + session=self.session, + workspace_id=self.workspace_id, + job_type="testing", + ) + + def get_test_result(self, test_id: str) -> TestResult: + """ + Retrieve the result of the model test with id `test_id` from the Intel® Geti™ + server + + :param test_id: Unique ID of the test to fetch the results for + :return: TestResult instance containing the test results + """ + response = self.session.get_rest_response( + url=self.base_url + "/" + test_id, method="GET" + ) + return TestResultRESTConverter.from_dict(response) + + def monitor_jobs( + self, jobs: List[Job], timeout: int = 10000, interval: int = 15 + ) -> List[Job]: + """ + Monitor and print the progress of all jobs in the list `jobs`. Execution is + halted until all jobs have either finished, failed or were cancelled. + + Progress will be reported in 15s intervals + + :param jobs: List of jobs to monitor + :param timeout: Timeout (in seconds) after which to stop the monitoring + :param interval: Time interval (in seconds) at which the TrainingClient polls + the server to update the status of the jobs. Defaults to 15 seconds + :return: List of finished (or failed) jobs with their status updated + """ + return monitor_jobs( + session=self.session, jobs=jobs, timeout=timeout, interval=interval + ) diff --git a/geti_sdk/rest_clients/training_client.py b/geti_sdk/rest_clients/training_client.py index 25b276c5..7277ac30 100644 --- a/geti_sdk/rest_clients/training_client.py +++ b/geti_sdk/rest_clients/training_client.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions # and limitations under the License. import logging -import time from typing import Any, Dict, List, Optional, Union from geti_sdk.data_models import ( @@ -26,7 +25,7 @@ from geti_sdk.data_models.containers import AlgorithmList from geti_sdk.data_models.enums import JobState, JobType from geti_sdk.data_models.project import Dataset -from geti_sdk.http_session import GetiRequestException, GetiSession +from geti_sdk.http_session import GetiSession from geti_sdk.platform_versions import GETI_11_VERSION from geti_sdk.rest_converters import ( ConfigurationRESTConverter, @@ -34,11 +33,12 @@ StatusRESTConverter, ) from geti_sdk.utils import get_supported_algorithms +from geti_sdk.utils.job_helpers import get_job_with_timeout, monitor_jobs class TrainingClient: """ - Class to manage training jobs for a certain project. + Class to manage training jobs for a certain Intel® Geti™ project. """ def __init__(self, workspace_id: str, project: Project, session: GetiSession): @@ -96,26 +96,6 @@ def get_jobs(self, project_only: bool = True) -> List[Job]: else: return job_list - def get_job_by_id(self, job_id: str) -> Optional[Job]: - """ - Return the details of a Job by its `job_id`. - - :param job_id: ID of the job to retrieve - :return: Job instance containing detailed information and status of the job. - If no job by the specified ID is found on the Intel® Geti™ platform, this - method returns None - """ - try: - response = self.session.get_rest_response( - url=f"workspaces/{self.workspace_id}/jobs/{job_id}", method="GET" - ) - except GetiRequestException as error: - if error.status_code == 404: - return None - else: - raise error - return JobRESTConverter.from_dict(response) - def get_algorithms_for_task(self, task: Union[Task, int]) -> AlgorithmList: """ Return a list of supported algorithms for a specific task. @@ -243,26 +223,18 @@ def train_task( job_id = job.id else: job_id = response["job_ids"][0] - job = self.get_job_by_id(job_id=job_id) - - if job is not None: - logging.info(f"Training job with ID {job_id} submitted successfully.") - else: - t_start = time.time() - while job is None and (time.time() - t_start < 10): - logging.info( - f"Training request was submitted but the training job status could " - f"not be retrieved from the platform yet. Re-attempting to fetch " - f"job status. Looking for job with ID {job_id}" - ) - time.sleep(2) - job = self.get_job_by_id(job_id=job_id) - if job is None: - raise RuntimeError( - "Train request was submitted but the TrainingClient was unable to " - "find the resulting training job on the Intel® Geti™ server." - ) - job.workspace_id = self.workspace_id + try: + job = get_job_with_timeout( + job_id=job_id, + session=self.session, + workspace_id=self.workspace_id, + job_type="training", + ) + except RuntimeError: + raise RuntimeError( + "Training job was submitted, but the TrainingClient was unable to " + "find the resulting job on the platform." + ) return job def monitor_jobs( @@ -280,50 +252,9 @@ def monitor_jobs( the server to update the status of the jobs. Defaults to 15 seconds :return: List of finished (or failed) jobs with their status updated """ - monitoring = True - completed_states = [ - JobState.FINISHED, - JobState.CANCELLED, - JobState.FAILED, - JobState.ERROR, - ] - logging.info("---------------- Monitoring progress -------------------") - jobs_to_monitor = [ - job for job in jobs if job.status.state not in completed_states - ] - try: - t_start = time.time() - t_elapsed = 0 - while monitoring and t_elapsed < timeout: - msg = "" - complete_count = 0 - for job in jobs_to_monitor: - job.update(self.session) - msg += ( - f"{job.name} -- " - f" Phase: {job.status.user_friendly_message} " - f" State: {job.status.state} " - f" Progress: {job.status.progress:.1f}%" - ) - if job.status.state in completed_states: - complete_count += 1 - if complete_count == len(jobs_to_monitor): - monitoring = False - logging.info(msg) - time.sleep(interval) - t_elapsed = time.time() - t_start - except KeyboardInterrupt: - logging.info("Job monitoring interrupted, stopping...") - for job in jobs: - job.update(self.session) - return jobs - if t_elapsed < timeout: - logging.info("All jobs completed, monitoring stopped.") - else: - logging.info( - f"Monitoring stopped after {t_elapsed:.1f} seconds due to timeout." - ) - return jobs + return monitor_jobs( + session=self.session, jobs=jobs, timeout=timeout, interval=interval + ) def get_jobs_for_task(self, task: Task, running_only: bool = True) -> List[Job]: """ diff --git a/geti_sdk/rest_converters/__init__.py b/geti_sdk/rest_converters/__init__.py index c01e8368..54f24835 100644 --- a/geti_sdk/rest_converters/__init__.py +++ b/geti_sdk/rest_converters/__init__.py @@ -50,6 +50,8 @@ .. autoclass:: JobRESTConverter :members: +.. autoclass:: TestResultRESTConverter + :members: """ from .annotation_rest_converter import AnnotationRESTConverter from .configuration_rest_converter import ConfigurationRESTConverter @@ -59,6 +61,7 @@ from .prediction_rest_converter import PredictionRESTConverter from .project_rest_converter import ProjectRESTConverter from .status_rest_converter import StatusRESTConverter +from .test_result_rest_converter import TestResultRESTConverter __all__ = [ "ProjectRESTConverter", @@ -69,4 +72,5 @@ "ModelRESTConverter", "StatusRESTConverter", "JobRESTConverter", + "TestResultRESTConverter", ] diff --git a/geti_sdk/rest_converters/annotation_rest_converter/annotation_rest_converter.py b/geti_sdk/rest_converters/annotation_rest_converter/annotation_rest_converter.py index 1aeb41f7..092f8507 100644 --- a/geti_sdk/rest_converters/annotation_rest_converter/annotation_rest_converter.py +++ b/geti_sdk/rest_converters/annotation_rest_converter/annotation_rest_converter.py @@ -18,8 +18,10 @@ import attr from omegaconf import OmegaConf -from geti_sdk.data_models import Annotation, AnnotationScene, MediaType, ScoredLabel +from geti_sdk.data_models import Annotation, AnnotationScene from geti_sdk.data_models.enums import ShapeType +from geti_sdk.data_models.label import ScoredLabel +from geti_sdk.data_models.media import MediaType from geti_sdk.data_models.media_identifiers import ( ImageIdentifier, MediaIdentifier, diff --git a/geti_sdk/rest_converters/annotation_rest_converter/normalized_annotation_rest_converter.py b/geti_sdk/rest_converters/annotation_rest_converter/normalized_annotation_rest_converter.py index a3b80f1f..328503f5 100644 --- a/geti_sdk/rest_converters/annotation_rest_converter/normalized_annotation_rest_converter.py +++ b/geti_sdk/rest_converters/annotation_rest_converter/normalized_annotation_rest_converter.py @@ -17,8 +17,9 @@ import attr -from geti_sdk.data_models import Annotation, AnnotationScene, ScoredLabel +from geti_sdk.data_models import Annotation, AnnotationScene from geti_sdk.data_models.enums import ShapeType +from geti_sdk.data_models.label import ScoredLabel from geti_sdk.data_models.shapes import Shape from geti_sdk.data_models.utils import ( attr_value_serializer, diff --git a/geti_sdk/rest_converters/model_rest_converter.py b/geti_sdk/rest_converters/model_rest_converter.py index 7a26b0be..717a4ce4 100644 --- a/geti_sdk/rest_converters/model_rest_converter.py +++ b/geti_sdk/rest_converters/model_rest_converter.py @@ -14,7 +14,8 @@ from typing import Any, Dict -from geti_sdk.data_models import Model, ModelGroup, OptimizedModel +from geti_sdk.data_models.model import Model, OptimizedModel +from geti_sdk.data_models.model_group import ModelGroup from geti_sdk.utils import deserialize_dictionary diff --git a/geti_sdk/rest_converters/test_result_rest_converter.py b/geti_sdk/rest_converters/test_result_rest_converter.py new file mode 100644 index 00000000..0ae44d07 --- /dev/null +++ b/geti_sdk/rest_converters/test_result_rest_converter.py @@ -0,0 +1,36 @@ +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. + +from typing import Any, Dict + +from geti_sdk.data_models import TestResult +from geti_sdk.utils import deserialize_dictionary + + +class TestResultRESTConverter: + """ + Class that handles conversion of Intel® Geti™ REST output for test results to + objects, and vice-versa + """ + + @staticmethod + def from_dict(result_dict: Dict[str, Any]) -> TestResult: + """ + Create a TestResult instance from the input dictionary passed in `result_dict`. + + :param result_dict: Dictionary representing a test result on the Intel® Geti™ + server, as returned by the /tests endpoints + :return: TestResult instance, holding the result data contained in result_dict + """ + return deserialize_dictionary(result_dict, output_type=TestResult) diff --git a/geti_sdk/utils/job_helpers.py b/geti_sdk/utils/job_helpers.py new file mode 100644 index 00000000..50bc7406 --- /dev/null +++ b/geti_sdk/utils/job_helpers.py @@ -0,0 +1,150 @@ +# Copyright (C) 2023 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions +# and limitations under the License. +import logging +import time +from typing import List, Optional + +from geti_sdk.data_models.enums.job_state import JobState +from geti_sdk.data_models.job import Job +from geti_sdk.http_session import GetiRequestException, GetiSession +from geti_sdk.rest_converters.job_rest_converter import JobRESTConverter + + +def get_job_by_id( + job_id: str, session: GetiSession, workspace_id: str +) -> Optional[Job]: + """ + Retrieve Job information from the Intel® Geti™ server. + + :param job_id: Unique ID of the job to retrieve + :param session: GetiSession instance addressing the Intel® Geti™ platform + :param workspace_id: ID of the workspace in which the job was created + :return: Job instance holding the details of the job + """ + try: + response = session.get_rest_response( + url=f"workspaces/{workspace_id}/jobs/{job_id}", method="GET" + ) + except GetiRequestException as error: + if error.status_code == 404: + return None + else: + raise error + return JobRESTConverter.from_dict(response) + + +def get_job_with_timeout( + job_id: str, + session: GetiSession, + workspace_id: str, + job_type: str = "training", + timeout: int = 15, +) -> Job: + """ + Retrieve a Job from the Intel® Geti™ server, by it's unique ID. If the job is not + found within the specified `timeout`, a RuntimeError is raised. + + :param job_id: Unique ID of the job to retrieve + :param session: GetiSession instance addressing the Intel® Geti™ platform + :param workspace_id: ID of the workspace in which the job was created + :param job_type: String representing the type of job, for instance "training" or + "testing" + :param timeout: Time (in seconds) after which the job retrieval will timeout + :raises: RuntimeError if the job is not found within the specified timeout + :return: Job instance holding the details of the job + """ + job = get_job_by_id(job_id=job_id, session=session, workspace_id=workspace_id) + if job is not None: + logging.info( + f"{job_type.capitalize()} job with ID {job_id} retrieved successfully." + ) + else: + t_start = time.time() + while job is None and (time.time() - t_start < timeout): + logging.info( + f"{job_type.capitalize()} job status could not be retrieved from the " + f"platform yet. Re-attempting to fetch job status. Looking for job " + f"with ID {job_id}" + ) + time.sleep(2) + job = get_job_by_id( + session=session, job_id=job_id, workspace_id=workspace_id + ) + if job is None: + raise RuntimeError( + f"Unable to find the resulting {job_type} job on the Intel® Geti™ " + f"server." + ) + job.workspace_id = workspace_id + return job + + +def monitor_jobs( + session: GetiSession, jobs: List[Job], timeout: int = 10000, interval: int = 15 +) -> List[Job]: + """ + Monitor and print the progress of all jobs in the list `jobs`. Execution is + halted until all jobs have either finished, failed or were cancelled. + + Progress will be reported in 15s intervals + + :param session: GetiSession instance addressing the Intel® Geti™ platform + :param jobs: List of jobs to monitor + :param timeout: Timeout (in seconds) after which to stop the monitoring + :param interval: Time interval (in seconds) at which the TrainingClient polls + the server to update the status of the jobs. Defaults to 15 seconds + :return: List of finished (or failed) jobs with their status updated + """ + monitoring = True + completed_states = [ + JobState.FINISHED, + JobState.CANCELLED, + JobState.FAILED, + JobState.ERROR, + ] + logging.info("---------------- Monitoring progress -------------------") + jobs_to_monitor = [job for job in jobs if job.status.state not in completed_states] + try: + t_start = time.time() + t_elapsed = 0 + while monitoring and t_elapsed < timeout: + msg = "" + complete_count = 0 + for job in jobs_to_monitor: + job.update(session) + msg += ( + f"{job.name} -- " + f" Phase: {job.status.user_friendly_message} " + f" State: {job.status.state} " + f" Progress: {job.status.progress:.1f}%" + ) + if job.status.state in completed_states: + complete_count += 1 + if complete_count == len(jobs_to_monitor): + break + logging.info(msg) + time.sleep(interval) + t_elapsed = time.time() - t_start + except KeyboardInterrupt: + logging.info("Job monitoring interrupted, stopping...") + for job in jobs: + job.update(session) + return jobs + if t_elapsed < timeout: + logging.info("All jobs completed, monitoring stopped.") + else: + logging.info( + f"Monitoring stopped after {t_elapsed:.1f} seconds due to timeout." + ) + return jobs From 168d5e7216c0c286c2c83c2f076b1f8633447ee6 Mon Sep 17 00:00:00 2001 From: ljcornel Date: Tue, 16 May 2023 14:57:46 +0200 Subject: [PATCH 2/4] Add logging and fix task type conversion --- geti_sdk/rest_clients/testing_client.py | 29 ++++++++++++------- .../test_result_rest_converter.py | 3 ++ 2 files changed, 22 insertions(+), 10 deletions(-) diff --git a/geti_sdk/rest_clients/testing_client.py b/geti_sdk/rest_clients/testing_client.py index 9a8c628d..fae05127 100644 --- a/geti_sdk/rest_clients/testing_client.py +++ b/geti_sdk/rest_clients/testing_client.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions # and limitations under the License. +import logging from typing import List, Optional, Sequence from geti_sdk.data_models import Dataset, Job, Model, Project, TestResult @@ -51,10 +52,7 @@ def test_model( :return: Job object representing the testing job """ if name is None: - name = ( - f"Testing job for model `{model.name}` on datasets " - f"`{[ds.name for ds in datasets]}`" - ) + name = f"Test_{model.name}_{model.version}" dataset_ids = [ds.id for ds in datasets] test_data = { @@ -71,6 +69,7 @@ def test_model( ) test_data.update({"metric": metric}) + logging.info("Starting model testing job...") response = self.session.get_rest_response( url=self.base_url, method="POST", data=test_data ) @@ -81,18 +80,28 @@ def test_model( job_type="testing", ) - def get_test_result(self, test_id: str) -> TestResult: + def get_test_result(self, job: Job) -> TestResult: """ - Retrieve the result of the model test with id `test_id` from the Intel® Geti™ + Retrieve the result of the model testing job from the Intel® Geti™ server - :param test_id: Unique ID of the test to fetch the results for + :param job: Job instance representing the model testing job :return: TestResult instance containing the test results """ - response = self.session.get_rest_response( - url=self.base_url + "/" + test_id, method="GET" + response = self.session.get_rest_response(url=self.base_url, method="GET") + test_results = [ + TestResultRESTConverter.from_dict(result) + for result in response["test_results"] + ] + result_for_job = next( + (result for result in test_results if result.job_info.id == job.id), None ) - return TestResultRESTConverter.from_dict(response) + if result_for_job is None: + raise ValueError( + f"Unable to find test result for job `{job.name}`, please make sure " + f"that a valid testing job was passed." + ) + return result_for_job def monitor_jobs( self, jobs: List[Job], timeout: int = 10000, interval: int = 15 diff --git a/geti_sdk/rest_converters/test_result_rest_converter.py b/geti_sdk/rest_converters/test_result_rest_converter.py index 0ae44d07..f8a581ca 100644 --- a/geti_sdk/rest_converters/test_result_rest_converter.py +++ b/geti_sdk/rest_converters/test_result_rest_converter.py @@ -33,4 +33,7 @@ def from_dict(result_dict: Dict[str, Any]) -> TestResult: server, as returned by the /tests endpoints :return: TestResult instance, holding the result data contained in result_dict """ + # Need to convert task type to lower case + task_type = result_dict["model_info"]["task_type"] + result_dict["model_info"]["task_type"] = task_type.lower() return deserialize_dictionary(result_dict, output_type=TestResult) From 42fe3c172adc2ba09508be899ad95a1698458e31 Mon Sep 17 00:00:00 2001 From: ljcornel Date: Tue, 16 May 2023 17:03:37 +0200 Subject: [PATCH 3/4] Fix optimization type --- geti_sdk/data_models/enums/optimization_type.py | 1 + geti_sdk/rest_clients/testing_client.py | 5 +++-- geti_sdk/rest_clients/training_client.py | 1 + geti_sdk/utils/job_helpers.py | 6 +++--- 4 files changed, 8 insertions(+), 5 deletions(-) diff --git a/geti_sdk/data_models/enums/optimization_type.py b/geti_sdk/data_models/enums/optimization_type.py index 8d914301..f407f23a 100644 --- a/geti_sdk/data_models/enums/optimization_type.py +++ b/geti_sdk/data_models/enums/optimization_type.py @@ -24,6 +24,7 @@ class OptimizationType(Enum): POT = "POT" MO = "MO" ONNX = "ONNX" + NONE = "NONE" def __str__(self) -> str: """ diff --git a/geti_sdk/rest_clients/testing_client.py b/geti_sdk/rest_clients/testing_client.py index fae05127..4ca55fcb 100644 --- a/geti_sdk/rest_clients/testing_client.py +++ b/geti_sdk/rest_clients/testing_client.py @@ -69,16 +69,17 @@ def test_model( ) test_data.update({"metric": metric}) - logging.info("Starting model testing job...") response = self.session.get_rest_response( url=self.base_url, method="POST", data=test_data ) - return get_job_with_timeout( + job = get_job_with_timeout( job_id=response["job_ids"][0], session=self.session, workspace_id=self.workspace_id, job_type="testing", ) + logging.info(f"Testing job with id {job.id} submitted successfully.") + return job def get_test_result(self, job: Job) -> TestResult: """ diff --git a/geti_sdk/rest_clients/training_client.py b/geti_sdk/rest_clients/training_client.py index 7277ac30..48f3f980 100644 --- a/geti_sdk/rest_clients/training_client.py +++ b/geti_sdk/rest_clients/training_client.py @@ -235,6 +235,7 @@ def train_task( "Training job was submitted, but the TrainingClient was unable to " "find the resulting job on the platform." ) + logging.info(f"Training job with id {job.id} submitted successfully.") return job def monitor_jobs( diff --git a/geti_sdk/utils/job_helpers.py b/geti_sdk/utils/job_helpers.py index 50bc7406..9a60e89a 100644 --- a/geti_sdk/utils/job_helpers.py +++ b/geti_sdk/utils/job_helpers.py @@ -66,13 +66,13 @@ def get_job_with_timeout( """ job = get_job_by_id(job_id=job_id, session=session, workspace_id=workspace_id) if job is not None: - logging.info( - f"{job_type.capitalize()} job with ID {job_id} retrieved successfully." + logging.debug( + f"{job_type.capitalize()} job with ID {job_id} retrieved from the platform." ) else: t_start = time.time() while job is None and (time.time() - t_start < timeout): - logging.info( + logging.debug( f"{job_type.capitalize()} job status could not be retrieved from the " f"platform yet. Re-attempting to fetch job status. Looking for job " f"with ID {job_id}" From 9fa428c1b28b48efe7c3b850473aded34aa2dc66 Mon Sep 17 00:00:00 2001 From: ljcornel Date: Wed, 17 May 2023 14:06:34 +0200 Subject: [PATCH 4/4] Add step to get inferencer configuration --- geti_sdk/deployment/deployed_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/geti_sdk/deployment/deployed_model.py b/geti_sdk/deployment/deployed_model.py index da9ed780..e8a6d921 100644 --- a/geti_sdk/deployment/deployed_model.py +++ b/geti_sdk/deployment/deployed_model.py @@ -26,6 +26,7 @@ import attr import numpy as np +from otx.algorithms.classification.utils import get_cls_inferencer_configuration from otx.api.entities.color import Color from otx.api.entities.label import Domain as OTEDomain from otx.api.entities.label import LabelEntity @@ -305,6 +306,9 @@ def load_inference_model( f"{wrapper_module_path}." ) from ex + if model_type == "otx_classification": + configuration = get_cls_inferencer_configuration(self.ote_label_schema) + model = OMZModel.create_model( name=model_type, model_adapter=model_adapter,