Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a1c0e12
Add Hugging Face model support to Shape Recommender
Aryanag2 Sep 9, 2025
2126fd0
Add Hugging Face model support to Shape Recommender
Aryanag2 Sep 9, 2025
b24938e
added HUGGINGFACE_CONFIG_URL and is_valid_ocid, and fixed error message
Aryanag2 Sep 9, 2025
0ca295e
Merge branch 'main' into ODSC-76845
mrDzurb Sep 10, 2025
ff37df9
added support for no compartment id provided and unit test for the same
Aryanag2 Sep 10, 2025
05183f9
Merge branch 'ODSC-76845' of https://github.com/oracle/accelerated-da…
Aryanag2 Sep 10, 2025
911a6af
Merge branch 'main' into ODSC-76845
mrDzurb Sep 10, 2025
4c2c6a2
Add Hugging Face model support to Shape Recommender
Aryanag2 Sep 9, 2025
d2bf709
Add Hugging Face model support to Shape Recommender
Aryanag2 Sep 9, 2025
e78ca08
added HUGGINGFACE_CONFIG_URL and is_valid_ocid, and fixed error message
Aryanag2 Sep 9, 2025
8554da2
added support for no compartment id provided and unit test for the same
Aryanag2 Sep 10, 2025
76288dd
Merge branch 'ODSC-76845' of https://github.com/oracle/accelerated-da…
Aryanag2 Sep 10, 2025
367ce70
removing real network call
Aryanag2 Sep 10, 2025
29bd22c
added black formatting
Aryanag2 Sep 10, 2025
efe0953
fixed comments and using _get_model_config to do the same checks as w…
Aryanag2 Sep 11, 2025
19ddce1
using huggingface_hub.hf_hub_download and design changes
Aryanag2 Sep 11, 2025
3d935bb
an example also how to provide the compartment in the parameters.
Aryanag2 Sep 11, 2025
e324143
added compartment id logic _get_model_config as a sanity check
Aryanag2 Sep 11, 2025
2251e3c
added docstrings
Aryanag2 Sep 11, 2025
c3b5afa
commented out search_model_in_catalog logic
Aryanag2 Sep 11, 2025
d68e501
added get_resource_type as an import
Aryanag2 Sep 11, 2025
79d6f5f
fixed imports
Aryanag2 Sep 11, 2025
c840a39
Merge branch 'main' into ODSC-76845
Aryanag2 Sep 11, 2025
a362351
fixed imports
Aryanag2 Sep 11, 2025
49962ff
resolving comments
Aryanag2 Sep 12, 2025
7d02bf1
Merge branch 'main' into ODSC-76845
elizjo Sep 16, 2025
a1bde9f
added changes to tests
Aryanag2 Sep 16, 2025
86b257b
Merge branch 'ODSC-76845' of https://github.com/oracle/accelerated-da…
Aryanag2 Sep 16, 2025
287b406
unittests for huggingface fetching
Aryanag2 Sep 16, 2025
5088b85
modified unit tests
elizjo Sep 17, 2025
2a9d843
fixed unit tests
Aryanag2 Sep 18, 2025
8edc0e8
Merge branch 'main' into ODSC-76845
mrDzurb Sep 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 86 additions & 10 deletions ads/aqua/shaperecommend/recommend.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,14 @@
# Copyright (c) 2025 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/

import json
import os
import re
import shutil
from typing import List, Union
from typing import Dict, List, Optional, Tuple, Union

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import HfHubHTTPError
from pydantic import ValidationError
from rich.table import Table

