Skip to content

Commit

Permalink
chore: Support both full resource name and resource id in Model Monit…
Browse files Browse the repository at this point in the history
…oring SDK.

PiperOrigin-RevId: 631518248
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed May 7, 2024
1 parent e47d436 commit c03767c
Showing 1 changed file with 111 additions and 27 deletions.
138 changes: 111 additions & 27 deletions vertexai/resources/preview/ml_monitoring/model_monitors.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import copy
import dataclasses
import json
import re
from typing import Any, Dict, List, Optional

from google.auth import credentials as auth_credentials
Expand Down Expand Up @@ -237,6 +238,53 @@ def _transform_field_schema(
return result


def _get_schedule_name(
schedule_name: str
) -> str:
if schedule_name:
client = initializer.global_config.create_client(
client_class=utils.ScheduleClientWithOverride,
)
if client.parse_schedule_path(schedule_name):
return schedule_name
elif re.match("^{}$".format("[0-9]{0,127}"), schedule_name):
return client.schedule_path(
project=initializer.global_config.project,
location=initializer.global_config.location,
schedule=schedule_name,
)
else:
raise ValueError(
"schedule name must be of the format `projects/{project}/locations/{location}/schedules/{schedule}` or `{schedule}`"
)
return schedule_name


def _get_model_monitoring_job_name(
model_monitoring_job_name: str,
model_monitor_name: str,
) -> str:
if model_monitoring_job_name:
client = initializer.global_config.create_client(
client_class=utils.ModelMonitoringClientWithOverride,
)
if client.parse_model_monitoring_job_path(model_monitoring_job_name):
return model_monitoring_job_name
elif re.match("^{}$".format("[0-9]{0,127}"), model_monitoring_job_name):
model_monitor_name = model_monitor_name.split("/")[-1]
return client.model_monitoring_job_path(
project=initializer.global_config.project,
location=initializer.global_config.location,
model_monitor=model_monitor_name,
model_monitoring_job=model_monitoring_job_name,
)
else:
raise ValueError(
"model monitoring job name must be of the format `projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}` or `{model_monitoring_job}`"
)
return model_monitoring_job_name


@dataclasses.dataclass
class MetricsSearchResponse:
"""MetricsSearchResponse represents a response of the search metrics request.
Expand Down Expand Up @@ -784,6 +832,8 @@ def update_schedule(
schedule_name (str):
Required. The resource name of schedule that needs to be updated.
Format: ``projects/{project}/locations/{location}/schedules/{schedule}``
or
``{schedule}``
display_name (str):
Optional. The user-defined name of the Schedule.
The name can be up to 128 characters long and can be consist of
Expand Down Expand Up @@ -833,8 +883,9 @@ def update_schedule(
project=self.project,
location=self.location,
)

current_schedule = copy.deepcopy(self.get_schedule(schedule_name=schedule_name))
schedule_name = _get_schedule_name(schedule_name)
current_schedule = copy.deepcopy(
self.get_schedule(schedule_name=schedule_name))
update_mask = []
if display_name is not None:
update_mask.append("display_name")
Expand Down Expand Up @@ -911,13 +962,17 @@ def delete_schedule(self, schedule_name: str) -> None:
schedule_name (str):
Required. The resource name of schedule that needs to be deleted.
Format: ``projects/{project}/locations/{location}/schedules/{schedule}``
or
``{schedule}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ScheduleClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
api_client.select_version("v1beta1").delete_schedule(name=schedule_name)
schedule_name = _get_schedule_name(schedule_name)
return api_client.select_version("v1beta1").delete_schedule(
name=schedule_name)

def pause_schedule(self, schedule_name: str) -> None:
"""Pauses an existing Schedule.
Expand All @@ -926,13 +981,17 @@ def pause_schedule(self, schedule_name: str) -> None:
schedule_name (str):
Required. The resource name of schedule that needs to be paused.
Format: ``projects/{project}/locations/{location}/schedules/{schedule}``
or
``{schedule}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ScheduleClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
api_client.select_version("v1beta1").pause_schedule(name=schedule_name)
schedule_name = _get_schedule_name(schedule_name)
return api_client.select_version("v1beta1").pause_schedule(
name=schedule_name)

def resume_schedule(self, schedule_name: str) -> None:
"""Resumes an existing Schedule.
Expand All @@ -941,13 +1000,17 @@ def resume_schedule(self, schedule_name: str) -> None:
schedule_name (str):
Required. The resource name of schedule that needs to be resumed.
Format: ``projects/{project}/locations/{location}/schedules/{schedule}``
or
``{schedule}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ScheduleClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
api_client.select_version("v1beta1").resume_schedule(name=schedule_name)
schedule_name = _get_schedule_name(schedule_name)
return api_client.select_version("v1beta1").resume_schedule(
name=schedule_name)

def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule":
"""Gets an existing Schedule.
Expand All @@ -956,6 +1019,8 @@ def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule":
schedule_name (str):
Required. The resource name of schedule that needs to be fetched.
Format: ``projects/{project}/locations/{location}/schedules/{schedule}``
or
``{schedule}``
Returns:
Schedule: The schedule requested.
Expand All @@ -965,6 +1030,7 @@ def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule":
credentials=self.credentials,
location_override=self.location,
)
schedule_name = _get_schedule_name(schedule_name)
return api_client.select_version("v1beta1").get_schedule(name=schedule_name)

