Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(components): Use GetModel integration test to manually test write_user_defined_error function #10843

Merged
merged 1 commit into from
May 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
from google_cloud_pipeline_components.proto import task_error_pb2


def write_user_defined_error(
def write_customized_error(
executor_input: str, error: task_error_pb2.TaskError
):
"""Writes a TaskError to a JSON file ('executor_error.json') in the output directory specified in the executor input.
"""Writes a TaskError customized by the author of the pipelines to a JSON file ('executor_error.json') in the output directory specified in the executor input.

Args:
executor_input: JSON string containing executor input data.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,50 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Remote runner for Get Model based on the Vertex AI SDK."""

import contextlib
from typing import Tuple, Type, Union
from google.api_core.client_options import ClientOptions
from google.cloud import aiplatform_v1 as aip_v1
from google_cloud_pipeline_components.container.utils import artifact_utils
from google_cloud_pipeline_components.container.utils import error_surfacing
from google_cloud_pipeline_components.proto import task_error_pb2
from google_cloud_pipeline_components.types import artifact_types


@contextlib.contextmanager
def catch_write_and_raise(
executor_input: str,
exception_types: Union[
Type[Exception], Tuple[Type[Exception], ...]
] = Exception,
):
"""Context manager to catch specified exceptions, log them using error_surfacing, and then re-raise."""
try:
yield
except exception_types as e:
task_error = task_error_pb2.TaskError()
task_error.error_message = str(e)
error_surfacing.write_customized_error(executor_input, task_error)
raise


def get_model(
executor_input,
model_name: str,
project: str,
location: str,
) -> None:
"""Get model."""
if not location or not project:
raise ValueError(
'Model resource name must be in the format'
' projects/{project}/locations/{location}/models/{model_name}'
)
with catch_write_and_raise(
executor_input=executor_input,
exception_types=ValueError,
):
if not location or not project:
model_name_error_message = (
'Model resource name must be in the format'
' projects/{project}/locations/{location}/models/{model_name}'
)
raise ValueError(model_name_error_message)
api_endpoint = location + '-aiplatform.googleapis.com'
vertex_uri_prefix = f'https://{api_endpoint}/v1/'
model_resource_name = (
Expand All @@ -40,7 +65,11 @@ def get_model(
client_options = ClientOptions(api_endpoint=api_endpoint)
client = aip_v1.ModelServiceClient(client_options=client_options)
request = aip_v1.GetModelRequest(name=model_resource_name)
get_model_response = client.get_model(request)
with catch_write_and_raise(
executor_input=executor_input,
exception_types=Exception,
):
get_model_response = client.get_model(request)
resp_model_name_without_version = get_model_response.name.split('@', 1)[0]
model_resource_name = (
f'{resp_model_name_without_version}@{get_model_response.version_id}'
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading