From c03767ca9ee23bce2f9738a265f5025dc5bce024 Mon Sep 17 00:00:00 2001 From: A Vertex SDK engineer Date: Tue, 7 May 2024 12:54:33 -0700 Subject: [PATCH] chore: Support both full resource name and resource id in Model Monitoring SDK. PiperOrigin-RevId: 631518248 --- .../preview/ml_monitoring/model_monitors.py | 138 ++++++++++++++---- 1 file changed, 111 insertions(+), 27 deletions(-) diff --git a/vertexai/resources/preview/ml_monitoring/model_monitors.py b/vertexai/resources/preview/ml_monitoring/model_monitors.py index dfe18f65f4..6e26858f08 100644 --- a/vertexai/resources/preview/ml_monitoring/model_monitors.py +++ b/vertexai/resources/preview/ml_monitoring/model_monitors.py @@ -18,6 +18,7 @@ import copy import dataclasses import json +import re from typing import Any, Dict, List, Optional from google.auth import credentials as auth_credentials @@ -237,6 +238,53 @@ def _transform_field_schema( return result +def _get_schedule_name( + schedule_name: str +) -> str: + if schedule_name: + client = initializer.global_config.create_client( + client_class=utils.ScheduleClientWithOverride, + ) + if client.parse_schedule_path(schedule_name): + return schedule_name + elif re.match("^{}$".format("[0-9]{0,127}"), schedule_name): + return client.schedule_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + schedule=schedule_name, + ) + else: + raise ValueError( + "schedule name must be of the format `projects/{project}/locations/{location}/schedules/{schedule}` or `{schedule}`" + ) + return schedule_name + + +def _get_model_monitoring_job_name( + model_monitoring_job_name: str, + model_monitor_name: str, +) -> str: + if model_monitoring_job_name: + client = initializer.global_config.create_client( + client_class=utils.ModelMonitoringClientWithOverride, + ) + if client.parse_model_monitoring_job_path(model_monitoring_job_name): + return model_monitoring_job_name + elif re.match("^{}$".format("[0-9]{0,127}"), model_monitoring_job_name): + model_monitor_name = model_monitor_name.split("/")[-1] + return client.model_monitoring_job_path( + project=initializer.global_config.project, + location=initializer.global_config.location, + model_monitor=model_monitor_name, + model_monitoring_job=model_monitoring_job_name, + ) + else: + raise ValueError( + "model monitoring job name must be of the format `projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}` or `{model_monitoring_job}`" + ) + return model_monitoring_job_name + + @dataclasses.dataclass class MetricsSearchResponse: """MetricsSearchResponse represents a response of the search metrics request. @@ -784,6 +832,8 @@ def update_schedule( schedule_name (str): Required. The resource name of schedule that needs to be updated. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` display_name (str): Optional. The user-defined name of the Schedule. The name can be up to 128 characters long and can be consist of @@ -833,8 +883,9 @@ def update_schedule( project=self.project, location=self.location, ) - - current_schedule = copy.deepcopy(self.get_schedule(schedule_name=schedule_name)) + schedule_name = _get_schedule_name(schedule_name) + current_schedule = copy.deepcopy( + self.get_schedule(schedule_name=schedule_name)) update_mask = [] if display_name is not None: update_mask.append("display_name") @@ -911,13 +962,17 @@ def delete_schedule(self, schedule_name: str) -> None: schedule_name (str): Required. The resource name of schedule that needs to be deleted. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` """ api_client = initializer.global_config.create_client( client_class=utils.ScheduleClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.select_version("v1beta1").delete_schedule(name=schedule_name) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").delete_schedule( + name=schedule_name) def pause_schedule(self, schedule_name: str) -> None: """Pauses an existing Schedule. @@ -926,13 +981,17 @@ def pause_schedule(self, schedule_name: str) -> None: schedule_name (str): Required. The resource name of schedule that needs to be paused. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` """ api_client = initializer.global_config.create_client( client_class=utils.ScheduleClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.select_version("v1beta1").pause_schedule(name=schedule_name) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").pause_schedule( + name=schedule_name) def resume_schedule(self, schedule_name: str) -> None: """Resumes an existing Schedule. @@ -941,13 +1000,17 @@ def resume_schedule(self, schedule_name: str) -> None: schedule_name (str): Required. The resource name of schedule that needs to be resumed. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` """ api_client = initializer.global_config.create_client( client_class=utils.ScheduleClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.select_version("v1beta1").resume_schedule(name=schedule_name) + schedule_name = _get_schedule_name(schedule_name) + return api_client.select_version("v1beta1").resume_schedule( + name=schedule_name) def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": """Gets an existing Schedule. @@ -956,6 +1019,8 @@ def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": schedule_name (str): Required. The resource name of schedule that needs to be fetched. Format: ``projects/{project}/locations/{location}/schedules/{schedule}`` + or + ``{schedule}`` Returns: Schedule: The schedule requested. @@ -965,6 +1030,7 @@ def get_schedule(self, schedule_name: str) -> "gca_schedule.Schedule": credentials=self.credentials, location_override=self.location, ) + schedule_name = _get_schedule_name(schedule_name) return api_client.select_version("v1beta1").get_schedule(name=schedule_name) def list_schedules( @@ -1375,13 +1441,17 @@ def delete_model_monitoring_job(self, model_monitoring_job_name: str) -> None: needs to be deleted. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - api_client.delete_model_monitoring_job(name=model_monitoring_job_name) + job_resource_name = _get_model_monitoring_job_name( + model_monitoring_job_name, self._gca_resource.name) + api_client.delete_model_monitoring_job(name=job_resource_name) def get_model_monitoring_job( self, model_monitoring_job_name: str @@ -1400,21 +1470,14 @@ def get_model_monitoring_job( ModelMonitoringJob: The model monitoring job get. """ self.wait() - if model_monitoring_job_name.startswith("projects/"): - return ModelMonitoringJob( - model_monitoring_job_name=model_monitoring_job_name, - project=self.project, - location=self.location, - credentials=self.credentials, - ) - else: - return ModelMonitoringJob( - model_monitoring_job_name=model_monitoring_job_name, - model_monitor_id=self._gca_resource.name, - project=self.project, - location=self.location, - credentials=self.credentials, - ) + job_resource_name = _get_model_monitoring_job_name( + model_monitoring_job_name, self._gca_resource.name) + return ModelMonitoringJob( + model_monitoring_job_name=job_resource_name, + project=self.project, + location=self.location, + credentials=self.credentials, + ) def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None: """The method to visualize the feature drift result from a model monitoring job as a histogram chart and a table. @@ -1424,17 +1487,24 @@ def show_feature_drift_stats(self, model_monitoring_job_name: str) -> None: Required. The resource name of model monitoring job to show the drift stats from. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - job = api_client.get_model_monitoring_job(name=model_monitoring_job_name) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) output_directory = ( job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix ) - job_id = model_monitoring_job_name.split("/")[-1] target_output, baseline_output = _feature_drift_stats_output_path( output_directory, job_id ) @@ -1455,17 +1525,24 @@ def show_output_drift_stats(self, model_monitoring_job_name: str) -> None: Required. The resource name of model monitoring job to show the drift stats from. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - job = api_client.get_model_monitoring_job(name=model_monitoring_job_name) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) output_directory = ( job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix ) - job_id = model_monitoring_job_name.split("/")[-1] target_output, baseline_output = _prediction_output_stats_output_path( output_directory, job_id ) @@ -1486,17 +1563,24 @@ def show_feature_attribution_drift_stats( feature attribution drift stats from. Format: ``projects/{project}/locations/{location}/modelMonitors/{model_monitor}/modelMonitoringJobs/{model_monitoring_job}`` + or + ``{model_monitoring_job}`` """ api_client = initializer.global_config.create_client( client_class=utils.ModelMonitoringClientWithOverride, credentials=self.credentials, location_override=self.location, ) - job = api_client.get_model_monitoring_job(name=model_monitoring_job_name) + if model_monitoring_job_name.startswith("projects/"): + job_resource_name = model_monitoring_job_name + job_id = model_monitoring_job_name.split("/")[-1] + else: + job_resource_name = f"{self._gca_resource.name}/modelMonitoringJobs/{model_monitoring_job_name}" + job_id = model_monitoring_job_name + job = api_client.get_model_monitoring_job(name=job_resource_name) output_directory = ( job.model_monitoring_spec.output_spec.gcs_base_directory.output_uri_prefix ) - job_id = model_monitoring_job_name.split("/")[-1] target_stats_output = _feature_attribution_target_stats_output_path( output_directory, job_id )