Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 53 additions & 3 deletions ads/opctl/decorator/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,48 @@
# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import Dict, Callable
from functools import wraps
from typing import Callable, Dict, List

from ads.config import (
JOB_RUN_OCID,
NB_SESSION_OCID,
PIPELINE_RUN_OCID,
DATAFLOW_RUN_OCID,
MD_OCID,
)

RUN_ID_FIELD = "run_id"


class OpctlEnvironmentError(Exception):
"""The custom error to validate OPCTL environment."""

NOT_SUPPORTED_ENVIRONMENTS = (
"Notebook Sessions",
"Data Science Jobs",
"ML Pipelines",
"Data Flow Applications",
)

def __init__(self):
super().__init__(
"This operation cannot be executed in the current environment. "
f"It is not supported in: {', '.join(self.NOT_SUPPORTED_ENVIRONMENTS)}."
)


def print_watch_command(func: callable) -> Callable:
"""The decorator to help build the `opctl watch` command."""

@wraps(func)
def wrapper(*args, **kwargs) -> Dict:
def wrapper(*args: List, **kwargs: Dict) -> Dict:
result = func(*args, **kwargs)
if result and isinstance(result, Dict) and RUN_ID_FIELD in result:
msg_header = f"{'*' * 40} To monitor the progress of a task, execute the following command {'*' * 40}"
msg_header = (
f"{'*' * 40} To monitor the progress of the task, "
"execute the following command {'*' * 40}"
)
print(msg_header)
print(f"ads opctl watch {result[RUN_ID_FIELD]}")
print("*" * len(msg_header))
Expand All @@ -26,6 +54,28 @@ def wrapper(*args, **kwargs) -> Dict:
return wrapper


def validate_environment(func: callable) -> Callable:
"""Validates whether an opctl command can be executed in the current environment."""

@wraps(func)
def wrapper(*args: List, **kwargs: Dict) -> Dict:
if any(
value
for value in (
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checking only the environment variable might not cause problem if user set the value manually in a supported environment. I think, since the root cause of the operation not being supported is that the environment does not support certain dependency, for example, docker. Maybe it is enough to tell users that since docker is not installed, the operation is not supported.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point! User may not have set up dataflow, for example

JOB_RUN_OCID,
NB_SESSION_OCID,
PIPELINE_RUN_OCID,
DATAFLOW_RUN_OCID,
MD_OCID,
)
):
raise OpctlEnvironmentError()

return func(*args, **kwargs)

return wrapper


def click_options(options):
"""The decorator to help group the click options."""

Expand Down
3 changes: 3 additions & 0 deletions ads/opctl/operator/cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
OPERATOR_BASE_IMAGE,
)
from ads.opctl.operator.common.utils import OperatorInfo, _operator_info
from ads.opctl.decorator.common import validate_environment
from ads.opctl.utils import publish_image as publish_image_cmd

from .__init__ import __operators__
Expand Down Expand Up @@ -324,6 +325,7 @@ def init(


@runtime_dependency(module="docker", install_from=OptionalDependency.OPCTL)
@validate_environment
def build_image(
name: str = None,
source_folder: str = None,
Expand Down Expand Up @@ -454,6 +456,7 @@ def build_image(


@runtime_dependency(module="docker", install_from=OptionalDependency.OPCTL)
@validate_environment
def publish_image(
name: str,
registry: str = None,
Expand Down
38 changes: 38 additions & 0 deletions tests/unitary/with_extras/opctl/test_opctl_decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
#!/usr/bin/env python

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import pytest
from ads.opctl.decorator import common
from ads.opctl.decorator.common import validate_environment, OpctlEnvironmentError
from unittest.mock import patch


class TestOpctlDecorators:
"""Tests the all OPCTL common decorators."""

@patch("ads.opctl.decorator.common.NB_SESSION_OCID", None)
@patch("ads.opctl.decorator.common.JOB_RUN_OCID", None)
@patch("ads.opctl.decorator.common.PIPELINE_RUN_OCID", None)
@patch("ads.opctl.decorator.common.DATAFLOW_RUN_OCID", None)
@patch("ads.opctl.decorator.common.MD_OCID", None)
def test_validate_environment_success(self):
"""Tests validating environment decorator."""

@validate_environment
def mock_function():
return "SUCCESS"

assert mock_function() == "SUCCESS"

@patch("ads.opctl.decorator.common.NB_SESSION_OCID", "TEST")
def test_validate_environment_fail(self):
"""Tests validating environment decorator fails."""

@validate_environment
def mock_function():
return "SUCCESS"

with pytest.raises(OpctlEnvironmentError):
assert mock_function()