Expand All @@ -17,7 +22,9 @@
)
from ads.aqua.common.utils import (
build_pydantic_error_message,
format_hf_custom_error_message,
get_resource_type,
is_valid_ocid,
load_config,
load_gpu_shapes_index,
)
Expand All @@ -37,6 +44,7 @@
ShapeRecommendationReport,
ShapeReport,
)
from ads.config import COMPARTMENT_OCID
from ads.model.datascience_model import DataScienceModel
from ads.model.service.oci_datascience_model_deployment import (
OCIDataScienceModelDeployment,
Expand Down Expand Up @@ -91,20 +99,23 @@ def which_shapes(
try:
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id)

ds_model = self._get_data_science_model(request.model_id)

model_name = ds_model.display_name if ds_model.display_name else ""

if request.deployment_config:
if is_valid_ocid(request.model_id):
ds_model = self._get_data_science_model(request.model_id)
model_name = ds_model.display_name
else:
model_name = request.model_id

shape_recommendation_report = (
ShapeRecommendationReport.from_deployment_config(
request.deployment_config, model_name, shapes
)
)

else:
data = self._get_model_config(ds_model)

data, model_name = self._get_model_config_and_name(
model_id=request.model_id,
)
llm_config = LLMConfig.from_raw_config(data)

shape_recommendation_report = self._summarize_shapes_for_seq_lens(
Expand Down Expand Up @@ -135,7 +146,57 @@ def which_shapes(

return shape_recommendation_report

def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]:
def _get_model_config_and_name(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: It would be more clean if we had return at the end:

    if is_valid_ocid(model_id):
        logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.")
        ds_model = self._validate_model_ocid(model_id)
        config = self._get_model_config(ds_model)
        model_name = ds_model.display_name
    else:
        logger.info(f"Assuming Hugging Face model ID: Fetching config for '{model_id}'.")
        config = self._fetch_hf_config(model_id)
        model_name = model_id

    return config, model_name

self,
model_id: str,
) -> Tuple[Dict, str]:
"""
Loads model configuration by trying OCID logic first, then falling back
to treating the model_id as a Hugging Face Hub ID.

Parameters
----------
model_id : str
The model OCID or Hugging Face model ID.
# compartment_id : Optional[str]
# The compartment OCID, used for searching the model catalog.

Returns
-------
Tuple[Dict, str]
A tuple containing:
- The model configuration dictionary.
- The display name for the model.
"""
if is_valid_ocid(model_id):
logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.")
ds_model = self._get_data_science_model(model_id)
config = self._get_model_config(ds_model)
model_name = ds_model.display_name
else:
logger.info(
f"Assuming Hugging Face model ID: Fetching config for '{model_id}'."
)
config = self._fetch_hf_config(model_id)
model_name = model_id

return config, model_name

def _fetch_hf_config(self, model_id: str) -> Dict:
"""
Downloads a model's config.json from Hugging Face Hub using the
huggingface_hub library.
"""
try:
config_path = hf_hub_download(repo_id=model_id, filename="config.json")
with open(config_path, "r", encoding="utf-8") as f:
return json.load(f)
except HfHubHTTPError as e:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please check the existing - format_hf_custom_error_message, i'm wondering if it can be reused here?

format_hf_custom_error_message(e)

def valid_compute_shapes(
self, compartment_id: Optional[str] = None
) -> List["ComputeShapeSummary"]:
"""
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file.

Expand All @@ -151,9 +212,23 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary

Raises
------
ValueError
If the file cannot be opened, parsed, or the 'shapes' key is missing.
AquaValueError
If a compartment_id is not provided and cannot be found in the
environment variables.
"""
if not compartment_id:
compartment_id = COMPARTMENT_OCID
if compartment_id:
logger.info(f"Using compartment_id from environment: {compartment_id}")

if not compartment_id:
raise AquaValueError(
"A compartment OCID is required to list available shapes. "
"Please specify it using the --compartment_id parameter.\n\n"
"Example:\n"
'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"'
)

oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id)
set_user_shapes = {shape.name: shape for shape in oci_shapes}

Expand Down Expand Up @@ -324,6 +399,7 @@ def _get_model_config(model: DataScienceModel):
"""

model_task = model.freeform_tags.get("task", "").lower()
model_task = re.sub(r"-", "_", model_task)
model_format = model.freeform_tags.get("model_format", "").lower()

logger.info(f"Current model task type: {model_task}")
Expand Down
3 changes: 2 additions & 1 deletion ads/aqua/shaperecommend/shape_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ class RequestRecommend(BaseModel):
"""

model_id: str = Field(
..., description="The OCID of the model to recommend feasible compute shapes."
...,
description="The OCID or Hugging Face ID of the model to recommend feasible compute shapes.",
)
generate_table: Optional[bool] = (
Field(
Expand Down
57 changes: 50 additions & 7 deletions tests/unitary/with_extras/aqua/test_recommend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@
import json
import os
import re
from unittest.mock import MagicMock

from unittest.mock import MagicMock, mock_open
from unittest.mock import patch
import pytest

from ads.aqua.common.entities import ComputeShapeSummary
from ads.aqua.common.errors import AquaRecommendationError

from ads.aqua.common.errors import AquaRecommendationError, AquaValueError
from ads.aqua.modeldeployment.config_loader import AquaDeploymentConfig
from ads.aqua.shaperecommend.estimator import (
LlamaMemoryEstimator,
Expand All @@ -21,7 +22,9 @@
get_estimator,
)
from ads.aqua.shaperecommend.llm_config import LLMConfig
from ads.aqua.shaperecommend.recommend import AquaShapeRecommend
from ads.aqua.shaperecommend.recommend import (
AquaShapeRecommend,
)
from ads.aqua.shaperecommend.shape_report import (
DeploymentParams,
ModelConfig,
Expand Down Expand Up @@ -275,6 +278,41 @@ def create(config_file=""):


class TestAquaShapeRecommend:

@patch("ads.aqua.shaperecommend.recommend.hf_hub_download")
@patch("builtins.open", new_callable=mock_open)
def test_fetch_hf_config_success(self, mock_file, mock_download):
"""Test successful config fetch from Hugging Face"""
app = AquaShapeRecommend()
model_id = "test/model"
config_path = "/fake/path/config.json"
expected_config = {"model_type": "llama", "hidden_size": 4096}

mock_download.return_value = config_path
mock_file.return_value.read.return_value = json.dumps(expected_config)

result = app._fetch_hf_config(model_id)

assert result == expected_config
mock_download.assert_called_once_with(repo_id=model_id, filename="config.json")

@patch("ads.aqua.shaperecommend.recommend.hf_hub_download")
@patch("ads.aqua.shaperecommend.recommend.format_hf_custom_error_message")
def test_fetch_hf_config_http_error(self, mock_format_error, mock_download):
"""Test error handling when Hugging Face request fails"""
from huggingface_hub.utils import HfHubHTTPError

app = AquaShapeRecommend()
model_id = "nonexistent/model"
http_error = HfHubHTTPError("Model not found")
mock_download.side_effect = http_error

# The method doesn't re-raise, so it returns None
result = app._fetch_hf_config(model_id)

assert result is None
mock_format_error.assert_called_once_with(http_error)

@pytest.mark.parametrize(
"config, expected_recs, expected_troubleshoot",
[
Expand Down Expand Up @@ -398,18 +436,23 @@ def test_which_shapes_valid_from_file(
)[1],
)

raw = load_config(config_file)
mock_raw_config = load_config(config_file)
mock_ds_model_name = mock_model.display_name

if service_managed_model:
config = AquaDeploymentConfig(**raw)
config = AquaDeploymentConfig(**mock_raw_config)

request = RequestRecommend(
model_id="ocid1.datasciencemodel.oc1.TEST",
generate_table=False,
deployment_config=config,
)
else:
monkeypatch.setattr(app, "_get_model_config", lambda _: raw)
monkeypatch.setattr(
app,
"_get_model_config_and_name",
lambda model_id: (mock_raw_config, mock_ds_model_name),
)

request = RequestRecommend(
model_id="ocid1.datasciencemodel.oc1.TEST", generate_table=False
Expand Down