Skip to content

Commit

Permalink
Merge branch 'main' into feature/forecast-auto-select
Browse files Browse the repository at this point in the history
  • Loading branch information
prasankh committed May 9, 2024
2 parents 8bc6aff + 26db0f1 commit 013e485
Show file tree
Hide file tree
Showing 55 changed files with 2,083 additions and 394 deletions.
59 changes: 59 additions & 0 deletions .github/workflows/run-forecast-unit-tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
name: "Forecast Operator Tests"

on:
workflow_dispatch:
pull_request:
branches: [ "main", "operators/**" ]

# Cancel in progress workflows on pull_requests.
# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-a-fallback-value
concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

permissions:
contents: read

# hack for https://github.com/actions/cache/issues/810#issuecomment-1222550359
env:
SEGMENT_DOWNLOAD_TIMEOUT_MINS: 5

jobs:
test:
name: python ${{ matrix.python-version }}
runs-on: ubuntu-latest
timeout-minutes: 180

strategy:
fail-fast: false
matrix:
python-version: ["3.8"]

steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ github.event.pull_request.head.sha }}


- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: "pip"
cache-dependency-path: |
pyproject.toml
"**requirements.txt"
"test-requirements-operators.txt"
- uses: ./.github/workflows/set-dummy-conf
name: "Test config setup"

- name: "Run Forecast Tests"
timeout-minutes: 180
shell: bash
run: |
set -x # print commands that are executed
$CONDA/bin/conda init
source /home/runner/.bashrc
pip install -r test-requirements-operators.txt
python -m pytest -v -p no:warnings --durations=5 tests/operators/forecast
10 changes: 3 additions & 7 deletions .github/workflows/run-operators-unit-tests.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
name: "[Py3.8][Py3.10] Operators Tests"
name: "Operators Tests"

on:
workflow_dispatch:
pull_request:
paths:
- "ads/opctl/operator/**"
- "**requirements.txt"
- ".github/workflows/run-operators*.yml"
- "test-requirements-operators.txt"
branches: [ "main", "operators/**" ]

# Cancel in progress workflows on pull_requests.
# https://docs.github.com/en/actions/using-jobs/using-concurrency#example-using-a-fallback-value
Expand Down Expand Up @@ -56,4 +52,4 @@ jobs:
$CONDA/bin/conda init
source /home/runner/.bashrc
pip install -r test-requirements-operators.txt
python -m pytest -v -p no:warnings --durations=5 tests/operators
python -m pytest -v -p no:warnings --durations=5 tests/operators --ignore=tests/operators/forecast
6 changes: 6 additions & 0 deletions THIRD_PARTY_LICENSES.txt
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,12 @@ python-fire
* Source code: https://github.com/google/python-fire
* Project home: https://github.com/google/python-fire

mlforecast
* Copyright 2024 Nixtla
* License: Apache License 2.0
* Source code: https://github.com/Nixtla/mlforecast
* Project home: https://github.com/Nixtla/mlforecast

=======
=============================== Licenses ===============================
------------------------------------------------------------------------
Expand Down
38 changes: 24 additions & 14 deletions ads/aqua/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,35 @@
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/


import logging
import sys
import os

from ads import logger, set_auth
from ads.aqua.utils import fetch_service_compartment
from ads.config import OCI_RESOURCE_PRINCIPAL_VERSION
from ads import set_auth

logger = logging.getLogger(__name__)
handler = logging.StreamHandler(sys.stdout)
logger.setLevel(logging.INFO)
ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL"


def get_logger_level():
"""Retrieves logging level from environment variable `ADS_AQUA_LOG_LEVEL`."""
level = os.environ.get(ENV_VAR_LOG_LEVEL, "INFO").upper()
return level


logger.setLevel(get_logger_level())


def set_log_level(log_level: str):
"""Global for setting logging level."""

log_level = log_level.upper()
logger.setLevel(log_level.upper())
logger.handlers[0].setLevel(log_level)


if OCI_RESOURCE_PRINCIPAL_VERSION:
set_auth("resource_principal")

