diff --git a/ads/aqua/app.py b/ads/aqua/app.py index f94b2b29b..25c055b26 100644 --- a/ads/aqua/app.py +++ b/ads/aqua/app.py @@ -6,13 +6,14 @@ import os import traceback from dataclasses import fields -from typing import Dict, Optional, Union +from typing import Any, Dict, Optional, Union import oci from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails from ads import set_auth from ads.aqua import logger +from ads.aqua.common.entities import ModelConfigResult from ads.aqua.common.enums import ConfigFolder, Tags from ads.aqua.common.errors import AquaRuntimeError, AquaValueError from ads.aqua.common.utils import ( @@ -273,24 +274,24 @@ def get_config( model_id: str, config_file_name: str, config_folder: Optional[str] = ConfigFolder.CONFIG, - ) -> Dict: - """Gets the config for the given Aqua model. + ) -> ModelConfigResult: + """ + Gets the configuration for the given Aqua model along with the model details. Parameters ---------- - model_id: str + model_id : str The OCID of the Aqua model. - config_file_name: str - name of the config file - config_folder: (str, optional): - subfolder path where config_file_name needs to be searched - Defaults to `ConfigFolder.CONFIG`. - When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT` + config_file_name : str + The name of the configuration file. + config_folder : Optional[str] + The subfolder path where config_file_name is searched. + Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT. Returns ------- - Dict: - A dict of allowed configs. + ModelConfigResult + A Pydantic model containing the model_details (extracted from OCI) and the config dictionary. """ config_folder = config_folder or ConfigFolder.CONFIG oci_model = self.ds_client.get_model(model_id).data @@ -302,11 +303,11 @@ def get_config( if oci_model.freeform_tags else False ) - if not oci_aqua: - raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.") + raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.") + + config: Dict[str, Any] = {} - config = {} # if the current model has a service model tag, then if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags: base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG] @@ -326,7 +327,7 @@ def get_config( logger.debug( f"Failed to get artifact path from custom metadata for the model: {model_id}" ) - return config + return ModelConfigResult(config=config, model_details=oci_model) config_path = os.path.join(os.path.dirname(artifact_path), config_folder) if not is_path_exists(config_path): @@ -351,9 +352,8 @@ def get_config( f"{config_file_name} is not available for the model: {model_id}. " f"Check if the custom metadata has the artifact path set." ) - return config - return config + return ModelConfigResult(config=config, model_details=oci_model) @property def telemetry(self): @@ -375,9 +375,11 @@ def build_cli(self) -> str: """ cmd = f"ads aqua {self._command}" params = [ - f"--{field.name} {json.dumps(getattr(self, field.name))}" - if isinstance(getattr(self, field.name), dict) - else f"--{field.name} {getattr(self, field.name)}" + ( + f"--{field.name} {json.dumps(getattr(self, field.name))}" + if isinstance(getattr(self, field.name), dict) + else f"--{field.name} {getattr(self, field.name)}" + ) for field in fields(self.__class__) if getattr(self, field.name) is not None ] diff --git a/ads/aqua/common/entities.py b/ads/aqua/common/entities.py index 3528e9160..13ebf294a 100644 --- a/ads/aqua/common/entities.py +++ b/ads/aqua/common/entities.py @@ -1,7 +1,12 @@ #!/usr/bin/env python -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +from typing import Any, Dict, Optional + +from oci.data_science.models import Model +from pydantic import BaseModel, Field + class ContainerSpec: """ @@ -15,3 +20,25 @@ class ContainerSpec: ENV_VARS = "envVars" RESTRICTED_PARAMS = "restrictedParams" EVALUATION_CONFIGURATION = "evaluationConfiguration" + + +class ModelConfigResult(BaseModel): + """ + Represents the result of getting the AQUA model configuration. + + Attributes: + model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI. + config (Dict[str, Any]): A dictionary of the loaded configuration. + """ + + config: Optional[Dict[str, Any]] = Field( + None, description="Loaded configuration dictionary." + ) + model_details: Optional[Model] = Field( + None, description="Details of the model from OCI." + ) + + class Config: + extra = "ignore" + arbitrary_types_allowed = True + protected_namespaces = () diff --git a/ads/aqua/finetuning/finetuning.py b/ads/aqua/finetuning/finetuning.py index 29b9e1838..229f4deae 100644 --- a/ads/aqua/finetuning/finetuning.py +++ b/ads/aqua/finetuning/finetuning.py @@ -592,7 +592,7 @@ def get_finetuning_config(self, model_id: str) -> Dict: Dict: A dict of allowed finetuning configs. """ - config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG) + config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG).config if not config: logger.debug( f"Fine-tuning config for custom model: {model_id} is not available. Use defaults." diff --git a/ads/aqua/model/model.py b/ads/aqua/model/model.py index 24a07f1c9..40a954f06 100644 --- a/ads/aqua/model/model.py +++ b/ads/aqua/model/model.py @@ -586,7 +586,7 @@ def get_hf_tokenizer_config(self, model_id): """ config = self.get_config( model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT - ) + ).config if not config: logger.debug(f"Tokenizer config for model: {model_id} is not available.") return config diff --git a/ads/aqua/modeldeployment/deployment.py b/ads/aqua/modeldeployment/deployment.py index 1c858d4b7..4743399bc 100644 --- a/ads/aqua/modeldeployment/deployment.py +++ b/ads/aqua/modeldeployment/deployment.py @@ -654,7 +654,7 @@ def get_deployment_config(self, model_id: str) -> Dict: Dict: A dict of allowed deployment configs. """ - config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG) + config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG).config if not config: logger.debug( f"Deployment config for custom model: {model_id} is not available. Use defaults." diff --git a/tests/unitary/with_extras/aqua/test_config.py b/tests/unitary/with_extras/aqua/test_config.py index 3d6d2de1d..04ef25888 100644 --- a/tests/unitary/with_extras/aqua/test_config.py +++ b/tests/unitary/with_extras/aqua/test_config.py @@ -96,8 +96,8 @@ def test_load_config( model_id="test_model_id", config_file_name="test_config_file_name" ) if not path_exists: - assert result == {} + assert result.config == {} if not custom_metadata: - assert result == {} + assert result.config == {} if path_exists and custom_metadata: - assert result == {"config_key": "config_value"} + assert result.config == {"config_key": "config_value"} diff --git a/tests/unitary/with_extras/aqua/test_deployment.py b/tests/unitary/with_extras/aqua/test_deployment.py index cad2759f1..81d005fa1 100644 --- a/tests/unitary/with_extras/aqua/test_deployment.py +++ b/tests/unitary/with_extras/aqua/test_deployment.py @@ -1,7 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*-- -# Copyright (c) 2024 Oracle and/or its affiliates. +# Copyright (c) 2024, 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ import copy @@ -16,6 +16,7 @@ import pytest from parameterized import parameterized +from ads.aqua.common.entities import ModelConfigResult import ads.aqua.modeldeployment.deployment import ads.config from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse @@ -438,11 +439,11 @@ def test_get_deployment_config(self): with open(config_json, "r") as _file: config = json.load(_file) - self.app.get_config = MagicMock(return_value=config) + self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config)) result = self.app.get_deployment_config(TestDataset.MODEL_ID) assert result == config - self.app.get_config = MagicMock(return_value=None) + self.app.get_config = MagicMock(return_value=ModelConfigResult(config=None)) result = self.app.get_deployment_config(TestDataset.MODEL_ID) assert result == None diff --git a/tests/unitary/with_extras/aqua/test_finetuning.py b/tests/unitary/with_extras/aqua/test_finetuning.py index bc2a9c883..e082595ca 100644 --- a/tests/unitary/with_extras/aqua/test_finetuning.py +++ b/tests/unitary/with_extras/aqua/test_finetuning.py @@ -17,6 +17,7 @@ import ads.aqua.finetuning.finetuning import ads.config from ads.aqua.app import AquaApp +from ads.aqua.common.entities import ModelConfigResult from ads.aqua.common.errors import AquaValueError from ads.aqua.finetuning import AquaFineTuningApp from ads.aqua.finetuning.constants import FineTuneCustomMetadata @@ -279,7 +280,7 @@ def test_get_finetuning_config(self): with open(config_json, "r") as _file: config = json.load(_file) - self.app.get_config = MagicMock(return_value=config) + self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config)) result = self.app.get_finetuning_config(model_id="test-model-id") assert result == config