From c5e755149854e8cf3b59803a0c4204f7a69ca725 Mon Sep 17 00:00:00 2001 From: Dmitrii Cherkasov Date: Fri, 15 Sep 2023 12:13:14 -0700 Subject: [PATCH] Adds opctl environment validator decorator. --- ads/opctl/decorator/common.py | 61 +++++++++++++++++-- ads/opctl/operator/cmd.py | 4 +- .../opctl/test_opctl_decorators.py | 38 ++++++++++++ 3 files changed, 97 insertions(+), 6 deletions(-) create mode 100644 tests/unitary/with_extras/opctl/test_opctl_decorators.py diff --git a/ads/opctl/decorator/common.py b/ads/opctl/decorator/common.py index 98100977c..4d032a2d3 100644 --- a/ads/opctl/decorator/common.py +++ b/ads/opctl/decorator/common.py @@ -4,22 +4,73 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from typing import Dict, Callable from functools import wraps +from typing import Callable, Dict, List + +from ads.config import ( + JOB_RUN_OCID, + NB_SESSION_OCID, + PIPELINE_RUN_OCID, + DATAFLOW_RUN_OCID, + MD_OCID, +) RUN_ID_FIELD = "run_id" -def print_watch_command(func: callable)->Callable: + +class OpctlEnvironmentError(Exception): + """The custom error to validate OPCTL environment.""" + + NOT_SUPPORTED_ENVIRONMENTS = ( + "Notebook Sessions", + "Data Science Jobs", + "ML Pipelines", + "Data Flow Applications", + ) + + def __init__(self): + super().__init__( + "This operation cannot be executed in the current environment. " + f"It is not supported in: {', '.join(self.NOT_SUPPORTED_ENVIRONMENTS)}." + ) + + +def print_watch_command(func: callable) -> Callable: """The decorator to help build the `opctl watch` command.""" + @wraps(func) - def wrapper(*args, **kwargs)->Dict: + def wrapper(*args: List, **kwargs: Dict) -> Dict: result = func(*args, **kwargs) if result and isinstance(result, Dict) and RUN_ID_FIELD in result: msg_header = ( - f"{'*' * 40} To monitor the progress of a task, execute the following command {'*' * 40}" + f"{'*' * 40} To monitor the progress of the task, " + "execute the following command {'*' * 40}" ) print(msg_header) print(f"ads opctl watch {result[RUN_ID_FIELD]}") print("*" * len(msg_header)) return result - return wrapper \ No newline at end of file + + return wrapper + + +def validate_environment(func: callable) -> Callable: + """Validates whether an opctl command can be executed in the current environment.""" + + @wraps(func) + def wrapper(*args: List, **kwargs: Dict) -> Dict: + if any( + value + for value in ( + JOB_RUN_OCID, + NB_SESSION_OCID, + PIPELINE_RUN_OCID, + DATAFLOW_RUN_OCID, + MD_OCID, + ) + ): + raise OpctlEnvironmentError() + + return func(*args, **kwargs) + + return wrapper diff --git a/ads/opctl/operator/cmd.py b/ads/opctl/operator/cmd.py index d2faf9d48..29388ab6e 100644 --- a/ads/opctl/operator/cmd.py +++ b/ads/opctl/operator/cmd.py @@ -36,6 +36,7 @@ ) from ads.opctl.operator.common.const import PACK_TYPE from ads.opctl.operator.common.utils import OperatorInfo, _operator_info +from ads.opctl.decorator.common import validate_environment from ads.opctl.utils import publish_image as publish_image_cmd from .__init__ import __operators__ @@ -46,7 +47,6 @@ ) from .common.utils import ( _build_image, - _load_yaml_from_uri, _operator_info_list, ) @@ -312,6 +312,7 @@ def init( @runtime_dependency(module="docker", install_from=OptionalDependency.OPCTL) +@validate_environment def build_image( name: str = None, source_folder: str = None, @@ -442,6 +443,7 @@ def build_image( @runtime_dependency(module="docker", install_from=OptionalDependency.OPCTL) +@validate_environment def publish_image( name: str, registry: str = None, diff --git a/tests/unitary/with_extras/opctl/test_opctl_decorators.py b/tests/unitary/with_extras/opctl/test_opctl_decorators.py new file mode 100644 index 000000000..27d93f45b --- /dev/null +++ b/tests/unitary/with_extras/opctl/test_opctl_decorators.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python + +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import pytest +from ads.opctl.decorator import common +from ads.opctl.decorator.common import validate_environment, OpctlEnvironmentError +from unittest.mock import patch + + +class TestOpctlDecorators: + """Tests the all OPCTL common decorators.""" + + @patch("ads.opctl.decorator.common.NB_SESSION_OCID", None) + @patch("ads.opctl.decorator.common.JOB_RUN_OCID", None) + @patch("ads.opctl.decorator.common.PIPELINE_RUN_OCID", None) + @patch("ads.opctl.decorator.common.DATAFLOW_RUN_OCID", None) + @patch("ads.opctl.decorator.common.MD_OCID", None) + def test_validate_environment_success(self): + """Tests validating environment decorator.""" + + @validate_environment + def mock_function(): + return "SUCCESS" + + assert mock_function() == "SUCCESS" + + @patch("ads.opctl.decorator.common.NB_SESSION_OCID", "TEST") + def test_validate_environment_fail(self): + """Tests validating environment decorator fails.""" + + @validate_environment + def mock_function(): + return "SUCCESS" + + with pytest.raises(OpctlEnvironmentError): + assert mock_function()