ODSC_MODEL_COMPARTMENT_OCID = os.environ.get("ODSC_MODEL_COMPARTMENT_OCID")
if not ODSC_MODEL_COMPARTMENT_OCID:
try:
ODSC_MODEL_COMPARTMENT_OCID = fetch_service_compartment()
except Exception as e:
logger.error(
f"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua, due to {e}."
)
ODSC_MODEL_COMPARTMENT_OCID = (
os.environ.get("ODSC_MODEL_COMPARTMENT_OCID") or fetch_service_compartment()
)
2 changes: 0 additions & 2 deletions ads/aqua/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
get_artifact_path,
is_valid_ocid,
load_config,
logger,
)
from ads.common import oci_client as oc
from ads.common.auth import default_signer
Expand Down Expand Up @@ -164,7 +163,6 @@ def create_model_version_set(
tag = Tags.AQUA_FINE_TUNING.value

if not model_version_set_id:
tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this
try:
model_version_set = ModelVersionSet.from_name(
name=model_version_set_name,
Expand Down
52 changes: 50 additions & 2 deletions ads/aqua/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,65 @@

# Copyright (c) 2024 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
import os
import sys

from ads.aqua import (
ENV_VAR_LOG_LEVEL,
set_log_level,
ODSC_MODEL_COMPARTMENT_OCID,
logger,
)
from ads.aqua.deployment import AquaDeploymentApp
from ads.aqua.evaluation import AquaEvaluationApp
from ads.aqua.finetune import AquaFineTuningApp
from ads.aqua.model import AquaModelApp
from ads.aqua.evaluation import AquaEvaluationApp
from ads.config import NB_SESSION_OCID
from ads.common.utils import LOG_LEVELS


class AquaCommand:
"""Contains the command groups for project Aqua."""
"""Contains the command groups for project Aqua.
Acts as an entry point for managing different components of the Aqua
project including model management, fine-tuning, deployment, and
evaluation.
"""

model = AquaModelApp
fine_tuning = AquaFineTuningApp
deployment = AquaDeploymentApp
evaluation = AquaEvaluationApp

def __init__(
self,
log_level: str = os.environ.get(ENV_VAR_LOG_LEVEL, "ERROR").upper(),
):
"""
Initialize the command line interface settings for the Aqua project.
FLAGS
-----
log_level (str):
Sets the logging level for the application.
Default is retrieved from environment variable `LOG_LEVEL`,
or 'ERROR' if not set. Example values include 'DEBUG', 'INFO',
'WARNING', 'ERROR', and 'CRITICAL'.
"""
if log_level.upper() not in LOG_LEVELS:
logger.error(
f"Log level should be one of {LOG_LEVELS}. Setting default to ERROR."
)
log_level = "ERROR"
set_log_level(log_level)
# gracefully exit if env var is not set
if not ODSC_MODEL_COMPARTMENT_OCID:
logger.debug(
"ODSC_MODEL_COMPARTMENT_OCID environment variable is not set for Aqua."
)
if NB_SESSION_OCID:
logger.error(
f"Aqua is not available for the notebook session {NB_SESSION_OCID}. For more information, "
f"please refer to the documentation."
)
sys.exit(1)
8 changes: 8 additions & 0 deletions ads/aqua/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
RequestException,
ServiceError,
)
from tornado.web import HTTPError

from ads.aqua.exception import AquaError
from ads.aqua.extension.base_handler import AquaAPIhandler
Expand Down Expand Up @@ -58,6 +59,7 @@ def inner_function(self: AquaAPIhandler, *args, **kwargs):
except ServiceError as error:
self.write_error(
status_code=error.status or 500,
message=error.message,
reason=error.message,
service_payload=error.args[0] if error.args else None,
exc_info=sys.exc_info(),
Expand Down Expand Up @@ -91,6 +93,12 @@ def inner_function(self: AquaAPIhandler, *args, **kwargs):
service_payload=error.service_payload,
exc_info=sys.exc_info(),
)
except HTTPError as e:
self.write_error(
status_code=e.status_code,
reason=e.log_message,
exc_info=sys.exc_info(),
)
except Exception as ex:
self.write_error(
status_code=500,
Expand Down
71 changes: 37 additions & 34 deletions ads/aqua/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
UNKNOWN_DICT,
get_resource_name,
get_model_by_reference_paths,
get_ocid_substring,
AQUA_MODEL_TYPE_SERVICE,
AQUA_MODEL_TYPE_CUSTOM,
)
from ads.aqua.finetune import FineTuneCustomMetadata
from ads.aqua.data import AquaResourceIdentifier
Expand Down Expand Up @@ -391,40 +394,27 @@ def create(
.with_runtime(container_runtime)
).deploy(wait_for_completion=False)

if is_fine_tuned_model:
# tracks unique deployments that were created in the user compartment
self.telemetry.record_event_async(
category="aqua/custom/deployment", action="create", detail=model_name
)
# tracks the shape used for deploying the custom models
self.telemetry.record_event_async(
category="aqua/custom/deployment/create",
action="shape",
detail=instance_shape,
)
# tracks the shape used for deploying the custom models by name
self.telemetry.record_event_async(
category=f"aqua/custom/{model_name}/deployment/create",
action="shape",
detail=instance_shape,
)
else:
# tracks unique deployments that were created in the user compartment
self.telemetry.record_event_async(
category="aqua/service/deployment", action="create", detail=model_name
)
# tracks the shape used for deploying the service models
self.telemetry.record_event_async(
category="aqua/service/deployment/create",
action="shape",
detail=instance_shape,
)
# tracks the shape used for deploying the service models by name
self.telemetry.record_event_async(
category=f"aqua/service/{model_name}/deployment/create",
action="shape",
detail=instance_shape,
)
model_type = (
AQUA_MODEL_TYPE_CUSTOM if is_fine_tuned_model else AQUA_MODEL_TYPE_SERVICE
)
deployment_id = deployment.dsc_model_deployment.id
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
telemetry_kwargs = {"ocid": get_ocid_substring(deployment_id, key_len=8)}

# tracks unique deployments that were created in the user compartment
self.telemetry.record_event_async(
category=f"aqua/{model_type}/deployment",
action="create",
detail=model_name,
**telemetry_kwargs,
)
# tracks the shape used for deploying the custom or service models by name
self.telemetry.record_event_async(
category=f"aqua/{model_type}/deployment/create",
action="shape",
detail=instance_shape,
value=model_name,
)

return AquaDeployment.from_oci_model_deployment(
deployment.dsc_model_deployment, self.region
Expand Down Expand Up @@ -471,6 +461,19 @@ def list(self, **kwargs) -> List["AquaDeployment"]:
)
)

# log telemetry if MD is in active or failed state
deployment_id = model_deployment.id
state = model_deployment.lifecycle_state.upper()
if state in ["ACTIVE", "FAILED"]:
# tracks unique deployments that were listed in the user compartment
# we arbitrarily choose last 8 characters of OCID to identify MD in telemetry
self.telemetry.record_event_async(
category=f"aqua/deployment",
action="list",
detail=get_ocid_substring(deployment_id, key_len=8),
value=state,
)

# tracks number of times deployment listing was called
self.telemetry.record_event_async(category="aqua/deployment", action="list")

Expand Down
Loading

0 comments on commit 013e485

Please sign in to comment.