def list_schedules(
Expand Down Expand Up @@ -1375,13 +1441,17 @@ def delete_model_monitoring_job(self, model_monitoring_job_name: str) -> None:
needs to be deleted.
Format:
``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}``
or
``{model_monitoring_job}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ModelMonitoringClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
api_client.delete_model_monitoring_job(name=model_monitoring_job_name)
job_resource_name = _get_model_monitoring_job_name(
model_monitoring_job_name, self._gca_resource.name)
api_client.delete_model_monitoring_job(name=job_resource_name)

def get_model_monitoring_job(
self, model_monitoring_job_name: str
Expand All @@ -1400,21 +1470,14 @@ def get_model_monitoring_job(
ModelMonitoringJob: The model monitoring job get.
"""
self.wait()
if model_monitoring_job_name.startswith("projects/"):
return ModelMonitoringJob(
model_monitoring_job_name=model_monitoring_job_name,
project=self.project,
location=self.location,
credentials=self.credentials,
)
else:
return ModelMonitoringJob(
model_monitoring_job_name=model_monitoring_job_name,
model_monitor_id=self._gca_resource.name,
project=self.project,
location=self.location,
credentials=self.credentials,
)
job_resource_name = _get_model_monitoring_job_name(
model_monitoring_job_name, self._gca_resource.name)
return ModelMonitoringJob(
model_monitoring_job_name=job_resource_name,
project=self.project,
location=self.location,
credentials=self.credentials,
)

def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None:
"""The method to visualize the feature drift result from a model monitoring job as a histogram chart and a table.
Expand All @@ -1424,17 +1487,24 @@ def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None:
Required. The resource name of model monitoring job to show the
drift stats from.
Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}``
or
``{model_monitoring_job}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ModelMonitoringClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
job = api_client.get_model_monitoring_job(name=model_monitoring_job_name)
if model_monitoring_job_name.startswith("projects/"):
job_resource_name = model_monitoring_job_name
job_id = model_monitoring_job_name.split("/")[-1]
else:
job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}"
job_id = model_monitoring_job_name
job = api_client.get_model_monitoring_job(name=job_resource_name)
output_directory = (
job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix
)
job_id = model_monitoring_job_name.split("/")[-1]
target_output, baseline_output = _feature_drift_stats_output_path(
output_directory, job_id
)
Expand All @@ -1455,17 +1525,24 @@ def show_output_drift_stats(self, model_monitoring_job_name: str) -> None:
Required. The resource name of model monitoring job to show the
drift stats from.
Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}``
or
``{model_monitoring_job}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ModelMonitoringClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
job = api_client.get_model_monitoring_job(name=model_monitoring_job_name)
if model_monitoring_job_name.startswith("projects/"):
job_resource_name = model_monitoring_job_name
job_id = model_monitoring_job_name.split("/")[-1]
else:
job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}"
job_id = model_monitoring_job_name
job = api_client.get_model_monitoring_job(name=job_resource_name)
output_directory = (
job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix
)
job_id = model_monitoring_job_name.split("/")[-1]
target_output, baseline_output = _prediction_output_stats_output_path(
output_directory, job_id
)
Expand All @@ -1486,17 +1563,24 @@ def show_feature_attribution_drift_stats(
feature attribution drift stats from.
Format:
``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}``
or
``{model_monitoring_job}``
"""
api_client = initializer.global_config.create_client(
client_class=utils.ModelMonitoringClientWithOverride,
credentials=self.credentials,
location_override=self.location,
)
job = api_client.get_model_monitoring_job(name=model_monitoring_job_name)
if model_monitoring_job_name.startswith("projects/"):
job_resource_name = model_monitoring_job_name
job_id = model_monitoring_job_name.split("/")[-1]
else:
job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}"
job_id = model_monitoring_job_name
job = api_client.get_model_monitoring_job(name=job_resource_name)
output_directory = (
job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix
)
job_id = model_monitoring_job_name.split("/")[-1]
target_stats_output = _feature_attribution_target_stats_output_path(
output_directory, job_id
)
Expand Down

0 comments on commit c03767c

Please sign in to comment.