-
Notifications
You must be signed in to change notification settings - Fork 57
[AQUA] Add Hugging Face model support to Shape Recommender #1262
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
32 commits
Select commit
Hold shift + click to select a range
a1c0e12
Add Hugging Face model support to Shape Recommender
Aryanag2 2126fd0
Add Hugging Face model support to Shape Recommender
Aryanag2 b24938e
added HUGGINGFACE_CONFIG_URL and is_valid_ocid, and fixed error message
Aryanag2 0ca295e
Merge branch 'main' into ODSC-76845
mrDzurb ff37df9
added support for no compartment id provided and unit test for the same
Aryanag2 05183f9
Merge branch 'ODSC-76845' of https://github.com/oracle/accelerated-da…
Aryanag2 911a6af
Merge branch 'main' into ODSC-76845
mrDzurb 4c2c6a2
Add Hugging Face model support to Shape Recommender
Aryanag2 d2bf709
Add Hugging Face model support to Shape Recommender
Aryanag2 e78ca08
added HUGGINGFACE_CONFIG_URL and is_valid_ocid, and fixed error message
Aryanag2 8554da2
added support for no compartment id provided and unit test for the same
Aryanag2 76288dd
Merge branch 'ODSC-76845' of https://github.com/oracle/accelerated-da…
Aryanag2 367ce70
removing real network call
Aryanag2 29bd22c
added black formatting
Aryanag2 efe0953
fixed comments and using _get_model_config to do the same checks as w…
Aryanag2 19ddce1
using huggingface_hub.hf_hub_download and design changes
Aryanag2 3d935bb
an example also how to provide the compartment in the parameters.
Aryanag2 e324143
added compartment id logic _get_model_config as a sanity check
Aryanag2 2251e3c
added docstrings
Aryanag2 c3b5afa
commented out search_model_in_catalog logic
Aryanag2 d68e501
added get_resource_type as an import
Aryanag2 79d6f5f
fixed imports
Aryanag2 c840a39
Merge branch 'main' into ODSC-76845
Aryanag2 a362351
fixed imports
Aryanag2 49962ff
resolving comments
Aryanag2 7d02bf1
Merge branch 'main' into ODSC-76845
elizjo a1bde9f
added changes to tests
Aryanag2 86b257b
Merge branch 'ODSC-76845' of https://github.com/oracle/accelerated-da…
Aryanag2 287b406
unittests for huggingface fetching
Aryanag2 5088b85
modified unit tests
elizjo 2a9d843
fixed unit tests
Aryanag2 8edc0e8
Merge branch 'main' into ODSC-76845
mrDzurb File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,9 +2,14 @@ | |
# Copyright (c) 2025 Oracle and/or its affiliates. | ||
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ | ||
|
||
import json | ||
import os | ||
import re | ||
import shutil | ||
from typing import List, Union | ||
from typing import Dict, List, Optional, Tuple, Union | ||
|
||
from huggingface_hub import hf_hub_download | ||
from huggingface_hub.utils import HfHubHTTPError | ||
from pydantic import ValidationError | ||
from rich.table import Table | ||
|
||
|
@@ -17,7 +22,9 @@ | |
) | ||
from ads.aqua.common.utils import ( | ||
build_pydantic_error_message, | ||
format_hf_custom_error_message, | ||
get_resource_type, | ||
is_valid_ocid, | ||
load_config, | ||
load_gpu_shapes_index, | ||
) | ||
|
@@ -37,6 +44,7 @@ | |
ShapeRecommendationReport, | ||
ShapeReport, | ||
) | ||
from ads.config import COMPARTMENT_OCID | ||
from ads.model.datascience_model import DataScienceModel | ||
from ads.model.service.oci_datascience_model_deployment import ( | ||
OCIDataScienceModelDeployment, | ||
|
@@ -91,20 +99,23 @@ def which_shapes( | |
try: | ||
shapes = self.valid_compute_shapes(compartment_id=request.compartment_id) | ||
|
||
ds_model = self._get_data_science_model(request.model_id) | ||
|
||
model_name = ds_model.display_name if ds_model.display_name else "" | ||
|
||
if request.deployment_config: | ||
if is_valid_ocid(request.model_id): | ||
ds_model = self._get_data_science_model(request.model_id) | ||
model_name = ds_model.display_name | ||
else: | ||
model_name = request.model_id | ||
|
||
shape_recommendation_report = ( | ||
ShapeRecommendationReport.from_deployment_config( | ||
request.deployment_config, model_name, shapes | ||
) | ||
) | ||
|
||
else: | ||
data = self._get_model_config(ds_model) | ||
|
||
data, model_name = self._get_model_config_and_name( | ||
model_id=request.model_id, | ||
) | ||
llm_config = LLMConfig.from_raw_config(data) | ||
|
||
shape_recommendation_report = self._summarize_shapes_for_seq_lens( | ||
|
@@ -135,7 +146,57 @@ def which_shapes( | |
|
||
return shape_recommendation_report | ||
|
||
def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary"]: | ||
def _get_model_config_and_name( | ||
self, | ||
model_id: str, | ||
) -> Tuple[Dict, str]: | ||
""" | ||
Loads model configuration by trying OCID logic first, then falling back | ||
to treating the model_id as a Hugging Face Hub ID. | ||
|
||
Parameters | ||
---------- | ||
model_id : str | ||
The model OCID or Hugging Face model ID. | ||
# compartment_id : Optional[str] | ||
# The compartment OCID, used for searching the model catalog. | ||
|
||
Returns | ||
------- | ||
Tuple[Dict, str] | ||
A tuple containing: | ||
- The model configuration dictionary. | ||
- The display name for the model. | ||
""" | ||
if is_valid_ocid(model_id): | ||
logger.info(f"Detected OCID: Fetching OCI model config for '{model_id}'.") | ||
ds_model = self._get_data_science_model(model_id) | ||
config = self._get_model_config(ds_model) | ||
model_name = ds_model.display_name | ||
else: | ||
logger.info( | ||
f"Assuming Hugging Face model ID: Fetching config for '{model_id}'." | ||
) | ||
config = self._fetch_hf_config(model_id) | ||
model_name = model_id | ||
|
||
return config, model_name | ||
|
||
def _fetch_hf_config(self, model_id: str) -> Dict: | ||
""" | ||
Downloads a model's config.json from Hugging Face Hub using the | ||
huggingface_hub library. | ||
""" | ||
Aryanag2 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
try: | ||
config_path = hf_hub_download(repo_id=model_id, filename="config.json") | ||
with open(config_path, "r", encoding="utf-8") as f: | ||
return json.load(f) | ||
except HfHubHTTPError as e: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please check the existing - |
||
format_hf_custom_error_message(e) | ||
|
||
def valid_compute_shapes( | ||
self, compartment_id: Optional[str] = None | ||
) -> List["ComputeShapeSummary"]: | ||
""" | ||
Returns a filtered list of GPU-only ComputeShapeSummary objects by reading and parsing a JSON file. | ||
|
||
|
@@ -151,9 +212,23 @@ def valid_compute_shapes(self, compartment_id: str) -> List["ComputeShapeSummary | |
|
||
Raises | ||
------ | ||
ValueError | ||
If the file cannot be opened, parsed, or the 'shapes' key is missing. | ||
AquaValueError | ||
If a compartment_id is not provided and cannot be found in the | ||
environment variables. | ||
""" | ||
if not compartment_id: | ||
compartment_id = COMPARTMENT_OCID | ||
if compartment_id: | ||
logger.info(f"Using compartment_id from environment: {compartment_id}") | ||
|
||
if not compartment_id: | ||
raise AquaValueError( | ||
"A compartment OCID is required to list available shapes. " | ||
"Please specify it using the --compartment_id parameter.\n\n" | ||
"Example:\n" | ||
'ads aqua deployment recommend_shape --model_id "<YOUR_MODEL_OCID>" --compartment_id "<YOUR_COMPARTMENT_OCID>"' | ||
) | ||
|
||
oci_shapes = OCIDataScienceModelDeployment.shapes(compartment_id=compartment_id) | ||
set_user_shapes = {shape.name: shape for shape in oci_shapes} | ||
|
||
|
@@ -324,6 +399,7 @@ def _get_model_config(model: DataScienceModel): | |
""" | ||
|
||
model_task = model.freeform_tags.get("task", "").lower() | ||
model_task = re.sub(r"-", "_", model_task) | ||
model_format = model.freeform_tags.get("model_format", "").lower() | ||
|
||
logger.info(f"Current model task type: {model_task}") | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NIT: It would be more clean if we had return at the end: