Skip to content

Commit

Permalink
fix: added proto message conversion to MDMJob.update fields (#1718)
Browse files Browse the repository at this point in the history
* fix: added proto message conversion to MDMJob.update fields

* 馃 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed PR comment

* formatting

* 馃 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* replaced string literal with constant

* adding _gca_resource re-assignmnet to mdm job class

* Added side effects in get_mdm_job pytest mock

* fixing side effects

* formatting

* 馃 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* minor edits to variable names

* Addressed PR feedback

* 馃 Updates from OwlBot post-processor

See https://github.com/googleapis/repo-automation-bots/blob/main/packages/owl-bot/README.md

* addressed more PR commentes

* addressed PR comments

* fix linter errors

Co-authored-by: Owl Bot <gcf-owl-bot[bot]@users.noreply.github.com>
  • Loading branch information
rosiezou and gcf-owl-bot[bot] committed Oct 19, 2022
1 parent 3747ce3 commit 9e77c61
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 49 deletions.
25 changes: 15 additions & 10 deletions google/cloud/aiplatform/jobs.py
Expand Up @@ -2427,7 +2427,8 @@ def update(
are allowed. See https://goo.gl/xmQnxf for more information
and examples of labels.
bigquery_tables_log_ttl (int):
Optional. The TTL(time to live) of BigQuery tables in user projects
Optional. The number of days for which the logs are stored.
The TTL(time to live) of BigQuery tables in user projects
which stores logs. A day is the basic unit of
the TTL and we take the ceil of TTL/86400(a
day). e.g. { second: 3600} indicates ttl = 1
Expand All @@ -2453,28 +2454,30 @@ def update(
will be applied to all deployed models.
"""
self._sync_gca_resource()
current_job = self.api_client.get_model_deployment_monitoring_job(
name=self._gca_resource.name
)
current_job = copy.deepcopy(self._gca_resource)
update_mask: List[str] = []
if display_name is not None:
update_mask.append("display_name")
current_job.display_name = display_name
if schedule_config is not None:
update_mask.append("model_deployment_monitoring_schedule_config")
current_job.model_deployment_monitoring_schedule_config = schedule_config
current_job.model_deployment_monitoring_schedule_config = (
schedule_config.as_proto()
)
if alert_config is not None:
update_mask.append("model_monitoring_alert_config")
current_job.model_monitoring_alert_config = alert_config
current_job.model_monitoring_alert_config = alert_config.as_proto()
if logging_sampling_strategy is not None:
update_mask.append("logging_sampling_strategy")
current_job.logging_sampling_strategy = logging_sampling_strategy
current_job.logging_sampling_strategy = logging_sampling_strategy.as_proto()
if labels is not None:
update_mask.append("labels")
current_job.lables = labels
current_job.labels = labels
if bigquery_tables_log_ttl is not None:
update_mask.append("log_ttl")
current_job.log_ttl = bigquery_tables_log_ttl
current_job.log_ttl = duration_pb2.Duration(
seconds=bigquery_tables_log_ttl * 86400
)
if enable_monitoring_pipeline_logs is not None:
update_mask.append("enable_monitoring_pipeline_logs")
current_job.enable_monitoring_pipeline_logs = (
Expand All @@ -2491,10 +2494,12 @@ def update(
deployed_model_ids=deployed_model_ids,
)
)
self.api_client.update_model_deployment_monitoring_job(
# TODO: b/254285776 add optional_sync support to model monitoring job
lro = self.api_client.update_model_deployment_monitoring_job(
model_deployment_monitoring_job=current_job,
update_mask=field_mask_pb2.FieldMask(paths=update_mask),
)
self._gca_resource = lro.result()
return self

def pause(self) -> "ModelDeploymentMonitoringJob":
Expand Down
150 changes: 111 additions & 39 deletions tests/unit/aiplatform/test_jobs.py
Expand Up @@ -16,6 +16,7 @@
#

import pytest
import copy

from unittest import mock
from importlib import reload
Expand All @@ -24,6 +25,7 @@
from google.cloud import storage
from google.cloud import bigquery

from google.api_core import operation
from google.auth import credentials as auth_credentials

from google.cloud import aiplatform
Expand All @@ -46,7 +48,9 @@
job_service_client,
)
from google.protobuf import field_mask_pb2 # type: ignore
from google.protobuf import duration_pb2 # type: ignore

import test_endpoints # noqa: F401
from test_endpoints import get_endpoint_with_models_mock # noqa: F401

_TEST_API_CLIENT = job_service_client.JobServiceClient
Expand Down Expand Up @@ -175,6 +179,58 @@
_TEST_JOB_RESOURCE_NAME = f"{_TEST_PARENT}/customJobs/{_TEST_ID}"

_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG = {"TEST_KEY": 0.01}
_TEST_MDM_USER_EMAIL = "TEST_EMAIL"
_TEST_MDM_SAMPLE_RATE = 0.5
_TEST_MDM_LABEL = {"TEST KEY": "TEST VAL"}
_TEST_LOG_TTL_IN_DAYS = 1
_TEST_MDM_NEW_NAME = "NEW_NAME"

_TEST_MDM_OLD_JOB = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
endpoint=_TEST_ENDPOINT,
state=_TEST_JOB_STATE_RUNNING,
)
)

_TEST_MDM_EXPECTED_NEW_JOB = gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_MDM_NEW_NAME,
endpoint=_TEST_ENDPOINT,
state=_TEST_JOB_STATE_RUNNING,
model_deployment_monitoring_objective_configs=[
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id=model_id,
objective_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
drift_thresholds={
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(
value=0.01
)
}
)
),
)
for model_id in [model.id for model in test_endpoints._TEST_DEPLOYED_MODELS]
],
logging_sampling_strategy=gca_model_monitoring_compat.SamplingStrategy(
random_sample_config=gca_model_monitoring_compat.SamplingStrategy.RandomSampleConfig(
sample_rate=_TEST_MDM_SAMPLE_RATE
)
),
labels=_TEST_MDM_LABEL,
model_monitoring_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig(
email_alert_config=gca_model_monitoring_compat.ModelMonitoringAlertConfig.EmailAlertConfig(
user_emails=[_TEST_MDM_USER_EMAIL]
)
),
model_deployment_monitoring_schedule_config=gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringScheduleConfig(
monitor_interval=duration_pb2.Duration(seconds=3600)
),
log_ttl=duration_pb2.Duration(seconds=_TEST_LOG_TTL_IN_DAYS * 86400),
enable_monitoring_pipeline_logs=True,
)

# TODO(b/171333554): Move reusable test fixtures to conftest.py file

Expand Down Expand Up @@ -988,48 +1044,23 @@ def get_mdm_job_mock():
with mock.patch.object(
_TEST_API_CLIENT, "get_model_deployment_monitoring_job"
) as get_mdm_job_mock:
get_mdm_job_mock.return_value = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
)
)
get_mdm_job_mock.side_effect = [
_TEST_MDM_OLD_JOB,
_TEST_MDM_OLD_JOB,
_TEST_MDM_OLD_JOB,
_TEST_MDM_EXPECTED_NEW_JOB,
]
yield get_mdm_job_mock


@pytest.fixture
@pytest.mark.usefixtures("get_mdm_job_mock")
def update_mdm_job_mock(get_endpoint_with_models_mock): # noqa: F811
with mock.patch.object(
_TEST_API_CLIENT, "update_model_deployment_monitoring_job"
) as update_mdm_job_mock:
expected_objective_config = gca_model_monitoring_compat.ModelMonitoringObjectiveConfig(
prediction_drift_detection_config=gca_model_monitoring_compat.ModelMonitoringObjectiveConfig.PredictionDriftDetectionConfig(
drift_thresholds={
"TEST_KEY": gca_model_monitoring_compat.ThresholdConfig(value=0.01)
}
)
)
all_configs = []
for model in get_endpoint_with_models_mock.return_value.deployed_models:
all_configs.append(
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringObjectiveConfig(
deployed_model_id=model.id,
objective_config=expected_objective_config,
)
)

update_mdm_job_mock.return_vaue.result_type = (
gca_model_deployment_monitoring_job_compat.ModelDeploymentMonitoringJob(
name=_TEST_MDM_JOB_NAME,
display_name=_TEST_DISPLAY_NAME,
state=_TEST_JOB_STATE_RUNNING,
endpoint=_TEST_ENDPOINT,
model_deployment_monitoring_objective_configs=all_configs,
)
)
update_mdm_job_lro_mock = mock.Mock(operation.Operation)
update_mdm_job_lro_mock.result.return_value = _TEST_MDM_EXPECTED_NEW_JOB
update_mdm_job_mock.return_value = update_mdm_job_lro_mock
yield update_mdm_job_mock


Expand All @@ -1046,25 +1077,66 @@ def test_update_mdm_job(self, get_mdm_job_mock, update_mdm_job_mock):
job = jobs.ModelDeploymentMonitoringJob(
model_deployment_monitoring_job_name=_TEST_MDM_JOB_NAME
)
old_job = copy.deepcopy(job._gca_resource)
drift_detection_config = aiplatform.model_monitoring.DriftDetectionConfig(
drift_thresholds=_TEST_MDM_JOB_DRIFT_DETECTION_CONFIG
)
schedule_config = aiplatform.model_monitoring.ScheduleConfig(monitor_interval=1)
alert_config = aiplatform.model_monitoring.EmailAlertConfig(
user_emails=[_TEST_MDM_USER_EMAIL]
)
sampling_strategy = aiplatform.model_monitoring.RandomSampleConfig(
sample_rate=_TEST_MDM_SAMPLE_RATE
)
labels = _TEST_MDM_LABEL
log_ttl = _TEST_LOG_TTL_IN_DAYS
display_name = _TEST_MDM_NEW_NAME
new_config = aiplatform.model_monitoring.ObjectiveConfig(
drift_detection_config=drift_detection_config
)
job.update(objective_configs=new_config)
job.update(
display_name=display_name,
schedule_config=schedule_config,
alert_config=alert_config,
logging_sampling_strategy=sampling_strategy,
labels=labels,
bigquery_tables_log_ttl=log_ttl,
enable_monitoring_pipeline_logs=True,
objective_configs=new_config,
)
new_job = job._gca_resource
assert old_job != new_job
assert new_job.display_name == display_name
assert new_job.logging_sampling_strategy == sampling_strategy.as_proto()
assert (
new_job.model_deployment_monitoring_schedule_config
== schedule_config.as_proto()
)
assert new_job.labels == labels
assert new_job.model_monitoring_alert_config == alert_config.as_proto()
assert new_job.log_ttl.days == _TEST_LOG_TTL_IN_DAYS
assert new_job.enable_monitoring_pipeline_logs
assert (
job._gca_resource.model_deployment_monitoring_objective_configs[
new_job.model_deployment_monitoring_objective_configs[
0
].objective_config.prediction_drift_detection_config
== drift_detection_config.as_proto()
)
get_mdm_job_mock.assert_called_with(
name=_TEST_MDM_JOB_NAME,
name=_TEST_MDM_JOB_NAME, retry=base._DEFAULT_RETRY
)
update_mdm_job_mock.assert_called_once_with(
model_deployment_monitoring_job=get_mdm_job_mock.return_value,
model_deployment_monitoring_job=new_job,
update_mask=field_mask_pb2.FieldMask(
paths=["model_deployment_monitoring_objective_configs"]
paths=[
"display_name",
"model_deployment_monitoring_schedule_config",
"model_monitoring_alert_config",
"logging_sampling_strategy",
"labels",
"log_ttl",
"enable_monitoring_pipeline_logs",
"model_deployment_monitoring_objective_configs",
]
),
)

0 comments on commit 9e77c61

Please sign in to comment.