Skip to content

[AQUA] Enhance get_config to Return Model Details and Configuration in a Pydantic format #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 23 additions & 21 deletions ads/aqua/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
import os
import traceback
from dataclasses import fields
from typing import Dict, Optional, Union
from typing import Any, Dict, Optional, Union

import oci
from oci.data_science.models import UpdateModelDetails, UpdateModelProvenanceDetails

from ads import set_auth
from ads.aqua import logger
from ads.aqua.common.entities import ModelConfigResult
from ads.aqua.common.enums import ConfigFolder, Tags
from ads.aqua.common.errors import AquaRuntimeError, AquaValueError
from ads.aqua.common.utils import (
Expand Down Expand Up @@ -273,24 +274,24 @@ def get_config(
model_id: str,
config_file_name: str,
config_folder: Optional[str] = ConfigFolder.CONFIG,
) -> Dict:
"""Gets the config for the given Aqua model.
) -> ModelConfigResult:
"""
Gets the configuration for the given Aqua model along with the model details.

Parameters
----------
model_id: str
model_id : str
The OCID of the Aqua model.
config_file_name: str
name of the config file
config_folder: (str, optional):
subfolder path where config_file_name needs to be searched
Defaults to `ConfigFolder.CONFIG`.
When searching inside model artifact directory , the value is ConfigFolder.ARTIFACT`
config_file_name : str
The name of the configuration file.
config_folder : Optional[str]
The subfolder path where config_file_name is searched.
Defaults to ConfigFolder.CONFIG. For model artifact directories, use ConfigFolder.ARTIFACT.

Returns
-------
Dict:
A dict of allowed configs.
ModelConfigResult
A Pydantic model containing the model_details (extracted from OCI) and the config dictionary.
"""
config_folder = config_folder or ConfigFolder.CONFIG
oci_model = self.ds_client.get_model(model_id).data
Expand All @@ -302,11 +303,11 @@ def get_config(
if oci_model.freeform_tags
else False
)

if not oci_aqua:
raise AquaRuntimeError(f"Target model {oci_model.id} is not Aqua model.")
raise AquaRuntimeError(f"Target model {oci_model.id} is not an Aqua model.")

config: Dict[str, Any] = {}

config = {}
# if the current model has a service model tag, then
if Tags.AQUA_SERVICE_MODEL_TAG in oci_model.freeform_tags:
base_model_ocid = oci_model.freeform_tags[Tags.AQUA_SERVICE_MODEL_TAG]
Expand All @@ -326,7 +327,7 @@ def get_config(
logger.debug(
f"Failed to get artifact path from custom metadata for the model: {model_id}"
)
return config
return ModelConfigResult(config=config, model_details=oci_model)

config_path = os.path.join(os.path.dirname(artifact_path), config_folder)
if not is_path_exists(config_path):
Expand All @@ -351,9 +352,8 @@ def get_config(
f"{config_file_name} is not available for the model: {model_id}. "
f"Check if the custom metadata has the artifact path set."
)
return config

return config
return ModelConfigResult(config=config, model_details=oci_model)

@property
def telemetry(self):
Expand All @@ -375,9 +375,11 @@ def build_cli(self) -> str:
"""
cmd = f"ads aqua {self._command}"
params = [
f"--{field.name} {json.dumps(getattr(self, field.name))}"
if isinstance(getattr(self, field.name), dict)
else f"--{field.name} {getattr(self, field.name)}"
(
f"--{field.name} {json.dumps(getattr(self, field.name))}"
if isinstance(getattr(self, field.name), dict)
else f"--{field.name} {getattr(self, field.name)}"
)
for field in fields(self.__class__)
if getattr(self, field.name) is not None
]
Expand Down
29 changes: 28 additions & 1 deletion ads/aqua/common/entities.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#!/usr/bin/env python
# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

from typing import Any, Dict, Optional

from oci.data_science.models import Model
from pydantic import BaseModel, Field


class ContainerSpec:
"""
Expand All @@ -15,3 +20,25 @@ class ContainerSpec:
ENV_VARS = "envVars"
RESTRICTED_PARAMS = "restrictedParams"
EVALUATION_CONFIGURATION = "evaluationConfiguration"


class ModelConfigResult(BaseModel):
"""
Represents the result of getting the AQUA model configuration.

Attributes:
model_details (Dict[str, Any]): A dictionary containing model details extracted from OCI.
config (Dict[str, Any]): A dictionary of the loaded configuration.
"""

config: Optional[Dict[str, Any]] = Field(
None, description="Loaded configuration dictionary."
)
model_details: Optional[Model] = Field(
None, description="Details of the model from OCI."
)

class Config:
extra = "ignore"
arbitrary_types_allowed = True
protected_namespaces = ()
2 changes: 1 addition & 1 deletion ads/aqua/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,7 +592,7 @@ def get_finetuning_config(self, model_id: str) -> Dict:
Dict:
A dict of allowed finetuning configs.
"""
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG)
config = self.get_config(model_id, AQUA_MODEL_FINETUNING_CONFIG).config
if not config:
logger.debug(
f"Fine-tuning config for custom model: {model_id} is not available. Use defaults."
Expand Down
2 changes: 1 addition & 1 deletion ads/aqua/model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,7 +586,7 @@ def get_hf_tokenizer_config(self, model_id):
"""
config = self.get_config(
model_id, AQUA_MODEL_TOKENIZER_CONFIG, ConfigFolder.ARTIFACT
)
).config
if not config:
logger.debug(f"Tokenizer config for model: {model_id} is not available.")
return config
Expand Down
2 changes: 1 addition & 1 deletion ads/aqua/modeldeployment/deployment.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,7 @@ def get_deployment_config(self, model_id: str) -> Dict:
Dict:
A dict of allowed deployment configs.
"""
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG)
config = self.get_config(model_id, AQUA_MODEL_DEPLOYMENT_CONFIG).config
if not config:
logger.debug(
f"Deployment config for custom model: {model_id} is not available. Use defaults."
Expand Down
6 changes: 3 additions & 3 deletions tests/unitary/with_extras/aqua/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ def test_load_config(
model_id="test_model_id", config_file_name="test_config_file_name"
)
if not path_exists:
assert result == {}
assert result.config == {}
if not custom_metadata:
assert result == {}
assert result.config == {}
if path_exists and custom_metadata:
assert result == {"config_key": "config_value"}
assert result.config == {"config_key": "config_value"}
7 changes: 4 additions & 3 deletions tests/unitary/with_extras/aqua/test_deployment.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--

# Copyright (c) 2024 Oracle and/or its affiliates.
# Copyright (c) 2024, 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import copy
Expand All @@ -16,6 +16,7 @@
import pytest
from parameterized import parameterized

from ads.aqua.common.entities import ModelConfigResult
import ads.aqua.modeldeployment.deployment
import ads.config
from ads.aqua.modeldeployment import AquaDeploymentApp, MDInferenceResponse
Expand Down Expand Up @@ -438,11 +439,11 @@ def test_get_deployment_config(self):
with open(config_json, "r") as _file:
config = json.load(_file)

self.app.get_config = MagicMock(return_value=config)
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config))
result = self.app.get_deployment_config(TestDataset.MODEL_ID)
assert result == config

self.app.get_config = MagicMock(return_value=None)
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=None))
result = self.app.get_deployment_config(TestDataset.MODEL_ID)
assert result == None

Expand Down
3 changes: 2 additions & 1 deletion tests/unitary/with_extras/aqua/test_finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import ads.aqua.finetuning.finetuning
import ads.config
from ads.aqua.app import AquaApp
from ads.aqua.common.entities import ModelConfigResult
from ads.aqua.common.errors import AquaValueError
from ads.aqua.finetuning import AquaFineTuningApp
from ads.aqua.finetuning.constants import FineTuneCustomMetadata
Expand Down Expand Up @@ -279,7 +280,7 @@ def test_get_finetuning_config(self):
with open(config_json, "r") as _file:
config = json.load(_file)

self.app.get_config = MagicMock(return_value=config)
self.app.get_config = MagicMock(return_value=ModelConfigResult(config=config))
result = self.app.get_finetuning_config(model_id="test-model-id")
assert result == config

Expand Down