diff --git a/ads/aqua/__init__.py b/ads/aqua/__init__.py index 45f1a8b03..ea1112623 100644 --- a/ads/aqua/__init__.py +++ b/ads/aqua/__init__.py @@ -5,15 +5,50 @@ import logging -import sys import os +import sys + +from ads import set_auth from ads.aqua.utils import fetch_service_compartment from ads.config import NB_SESSION_OCID, OCI_RESOURCE_PRINCIPAL_VERSION -from ads import set_auth -logger = logging.getLogger(__name__) -handler = logging.StreamHandler(sys.stdout) -logger.setLevel(logging.INFO) +ENV_VAR_LOG_LEVEL = "ADS_AQUA_LOG_LEVEL" + + +def get_logger_level(): + """Retrieves logging level from environment variable `LOG_LEVEL`.""" + level = os.environ.get(ENV_VAR_LOG_LEVEL, "INFO").upper() + return level + + +def configure_aqua_logger(): + """Configures the AQUA logger.""" + log_level = get_logger_level() + logger = logging.getLogger(__name__) + logger.setLevel(log_level) + + handler = logging.StreamHandler(sys.stdout) + formatter = logging.Formatter( + "%(asctime)s - %(name)s.%(module)s - %(levelname)s - %(message)s" + ) + handler.setFormatter(formatter) + handler.setLevel(log_level) + + logger.addHandler(handler) + logger.propagate = False + return logger + + +logger = configure_aqua_logger() + + +def set_log_level(log_level: str): + """Global for setting logging level.""" + + log_level = log_level.upper() + logger.setLevel(log_level.upper()) + logger.handlers[0].setLevel(log_level) + if OCI_RESOURCE_PRINCIPAL_VERSION: set_auth("resource_principal") diff --git a/ads/aqua/base.py b/ads/aqua/base.py index 870001cb6..43796ebb0 100644 --- a/ads/aqua/base.py +++ b/ads/aqua/base.py @@ -19,7 +19,6 @@ get_artifact_path, is_valid_ocid, load_config, - logger, ) from ads.common import oci_client as oc from ads.common.auth import default_signer @@ -164,7 +163,7 @@ def create_model_version_set( tag = Tags.AQUA_FINE_TUNING.value if not model_version_set_id: - tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this + tag = Tags.AQUA_FINE_TUNING.value # TODO: Fix this try: model_version_set = ModelVersionSet.from_name( name=model_version_set_name, diff --git a/ads/aqua/cli.py b/ads/aqua/cli.py index 55dff69bf..3c31129f7 100644 --- a/ads/aqua/cli.py +++ b/ads/aqua/cli.py @@ -3,17 +3,41 @@ # Copyright (c) 2024 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import os +from ads.aqua import ENV_VAR_LOG_LEVEL, set_log_level from ads.aqua.deployment import AquaDeploymentApp +from ads.aqua.evaluation import AquaEvaluationApp from ads.aqua.finetune import AquaFineTuningApp from ads.aqua.model import AquaModelApp -from ads.aqua.evaluation import AquaEvaluationApp class AquaCommand: - """Contains the command groups for project Aqua.""" + """Contains the command groups for project Aqua. + + Acts as an entry point for managing different components of the Aqua + project including model management, fine-tuning, deployment, and + evaluation. + """ model = AquaModelApp fine_tuning = AquaFineTuningApp deployment = AquaDeploymentApp evaluation = AquaEvaluationApp + + def __init__( + self, + log_level: str = os.environ.get(ENV_VAR_LOG_LEVEL, "ERROR").upper(), + ): + """ + Initialize the command line interface settings for the Aqua project. + + FLAGS + ----- + log_level (str): + Sets the logging level for the application. + Default is retrieved from environment variable `LOG_LEVEL`, + or 'ERROR' if not set. Example values include 'DEBUG', 'INFO', + 'WARNING', 'ERROR', and 'CRITICAL'. + """ + set_log_level(log_level) diff --git a/ads/aqua/evaluation.py b/ads/aqua/evaluation.py index dfce6e3ee..896f2f632 100644 --- a/ads/aqua/evaluation.py +++ b/ads/aqua/evaluation.py @@ -975,7 +975,7 @@ def list( self._process_evaluation_summary(model=model, jobrun=jobrun) ) except Exception as exc: - logger.error( + logger.debug( f"Processing evaluation: {model.identifier} generated an exception: {exc}" ) evaluations.append( @@ -1020,7 +1020,7 @@ def _if_eval_artifact_exist( return True if response.status == 200 else False except oci.exceptions.ServiceError as ex: if ex.status == 404: - logger.info("Evaluation artifact not found.") + logger.debug(f"Evaluation artifact not found for {model.identifier}.") return False @telemetry(entry_point="plugin=evaluation&action=get_status", name="aqua") @@ -1566,8 +1566,9 @@ def _build_resource_identifier( ), ) except Exception as e: - logger.error( - f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}" + logger.debug( + f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`. " + f"DEBUG INFO: {str(e)}" ) return AquaResourceIdentifier() @@ -1613,7 +1614,7 @@ def _fetch_runtime_params( ) if not params.get(EvaluationConfig.PARAMS): raise AquaMissingKeyError( - "model parameters have not been saved in correct format in model taxonomy.", + "model parameters have not been saved in correct format in model taxonomy. ", service_payload={"params": params}, ) # TODO: validate the format of parameters. @@ -1645,7 +1646,7 @@ def _build_job_identifier( except Exception as e: logger.debug( - f"Failed to get job details from job_run_details: {job_run_details}" + f"Failed to get job details from job_run_details: {job_run_details} " f"DEBUG INFO:{str(e)}" ) return AquaResourceIdentifier() diff --git a/ads/aqua/finetune.py b/ads/aqua/finetune.py index 2f3811ec1..24dfe5f9a 100644 --- a/ads/aqua/finetune.py +++ b/ads/aqua/finetune.py @@ -15,7 +15,7 @@ UpdateModelProvenanceDetails, ) -from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID +from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger from ads.aqua.base import AquaApp from ads.aqua.data import AquaResourceIdentifier, Resource, Tags from ads.aqua.exception import AquaFileExistsError, AquaValueError @@ -29,7 +29,6 @@ UNKNOWN, UNKNOWN_DICT, get_container_image, - logger, upload_local_to_os, ) from ads.common.auth import default_signer diff --git a/ads/aqua/model.py b/ads/aqua/model.py index 1a5042e18..982553f50 100644 --- a/ads/aqua/model.py +++ b/ads/aqua/model.py @@ -14,7 +14,7 @@ from cachetools import TTLCache from oci.data_science.models import JobRun, Model -from ads.aqua import logger, utils +from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID, logger, utils from ads.aqua.base import AquaApp from ads.aqua.constants import ( TRAINING_METRICS_FINAL, @@ -26,7 +26,6 @@ ) from ads.aqua.data import AquaResourceIdentifier, Tags from ads.aqua.exception import AquaRuntimeError - from ads.aqua.training.exceptions import exit_code_dict from ads.aqua.utils import ( LICENSE_TXT, @@ -50,7 +49,6 @@ PROJECT_OCID, TENANCY_OCID, ) -from ads.aqua import ODSC_MODEL_COMPARTMENT_OCID from ads.model import DataScienceModel from ads.model.model_metadata import MetadataTaxonomyKeys, ModelCustomMetadata from ads.telemetry import telemetry @@ -228,7 +226,7 @@ def __post_init__( ).value except Exception as e: logger.debug( - f"Failed to extract model hyperparameters from {model.id}:" f"{str(e)}" + f"Failed to extract model hyperparameters from {model.id}: " f"{str(e)}" ) model_hyperparameters = {} diff --git a/ads/aqua/utils.py b/ads/aqua/utils.py index d3369ba33..d543ce89f 100644 --- a/ads/aqua/utils.py +++ b/ads/aqua/utils.py @@ -10,7 +10,6 @@ import os import random import re -import sys from enum import Enum from functools import wraps from pathlib import Path @@ -22,22 +21,16 @@ from oci.data_science.models import JobRun, Model from ads.aqua.constants import RqsAdditionalDetails -from ads.aqua.data import AquaResourceIdentifier, Tags +from ads.aqua.data import AquaResourceIdentifier from ads.aqua.exception import AquaFileNotFoundError, AquaRuntimeError, AquaValueError from ads.common.auth import default_signer from ads.common.object_storage_details import ObjectStorageDetails from ads.common.oci_resource import SEARCH_TYPE, OCIResource from ads.common.utils import get_console_link, upload_to_os -from ads.config import ( - AQUA_SERVICE_MODELS_BUCKET, - CONDA_BUCKET_NS, - TENANCY_OCID, -) +from ads.config import AQUA_SERVICE_MODELS_BUCKET, CONDA_BUCKET_NS, TENANCY_OCID from ads.model import DataScienceModel, ModelVersionSet -# TODO: allow the user to setup the logging level? -logging.basicConfig(stream=sys.stdout, level=logging.INFO) -logger = logging.getLogger("ODSC_AQUA") +logger = logging.getLogger("ads.aqua") UNKNOWN = "" UNKNOWN_DICT = {} @@ -145,10 +138,6 @@ def get_status(evaluation_status: str, job_run_status: str = None): MODEL_BY_REFERENCE_OSS_PATH_KEY = "artifact_location" -def get_logger(): - return logger - - def random_color_generator(word: str): seed = sum([ord(c) for c in word]) % 13 random.seed(seed) @@ -235,7 +224,7 @@ def read_file(file_path: str, **kwargs) -> str: with fsspec.open(file_path, "r", **kwargs.get("auth", {})) as f: return f.read() except Exception as e: - logger.error(f"Failed to read file {file_path}. {e}") + logger.debug(f"Failed to read file {file_path}. {e}") return UNKNOWN @@ -485,7 +474,7 @@ def _build_resource_identifier( ), ) except Exception as e: - logger.error( + logger.debug( f"Failed to construct AquaResourceIdentifier from given id=`{id}`, and name=`{name}`, {str(e)}" ) return AquaResourceIdentifier() diff --git a/tests/unitary/with_extras/aqua/test_cli.py b/tests/unitary/with_extras/aqua/test_cli.py new file mode 100644 index 000000000..3ae0764bc --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_cli.py @@ -0,0 +1,47 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- + +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import logging +import subprocess +from unittest import TestCase +from unittest.mock import patch + +from parameterized import parameterized + +from ads.aqua.cli import AquaCommand + + +class TestAquaCLI(TestCase): + """Tests the AQUA CLI.""" + + DEFAUL_AQUA_CLI_LOGGING_LEVEL = "ERROR" + logger = logging.getLogger(__name__) + logging.basicConfig( + format="%(asctime)s %(module)s %(levelname)s: %(message)s", + datefmt="%m/%d/%Y %I:%M:%S %p", + level=logging.INFO, + ) + + def test_entrypoint(self): + """Tests CLI entrypoint.""" + result = subprocess.run(["ads", "aqua", "--help"], capture_output=True) + self.logger.info(f"{self._testMethodName}\n" + result.stderr.decode("utf-8")) + assert result.returncode == 0 + + @parameterized.expand( + [ + ("default", None, DEFAUL_AQUA_CLI_LOGGING_LEVEL), + ("set logging level", "info", "info"), + ] + ) + @patch("ads.aqua.cli.set_log_level") + def test_aquacommand(self, name, arg, expected, mock_setting_log): + """Tests aqua command initailzation.""" + if arg: + AquaCommand(arg) + else: + AquaCommand() + mock_setting_log.assert_called_with(expected) diff --git a/tests/unitary/with_extras/aqua/test_global.py b/tests/unitary/with_extras/aqua/test_global.py new file mode 100644 index 000000000..6ea41aee6 --- /dev/null +++ b/tests/unitary/with_extras/aqua/test_global.py @@ -0,0 +1,45 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*-- +# Copyright (c) 2024 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +import unittest +from unittest.mock import MagicMock, patch + +from ads.aqua import configure_aqua_logger, get_logger_level, set_log_level + + +class TestAquaLogging(unittest.TestCase): + DEFAULT_AQUA_LOG_LEVEL = "INFO" + + @patch.dict("os.environ", {}) + def test_get_logger_level_default(self): + """Test default log level when environment variable is not set.""" + self.assertEqual(get_logger_level(), self.DEFAULT_AQUA_LOG_LEVEL) + + @patch.dict("os.environ", {"ADS_AQUA_LOG_LEVEL": "DEBUG"}) + def test_get_logger_level_from_env(self): + """Test log level is correctly read from environment variable.""" + self.assertEqual(get_logger_level(), "DEBUG") + + @patch("logging.getLogger") + @patch("logging.StreamHandler") + def test_configure_aqua_logger(self, mock_handler, mock_get_logger): + """Test that logger is correctly configured.""" + mock_logger = MagicMock() + mock_get_logger.return_value = mock_logger + + logger = configure_aqua_logger() + + mock_get_logger.assert_called_once_with("ads.aqua") + mock_logger.setLevel.assert_called_with(self.DEFAULT_AQUA_LOG_LEVEL) + + @patch("ads.aqua.logger", create=True) + def test_set_log_level(self, mock_logger): + """Test that the log level of the logger is set correctly.""" + mock_handler = MagicMock() + mock_logger.handlers = [mock_handler] + + set_log_level("warning") + + mock_logger.setLevel.assert_called_with("WARNING")