Skip to content

Commit

Permalink
feat(components): Use GetModel integration test to manually test writ…
Browse files Browse the repository at this point in the history
…e_user_defined_error function

Signed-off-by: Googler <nobody@google.com>
PiperOrigin-RevId: 635979715
  • Loading branch information
Googler committed May 29, 2024
1 parent b4f91a3 commit 0ee006e
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from google_cloud_pipeline_components.proto import task_error_pb2


def write_user_defined_error(
def write_customized_error(
executor_input: str, error: task_error_pb2.TaskError
):
"""Writes a TaskError to a JSON file ('executor_error.json') in the output directory specified in the executor input.
"""Writes a TaskError customized by the author of the pipelines to a JSON file ('executor_error.json') in the output directory specified in the executor input.
Args:
executor_input: JSON string containing executor input data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,51 @@
# limitations under the License.
"""Remote runner for Get Model based on the Vertex AI SDK."""

import contextlib
from typing import Tuple, Type, Union

from google.api_core.client_options import ClientOptions
from google.cloud import aiplatform_v1 as aip_v1
from google_cloud_pipeline_components.container.utils import artifact_utils
from google_cloud_pipeline_components.container.utils import error_surfacing
from google_cloud_pipeline_components.proto import task_error_pb2
from google_cloud_pipeline_components.types import artifact_types


@contextlib.contextmanager
def catch_write_and_raise(
executor_input: str,
exception_types: Union[
Type[Exception], Tuple[Type[Exception], ...]
] = Exception,
) -> None:
"""Context manager to catch specified exceptions, log them using error_surfacing, and then re-raise."""
try:
yield
except exception_types as e:
task_error = task_error_pb2.TaskError()
task_error.error_message = str(e)
error_surfacing.write_customized_error(executor_input, task_error)
raise


def get_model(
executor_input,
model_name: str,
project: str,
location: str,
) -> None:
"""Get model."""
if not location or not project:
raise ValueError(
'Model resource name must be in the format'
' projects/{project}/locations/{location}/models/{model_name}'
)
with catch_write_and_raise(
exception_types=ValueError,
executor_input=executor_input,
):
if not location or not project:
model_name_error_message = (
'Model resource name must be in the format'
' projects/{project}/locations/{location}/models/{model_name}'
)
raise ValueError(model_name_error_message)
api_endpoint = location + '-aiplatform.googleapis.com'
vertex_uri_prefix = f'https://{api_endpoint}/v1/'
model_resource_name = (
Expand All @@ -40,7 +67,11 @@ def get_model(
client_options = ClientOptions(api_endpoint=api_endpoint)
client = aip_v1.ModelServiceClient(client_options=client_options)
request = aip_v1.GetModelRequest(name=model_resource_name)
get_model_response = client.get_model(request)
with catch_write_and_raise(
exception_types=Exception,
executor_input=executor_input,
):
get_model_response = client.get_model(request)
resp_model_name_without_version = get_model_response.name.split('@', 1)[0]
model_resource_name = (
f'{resp_model_name_without_version}@{get_model_response.version_id}'
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 0ee006e

Please sign in to comment.