Skip to content

Commit

Permalink
feat(components): Use GetModel integration test to manually test writ…
Browse files Browse the repository at this point in the history
…e_user_defined_error function

Signed-off-by: Googler <nobody@google.com>
PiperOrigin-RevId: 635979715
  • Loading branch information
Googler committed May 23, 2024
1 parent c18ec0b commit e75e6da
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
from google.api_core.client_options import ClientOptions
from google.cloud import aiplatform_v1 as aip_v1
from google_cloud_pipeline_components.container.utils import artifact_utils
from google_cloud_pipeline_components.container.utils import error_surfacing
from google_cloud_pipeline_components.proto import task_error_pb2
from google_cloud_pipeline_components.types import artifact_types


Expand All @@ -26,11 +28,15 @@ def get_model(
location: str,
) -> None:
"""Get model."""
task_error = task_error_pb2.TaskError()
if not location or not project:
raise ValueError(
model_name_error_message = (
'Model resource name must be in the format'
' projects/{project}/locations/{location}/models/{model_name}'
)
task_error.error_message = model_name_error_message
error_surfacing.write_user_defined_error(executor_input, task_error)
raise ValueError(model_name_error_message)
api_endpoint = location + '-aiplatform.googleapis.com'
vertex_uri_prefix = f'https://{api_endpoint}/v1/'
model_resource_name = (
Expand All @@ -40,7 +46,12 @@ def get_model(
client_options = ClientOptions(api_endpoint=api_endpoint)
client = aip_v1.ModelServiceClient(client_options=client_options)
request = aip_v1.GetModelRequest(name=model_resource_name)
get_model_response = client.get_model(request)
try:
get_model_response = client.get_model(request)
except Exception as e:
task_error.error_message = str(e)
error_surfacing.write_user_defined_error(executor_input, task_error)
raise
resp_model_name_without_version = get_model_response.name.split('@', 1)[0]
model_resource_name = (
f'{resp_model_name_without_version}@{get_model_response.version_id}'
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e75e6da

Please sign in to comment.