-
Notifications
You must be signed in to change notification settings - Fork 18
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #221 from openvinotoolkit/add-TestingClient-to-per…
…form-model-tests Add `TestingClient` to perform model tests
- Loading branch information
Showing
13 changed files
with
470 additions
and
90 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,117 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
from typing import List, Optional | ||
|
||
import attr | ||
|
||
from geti_sdk.data_models.enums import OptimizationType | ||
from geti_sdk.data_models.utils import ( | ||
str_to_datetime, | ||
str_to_enum_converter, | ||
str_to_task_type, | ||
) | ||
|
||
|
||
@attr.define() | ||
class DatasetInfo: | ||
""" | ||
Container for dataset information, specific to test datasets | ||
""" | ||
|
||
id: str | ||
is_deleted: bool | ||
n_frames: int | ||
n_images: int | ||
n_samples: int | ||
name: str | ||
|
||
|
||
@attr.define() | ||
class JobInfo: | ||
""" | ||
Container for job information, specific to model testing jobs | ||
""" | ||
|
||
id: str | ||
status: str | ||
|
||
@property | ||
def is_done(self) -> bool: | ||
""" | ||
Return True if the testing job has finished, False otherwise | ||
:return: True for a finished job, False otherwise | ||
""" | ||
return self.status.lower() == "done" | ||
|
||
|
||
@attr.define() | ||
class ModelInfo: | ||
""" | ||
Container for information related to the model, specific for model tests | ||
""" | ||
|
||
group_id: str | ||
id: str | ||
n_labels: int | ||
task_type: str = attr.field(converter=str_to_task_type) | ||
template_id: str | ||
optimization_type: str = attr.field( | ||
converter=str_to_enum_converter(OptimizationType) | ||
) | ||
version: int | ||
|
||
|
||
@attr.define() | ||
class Score: | ||
""" | ||
Container class holding a score resulting from a model testing job. The metric | ||
contained can either relate to a single label (`label_id` will be assigned) or | ||
averaged over the dataset as a whole (`label_id` will be None) | ||
Score values range from 0 to 1 | ||
""" | ||
|
||
name: str | ||
value: float | ||
label_id: Optional[str] = None | ||
|
||
|
||
@attr.define() | ||
class TestResult: | ||
""" | ||
Representation of the results of a model test job that was run for a specific | ||
model and dataset in an Intel® Geti™ project | ||
""" | ||
|
||
datasets_info: List[DatasetInfo] | ||
id: str | ||
job_info: JobInfo | ||
model_info: ModelInfo | ||
name: str | ||
scores: List[Score] | ||
creation_time: Optional[str] = attr.field(default=None, converter=str_to_datetime) | ||
|
||
def get_mean_score(self) -> Score: | ||
""" | ||
Return the mean score computed over the full dataset | ||
:return: Mean score on the dataset | ||
""" | ||
if not self.job_info.is_done: | ||
raise ValueError( | ||
"Unable to retrieve mean model score, the model testing job is not " | ||
"finished yet." | ||
) | ||
return [score for score in self.scores if score.label_id is None][0] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
# Copyright (C) 2023 Intel Corporation | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions | ||
# and limitations under the License. | ||
import logging | ||
from typing import List, Optional, Sequence | ||
|
||
from geti_sdk.data_models import Dataset, Job, Model, Project, TestResult | ||
from geti_sdk.http_session import GetiSession | ||
from geti_sdk.rest_converters import TestResultRESTConverter | ||
from geti_sdk.utils.job_helpers import get_job_with_timeout, monitor_jobs | ||
|
||
SUPPORTED_METRICS = ["global", "local"] | ||
|
||
|
||
class TestingClient: | ||
""" | ||
Class to manage testing jobs for a certain Intel® Geti™ project. | ||
""" | ||
|
||
def __init__(self, workspace_id: str, project: Project, session: GetiSession): | ||
self.session = session | ||
self.project = project | ||
self.workspace_id = workspace_id | ||
self.base_url = f"workspaces/{workspace_id}/projects/{project.id}/tests" | ||
|
||
def test_model( | ||
self, | ||
model: Model, | ||
datasets: Sequence[Dataset], | ||
name: Optional[str] = None, | ||
metric: Optional[str] = None, | ||
) -> Job: | ||
""" | ||
Start a model testing job for a specific `model` and `dataset` | ||
:param model: Model to evaluate | ||
:param datasets: Testing dataset(s) to evaluate the model on | ||
:param name: Optional name to assign to the testing job | ||
:param metric: Optional metric to calculate. This is only valid for either | ||
anomaly segmentation or anomaly detection models. Possible values are | ||
`global` or `local` | ||
:return: Job object representing the testing job | ||
""" | ||
if name is None: | ||
name = f"Test_{model.name}_{model.version}" | ||
dataset_ids = [ds.id for ds in datasets] | ||
|
||
test_data = { | ||
"name": name, | ||
"model_group_id": model.model_group_id, | ||
"model_id": model.id, | ||
"dataset_ids": dataset_ids, | ||
} | ||
if metric is not None: | ||
if metric not in SUPPORTED_METRICS: | ||
raise ValueError( | ||
f"Invalid metric received! Only `{SUPPORTED_METRICS}` are " | ||
f"supported currently." | ||
) | ||
test_data.update({"metric": metric}) | ||
|
||
response = self.session.get_rest_response( | ||
url=self.base_url, method="POST", data=test_data | ||
) | ||
job = get_job_with_timeout( | ||
job_id=response["job_ids"][0], | ||
session=self.session, | ||
workspace_id=self.workspace_id, | ||
job_type="testing", | ||
) | ||
logging.info(f"Testing job with id {job.id} submitted successfully.") | ||
return job | ||
|
||
def get_test_result(self, job: Job) -> TestResult: | ||
""" | ||
Retrieve the result of the model testing job from the Intel® Geti™ | ||
server | ||
:param job: Job instance representing the model testing job | ||
:return: TestResult instance containing the test results | ||
""" | ||
response = self.session.get_rest_response(url=self.base_url, method="GET") | ||
test_results = [ | ||
TestResultRESTConverter.from_dict(result) | ||
for result in response["test_results"] | ||
] | ||
result_for_job = next( | ||
(result for result in test_results if result.job_info.id == job.id), None | ||
) | ||
if result_for_job is None: | ||
raise ValueError( | ||
f"Unable to find test result for job `{job.name}`, please make sure " | ||
f"that a valid testing job was passed." | ||
) | ||
return result_for_job | ||
|
||
def monitor_jobs( | ||
self, jobs: List[Job], timeout: int = 10000, interval: int = 15 | ||
) -> List[Job]: | ||
""" | ||
Monitor and print the progress of all jobs in the list `jobs`. Execution is | ||
halted until all jobs have either finished, failed or were cancelled. | ||
Progress will be reported in 15s intervals | ||
:param jobs: List of jobs to monitor | ||
:param timeout: Timeout (in seconds) after which to stop the monitoring | ||
:param interval: Time interval (in seconds) at which the TrainingClient polls | ||
the server to update the status of the jobs. Defaults to 15 seconds | ||
:return: List of finished (or failed) jobs with their status updated | ||
""" | ||
return monitor_jobs( | ||
session=self.session, jobs=jobs, timeout=timeout, interval=interval | ||
) |
Oops, something went wrong.