diff --git a/ads/aqua/shaperecommend/recommend.py b/ads/aqua/shaperecommend/recommend.py index 1ef38a173..e83a0aa50 100644 --- a/ads/aqua/shaperecommend/recommend.py +++ b/ads/aqua/shaperecommend/recommend.py @@ -2,9 +2,14 @@ # Copyright (c) 2025 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ +import json +import os +import re import shutil -from typing import List, Union +from typing import Dict, List, Optional, Tuple, Union +from huggingface_hub import hf_hub_download +from huggingface_hub.utils import HfHubHTTPError from pydantic import ValidationError from rich.table import Table @@ -17,7 +22,9 @@ ) from ads.aqua.common.utils import ( build_pydantic_error_message, + format_hf_custom_error_message, get_resource_type, + is_valid_ocid, load_config, load_gpu_shapes_index, ) @@ -37,6 +44,7 @@ ShapeRecommendationReport, ShapeReport, ) +from ads.config import COMPARTMENT_OCID from ads.model.datascience_model import DataScienceModel from ads.model.service.oci_datascience_model_deployment import ( OCIDataScienceModelDeployment, @@ -91,11 +99,13 @@ def which_shapes( try: shapes = self.valid_compute_shapes(compartment_id=request.compartment_id) - ds_model = self._get_data_science_model(request.model_id) - - model_name = ds_model.display_name if ds_model.display_name else "" - if request.deployment_config: + if is_valid_ocid(request.model_id): + ds_model = self._get_data_science_model(request.model_id) + model_name = ds_model.display_name + else: + model_name = request.model_id + shape_recommendation_report = ( ShapeRecommendationReport.from_deployment_config( request.deployment_config, model_name, shapes @@ -103,8 +113,9 @@ def which_shapes( ) else: - data = self._get_model_config(ds_model) - + data, model_name = self._get_model_config_and_name( + model_id=request.model_id, + ) llm_config = LLMConfig.from_raw_config(data) shape_recommendation_report = self._summarize_shapes_for_seq_lens( @@ -135,7 +146,57 @@ def which_shapes( return shape_recommendation_report - def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]: + def _get_model_config_and_name( + self, + model_id: str, + ) -> Tuple[Dict, str]: + """ + Loads model configuration by trying OCID logic first, then falling back + to treating the model_id as a Hugging Face Hub ID. + + Parameters + ---------- + model_id : str + The model OCID or Hugging Face model ID. + # compartment_id : Optional[str] + # The compartment OCID, used for searching the model catalog. + + Returns + ------- + Tuple[Dict, str] + A tuple containing: + - The model configuration dictionary. + - The display name for the model. + """ + if is_valid_ocid(model_id): + logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.") + ds_model = self._get_data_science_model(model_id) + config = self._get_model_config(ds_model) + model_name = ds_model.display_name + else: + logger.info( + f"Assuming Hugging Face model ID: Fetching config for '{model_id}'." + ) + config = self._fetch_hf_config(model_id) + model_name = model_id + + return config, model_name + + def _fetch_hf_config(self, model_id: str) -> Dict: + """ + Downloads a model's config.json from Hugging Face Hub using the + huggingface_hub library. + """ + try: + config_path = hf_hub_download(repo_id=model_id, filename="config.json") + with open(config_path, "r", encoding="utf-8") as f: + return json.load(f) + except HfHubHTTPError as e: + format_hf_custom_error_message(e) + + def valid_compute_shapes( + self, compartment_id: Optional[str] = None + ) -> List["ComputeShapeSummary"]: """ Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file. @@ -151,9 +212,23 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary Raises ------ - ValueError - If the file cannot be opened, parsed, or the 'shapes' key is missing. + AquaValueError + If a compartment_id is not provided and cannot be found in the + environment variables. """ + if not compartment_id: + compartment_id = COMPARTMENT_OCID + if compartment_id: + logger.info(f"Using compartment_id from environment: {compartment_id}") + + if not compartment_id: + raise AquaValueError( + "A compartment OCID is required to list available shapes. " + "Please specify it using the --compartment_id parameter.\n\n" + "Example:\n" + 'ads aqua deployment recommend_shape --model_id "" --compartment_id ""' + ) + oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id) set_user_shapes = {shape.name: shape for shape in oci_shapes} @@ -324,6 +399,7 @@ def _get_model_config(model: DataScienceModel): """ model_task = model.freeform_tags.get("task", "").lower() + model_task = re.sub(r"-", "_", model_task) model_format = model.freeform_tags.get("model_format", "").lower() logger.info(f"Current model task type: {model_task}") diff --git a/ads/aqua/shaperecommend/shape_report.py b/ads/aqua/shaperecommend/shape_report.py index c9555c4cf..b1296fa85 100644 --- a/ads/aqua/shaperecommend/shape_report.py +++ b/ads/aqua/shaperecommend/shape_report.py @@ -29,7 +29,8 @@ class RequestRecommend(BaseModel): """ model_id: str = Field( - ..., description="The OCID of the model to recommend feasible compute shapes." + ..., + description="The OCID or Hugging Face ID of the model to recommend feasible compute shapes.", ) generate_table: Optional[bool] = ( Field( diff --git a/tests/unitary/with_extras/aqua/test_recommend.py b/tests/unitary/with_extras/aqua/test_recommend.py index 4fcd8c669..33b2503a8 100644 --- a/tests/unitary/with_extras/aqua/test_recommend.py +++ b/tests/unitary/with_extras/aqua/test_recommend.py @@ -7,12 +7,13 @@ import json import os import re -from unittest.mock import MagicMock - +from unittest.mock import MagicMock, mock_open +from unittest.mock import patch import pytest from ads.aqua.common.entities import ComputeShapeSummary -from ads.aqua.common.errors import AquaRecommendationError + +from ads.aqua.common.errors import AquaRecommendationError, AquaValueError from ads.aqua.modeldeployment.config_loader import AquaDeploymentConfig from ads.aqua.shaperecommend.estimator import ( LlamaMemoryEstimator, @@ -21,7 +22,9 @@ get_estimator, ) from ads.aqua.shaperecommend.llm_config import LLMConfig -from ads.aqua.shaperecommend.recommend import AquaShapeRecommend +from ads.aqua.shaperecommend.recommend import ( + AquaShapeRecommend, +) from ads.aqua.shaperecommend.shape_report import ( DeploymentParams, ModelConfig, @@ -275,6 +278,41 @@ def create(config_file=""): class TestAquaShapeRecommend: + + @patch("ads.aqua.shaperecommend.recommend.hf_hub_download") + @patch("builtins.open", new_callable=mock_open) + def test_fetch_hf_config_success(self, mock_file, mock_download): + """Test successful config fetch from Hugging Face""" + app = AquaShapeRecommend() + model_id = "test/model" + config_path = "/fake/path/config.json" + expected_config = {"model_type": "llama", "hidden_size": 4096} + + mock_download.return_value = config_path + mock_file.return_value.read.return_value = json.dumps(expected_config) + + result = app._fetch_hf_config(model_id) + + assert result == expected_config + mock_download.assert_called_once_with(repo_id=model_id, filename="config.json") + + @patch("ads.aqua.shaperecommend.recommend.hf_hub_download") + @patch("ads.aqua.shaperecommend.recommend.format_hf_custom_error_message") + def test_fetch_hf_config_http_error(self, mock_format_error, mock_download): + """Test error handling when Hugging Face request fails""" + from huggingface_hub.utils import HfHubHTTPError + + app = AquaShapeRecommend() + model_id = "nonexistent/model" + http_error = HfHubHTTPError("Model not found") + mock_download.side_effect = http_error + + # The method doesn't re-raise, so it returns None + result = app._fetch_hf_config(model_id) + + assert result is None + mock_format_error.assert_called_once_with(http_error) + @pytest.mark.parametrize( "config, expected_recs, expected_troubleshoot", [ @@ -398,10 +436,11 @@ def test_which_shapes_valid_from_file( )[1], ) - raw = load_config(config_file) + mock_raw_config = load_config(config_file) + mock_ds_model_name = mock_model.display_name if service_managed_model: - config = AquaDeploymentConfig(**raw) + config = AquaDeploymentConfig(**mock_raw_config) request = RequestRecommend( model_id="ocid1.datasciencemodel.oc1.TEST", @@ -409,7 +448,11 @@ def test_which_shapes_valid_from_file( deployment_config=config, ) else: - monkeypatch.setattr(app, "_get_model_config", lambda _: raw) + monkeypatch.setattr( + app, + "_get_model_config_and_name", + lambda model_id: (mock_raw_config, mock_ds_model_name), + ) request = RequestRecommend( model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False