diff --git a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py index 797f8c6f534..1cc7fa011e9 100644 --- a/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py +++ b/components/google-cloud/google_cloud_pipeline_components/container/v1/model/get_model/remote_runner.py @@ -16,6 +16,8 @@ from google.api_core.client_options import ClientOptions from google.cloud import aiplatform_v1 as aip_v1 from google_cloud_pipeline_components.container.utils import artifact_utils +from google_cloud_pipeline_components.container.utils import error_surfacing +from google_cloud_pipeline_components.proto import task_error_pb2 from google_cloud_pipeline_components.types import artifact_types @@ -26,11 +28,15 @@ def get_model( location: str, ) -> None: """Get model.""" + task_error = task_error_pb2.TaskError() if not location or not project: - raise ValueError( + model_name_error_message = ( 'Model resource name must be in the format' ' projects/{project}/locations/{location}/models/{model_name}' ) + task_error.error_message = model_name_error_message + error_surfacing.write_user_defined_error(executor_input, task_error) + raise ValueError(model_name_error_message) api_endpoint = location + '-aiplatform.googleapis.com' vertex_uri_prefix = f'https://{api_endpoint}/v1/' model_resource_name = ( @@ -40,7 +46,12 @@ def get_model( client_options = ClientOptions(api_endpoint=api_endpoint) client = aip_v1.ModelServiceClient(client_options=client_options) request = aip_v1.GetModelRequest(name=model_resource_name) - get_model_response = client.get_model(request) + try: + get_model_response = client.get_model(request) + except Exception as e: + task_error.error_message = str(e) + error_surfacing.write_user_defined_error(executor_input, task_error) + raise resp_model_name_without_version = get_model_response.name.split('@', 1)[0] model_resource_name = ( f'{resp_model_name_without_version}@{get_model_response.version_id}' diff --git a/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py b/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py index baaaec862ea..1dbf8868519 100755 --- a/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py +++ b/components/google-cloud/google_cloud_pipeline_components/proto/task_error_pb2.py @@ -5,7 +5,6 @@ """Generated protocol buffer code.""" from google.protobuf import descriptor as _descriptor from google.protobuf import descriptor_pool as _descriptor_pool -from google.protobuf import runtime_version as _runtime_version from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder # @@protoc_insertion_point(imports)