Skip to content

Commit

Permalink
Merge pull request #221 from openvinotoolkit/add-TestingClient-to-per…
Browse files Browse the repository at this point in the history
…form-model-tests

Add `TestingClient` to perform model tests
  • Loading branch information
ljcornel committed May 17, 2023
2 parents f0139bc + 9fa428c commit 02a6a44
Show file tree
Hide file tree
Showing 13 changed files with 470 additions and 90 deletions.
3 changes: 3 additions & 0 deletions geti_sdk/data_models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@
from .project import Dataset, Pipeline, Project
from .status import ProjectStatus
from .task import Task
from .test_result import Score, TestResult

__all__ = [
"TaskType",
Expand Down Expand Up @@ -211,4 +212,6 @@
"Job",
"CodeDeploymentInformation",
"Dataset",
"TestResult",
"Score",
]
1 change: 1 addition & 0 deletions geti_sdk/data_models/enums/optimization_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class OptimizationType(Enum):
POT = "POT"
MO = "MO"
ONNX = "ONNX"
NONE = "NONE"

def __str__(self) -> str:
"""
Expand Down
117 changes: 117 additions & 0 deletions geti_sdk/data_models/test_result.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
from typing import List, Optional

import attr

from geti_sdk.data_models.enums import OptimizationType
from geti_sdk.data_models.utils import (
str_to_datetime,
str_to_enum_converter,
str_to_task_type,
)


@attr.define()
class DatasetInfo:
"""
Container for dataset information, specific to test datasets
"""

id: str
is_deleted: bool
n_frames: int
n_images: int
n_samples: int
name: str


@attr.define()
class JobInfo:
"""
Container for job information, specific to model testing jobs
"""

id: str
status: str

@property
def is_done(self) -> bool:
"""
Return True if the testing job has finished, False otherwise
:return: True for a finished job, False otherwise
"""
return self.status.lower() == "done"


@attr.define()
class ModelInfo:
"""
Container for information related to the model, specific for model tests
"""

group_id: str
id: str
n_labels: int
task_type: str = attr.field(converter=str_to_task_type)
template_id: str
optimization_type: str = attr.field(
converter=str_to_enum_converter(OptimizationType)
)
version: int


@attr.define()
class Score:
"""
Container class holding a score resulting from a model testing job. The metric
contained can either relate to a single label (`label_id` will be assigned) or
averaged over the dataset as a whole (`label_id` will be None)
Score values range from 0 to 1
"""

name: str
value: float
label_id: Optional[str] = None


@attr.define()
class TestResult:
"""
Representation of the results of a model test job that was run for a specific
model and dataset in an Intel® Geti™ project
"""

datasets_info: List[DatasetInfo]
id: str
job_info: JobInfo
model_info: ModelInfo
name: str
scores: List[Score]
creation_time: Optional[str] = attr.field(default=None, converter=str_to_datetime)

def get_mean_score(self) -> Score:
"""
Return the mean score computed over the full dataset
:return: Mean score on the dataset
"""
if not self.job_info.is_done:
raise ValueError(
"Unable to retrieve mean model score, the model testing job is not "
"finished yet."
)
return [score for score in self.scores if score.label_id is None][0]
4 changes: 4 additions & 0 deletions geti_sdk/deployment/deployed_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

import attr
import numpy as np
from otx.algorithms.classification.utils import get_cls_inferencer_configuration
from otx.api.entities.color import Color
from otx.api.entities.label import Domain as OTEDomain
from otx.api.entities.label import LabelEntity
Expand Down Expand Up @@ -305,6 +306,9 @@ def load_inference_model(
f"{wrapper_module_path}."
) from ex

if model_type == "otx_classification":
configuration = get_cls_inferencer_configuration(self.ote_label_schema)

model = OMZModel.create_model(
name=model_type,
model_adapter=model_adapter,
Expand Down
2 changes: 2 additions & 0 deletions geti_sdk/rest_clients/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@
from .model_client import ModelClient
from .prediction_client import PredictionClient
from .project_client import ProjectClient
from .testing_client import TestingClient
from .training_client import TrainingClient

__all__ = [
Expand All @@ -111,4 +112,5 @@
"TrainingClient",
"DeploymentClient",
"ActiveLearningClient",
"TestingClient",
]
124 changes: 124 additions & 0 deletions geti_sdk/rest_clients/testing_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# Copyright (C) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions
# and limitations under the License.
import logging
from typing import List, Optional, Sequence

from geti_sdk.data_models import Dataset, Job, Model, Project, TestResult
from geti_sdk.http_session import GetiSession
from geti_sdk.rest_converters import TestResultRESTConverter
from geti_sdk.utils.job_helpers import get_job_with_timeout, monitor_jobs

SUPPORTED_METRICS = ["global", "local"]


class TestingClient:
"""
Class to manage testing jobs for a certain Intel® Geti™ project.
"""

def __init__(self, workspace_id: str, project: Project, session: GetiSession):
self.session = session
self.project = project
self.workspace_id = workspace_id
self.base_url = f"workspaces/{workspace_id}/projects/{project.id}/tests"

def test_model(
self,
model: Model,
datasets: Sequence[Dataset],
name: Optional[str] = None,
metric: Optional[str] = None,
) -> Job:
"""
Start a model testing job for a specific `model` and `dataset`
:param model: Model to evaluate
:param datasets: Testing dataset(s) to evaluate the model on
:param name: Optional name to assign to the testing job
:param metric: Optional metric to calculate. This is only valid for either
anomaly segmentation or anomaly detection models. Possible values are
`global` or `local`
:return: Job object representing the testing job
"""
if name is None:
name = f"Test_{model.name}_{model.version}"
dataset_ids = [ds.id for ds in datasets]

test_data = {
"name": name,
"model_group_id": model.model_group_id,
"model_id": model.id,
"dataset_ids": dataset_ids,
}
if metric is not None:
if metric not in SUPPORTED_METRICS:
raise ValueError(
f"Invalid metric received! Only `{SUPPORTED_METRICS}` are "
f"supported currently."
)
test_data.update({"metric": metric})

response = self.session.get_rest_response(
url=self.base_url, method="POST", data=test_data
)
job = get_job_with_timeout(
job_id=response["job_ids"][0],
session=self.session,
workspace_id=self.workspace_id,
job_type="testing",
)
logging.info(f"Testing job with id {job.id} submitted successfully.")
return job

def get_test_result(self, job: Job) -> TestResult:
"""
Retrieve the result of the model testing job from the Intel® Geti™
server
:param job: Job instance representing the model testing job
:return: TestResult instance containing the test results
"""
response = self.session.get_rest_response(url=self.base_url, method="GET")
test_results = [
TestResultRESTConverter.from_dict(result)
for result in response["test_results"]
]
result_for_job = next(
(result for result in test_results if result.job_info.id == job.id), None
)
if result_for_job is None:
raise ValueError(
f"Unable to find test result for job `{job.name}`, please make sure "
f"that a valid testing job was passed."
)
return result_for_job

def monitor_jobs(
self, jobs: List[Job], timeout: int = 10000, interval: int = 15
) -> List[Job]:
"""
Monitor and print the progress of all jobs in the list `jobs`. Execution is
halted until all jobs have either finished, failed or were cancelled.
Progress will be reported in 15s intervals
:param jobs: List of jobs to monitor
:param timeout: Timeout (in seconds) after which to stop the monitoring
:param interval: Time interval (in seconds) at which the TrainingClient polls
the server to update the status of the jobs. Defaults to 15 seconds
:return: List of finished (or failed) jobs with their status updated
"""
return monitor_jobs(
session=self.session, jobs=jobs, timeout=timeout, interval=interval
)

0 comments on commit 02a6a44

Please sign in to comment.