Skip to content

Commit 31575c1

Browse files
authored
Added Logic to Download Model from HugginFace (#917)
2 parents 656e6e1 + 28c8082 commit 31575c1

File tree

3 files changed

+191
-76
lines changed

3 files changed

+191
-76
lines changed

ads/aqua/common/utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -801,6 +801,33 @@ def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
801801
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
802802

803803

804+
def upload_folder(os_path: str, local_dir: str, model_name: str) -> str:
805+
"""Upload the local folder to the object storage
806+
807+
Args:
808+
os_path (str): object storage URI with prefix. This is the path to upload
809+
local_dir (str): Local directory where the object is downloaded
810+
model_name (str): Name of the huggingface model
811+
Retuns:
812+
str: Object name inside the bucket
813+
"""
814+
os_details: ObjectStorageDetails = ObjectStorageDetails.from_path(os_path)
815+
if not os_details.is_bucket_versioned():
816+
raise ValueError(f"Version is not enabled at object storage location {os_path}")
817+
auth_state = AuthState()
818+
object_path = os_details.filepath.rstrip("/") + "/" + model_name + "/"
819+
command = f"oci os object bulk-upload --src-dir {local_dir} --prefix {object_path} -bn {os_details.bucket} -ns {os_details.namespace} --auth {auth_state.oci_iam_type} --profile {auth_state.oci_key_profile} --no-overwrite"
820+
try:
821+
logger.info(f"Running: {command}")
822+
subprocess.check_call(shlex.split(command))
823+
except subprocess.CalledProcessError as e:
824+
logger.error(
825+
f"Error uploading the object. Exit code: {e.returncode} with error {e.stdout}"
826+
)
827+
828+
return f"oci://{os_details.bucket}@{os_details.namespace}" + "/" + object_path
829+
830+
804831
def is_service_managed_container(container):
805832
return container and container.startswith(SERVICE_MANAGED_CONTAINER_URI_SCHEME)
806833

ads/aqua/model/entities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,8 @@ def _extract_job_lifecycle_details(self, lifecycle_details):
278278
class ImportModelDetails(CLIBuilderMixin):
279279
model: str
280280
os_path: str
281-
download_from_hf: bool = True
281+
download_from_hf: Optional[bool] = True
282+
local_dir: Optional[str] = None
282283
inference_container: Optional[str] = None
283284
finetuning_container: Optional[str] = None
284285
compartment_id: Optional[str] = None

ads/aqua/model/model.py

Lines changed: 162 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
44
import os
55
from datetime import datetime, timedelta
6+
import pathlib
67
from threading import Lock
78
from typing import Dict, List, Optional, Set, Union
89

@@ -23,6 +24,7 @@
2324
list_os_files_with_extension,
2425
load_config,
2526
read_file,
27+
upload_folder,
2628
)
2729
from ads.aqua.constants import (
2830
AQUA_MODEL_ARTIFACT_CONFIG,
@@ -74,6 +76,8 @@
7476
from ads.model.model_metadata import ModelCustomMetadata, ModelCustomMetadataItem
7577
from ads.telemetry import telemetry
7678

79+
from huggingface_hub import HfApi, snapshot_download
80+
7781

7882
class AquaModelApp(AquaApp):
7983
"""Provides a suite of APIs to interact with Aqua models within the Oracle
@@ -779,12 +783,13 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
779783
list_os_files_with_extension(oss_path=os_path, extension=".gguf")
780784
)
781785
return model_files
782-
783-
def _validate_model_from_object_storage(
786+
787+
def _validate_model(
784788
self,
785-
os_path: str,
786-
model_name: str,
787-
verified_model: DataScienceModel,
789+
os_path: str = None,
790+
model_name: str = None,
791+
verified_model: DataScienceModel = None,
792+
download_from_hf: bool = None,
788793
) -> ModelValidationResult:
789794
"""
790795
Validates the model configuration and returns the model format telemetry model name.
@@ -793,29 +798,58 @@ def _validate_model_from_object_storage(
793798
os_path (str): OCI where the model is uploaded - oci://bucket@namespace/prefix
794799
model_name (str): name of the model
795800
verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service verified model
801+
download_from_hf (bool): If set, validates file formats from model downloaded from hugginface.
796802
797803
Returns:
798804
ModelValidationResult: The result of the model validation.
799805
800806
Raises:
801807
AquaRuntimeError: If there is an error while loading the config file or if the model path is incorrect.
802-
AquaValueError: If the model format is not supported by AQUA."""
803-
808+
AquaValueError: If the model format is not supported by AQUA.
809+
"""
804810
model_formats = []
805811
validation_result: ModelValidationResult = ModelValidationResult()
806812

807-
# as model formats can be either gguf or safetensors or both, we need to validate if there's at least
808-
# one way to register a valid model by checking if the oss path has the required artifacts.
809-
safetensors_model_files = self.get_model_files(os_path, ModelFormat.SAFETENSORS)
810-
gguf_model_files = self.get_model_files(os_path, ModelFormat.GGUF)
813+
safetensors_model_files = []
814+
gguf_model_files = []
815+
hf_download_config_present = False
816+
817+
if download_from_hf:
818+
model_siblings = []
819+
try:
820+
model_siblings = HfApi().model_info(
821+
repo_id=model_name
822+
).siblings
823+
except Exception as e:
824+
huggingface_err_message = str(e)
825+
raise AquaValueError(
826+
f"Could not get the model_info of {model_name} from https://huggingface.co with message {huggingface_err_message}."
827+
)
828+
829+
if not model_siblings:
830+
raise AquaValueError(
831+
f"Failed to fetch the model files of {model_name} from https://huggingface.co."
832+
)
833+
for model_sibling in model_siblings:
834+
extension = pathlib.Path(model_sibling.rfilename).suffix[1:].upper()
835+
if ModelFormat.SAFETENSORS.value == extension:
836+
if AQUA_MODEL_ARTIFACT_CONFIG not in safetensors_model_files:
837+
safetensors_model_files.append(AQUA_MODEL_ARTIFACT_CONFIG)
838+
elif ModelFormat.GGUF.value == extension:
839+
gguf_model_files.append(model_sibling.rfilename)
840+
elif model_sibling.rfilename == AQUA_MODEL_ARTIFACT_CONFIG:
841+
hf_download_config_present = True
842+
else:
843+
safetensors_model_files = self.get_model_files(os_path, ModelFormat.SAFETENSORS)
844+
gguf_model_files = self.get_model_files(os_path, ModelFormat.GGUF)
811845

812846
if not (safetensors_model_files or gguf_model_files):
813847
raise AquaRuntimeError(
814-
f"The model path {os_path} does not container either {ModelFormat.SAFETENSORS} "
815-
f"or {ModelFormat.GGUF} files. Please check if the path is correct or the model "
848+
f"The model {model_name} does not contain either {ModelFormat.SAFETENSORS.value} "
849+
f"or {ModelFormat.GGUF.value} files in {os_path} or Hugging Face repository. Please check if the path is correct or the model "
816850
f"artifacts are available at this location."
817851
)
818-
852+
819853
if verified_model:
820854
aqua_model = self.to_aqua_model(verified_model, self.region)
821855
model_formats = aqua_model.model_formats
@@ -835,62 +869,69 @@ def _validate_model_from_object_storage(
835869
model_format == ModelFormat.SAFETENSORS
836870
and len(safetensors_model_files) > 0
837871
):
838-
try:
839-
model_config = load_config(
840-
file_path=os_path,
841-
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
842-
)
843-
except Exception as ex:
844-
logger.error(
845-
f"Exception occurred while loading config file from {os_path}"
846-
f"Exception message: {ex}"
847-
)
848-
raise AquaRuntimeError(
849-
f"The model path {os_path} does not contain the file config.json. "
850-
f"Please check if the path is correct or the model artifacts are available at this location."
851-
) from ex
872+
if download_from_hf:
873+
# validates config.json exists for safetensors model from hugginface
874+
if not hf_download_config_present:
875+
raise AquaRuntimeError(
876+
f"The model {model_name} does not contain {AQUA_MODEL_ARTIFACT_CONFIG} file as required "
877+
f"by {ModelFormat.SAFETENSORS.value} format model. Please check if the model name is correct in Hugging Face repository."
878+
)
852879
else:
853880
try:
854-
metadata_model_type = verified_model.custom_metadata_list.get(
855-
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
856-
).value
857-
if metadata_model_type:
858-
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
859-
if (
860-
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
861-
!= metadata_model_type
862-
):
863-
raise AquaRuntimeError(
864-
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
865-
f" at {os_path} is invalid, expected {metadata_model_type} for "
866-
f"the model {model_name}. Please check if the path is correct or "
867-
f"the correct model artifacts are available at this location."
868-
f""
869-
)
870-
else:
871-
logger.debug(
872-
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
873-
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
874-
)
875-
except Exception:
876-
pass
877-
if verified_model:
878-
validation_result.telemetry_model_name = (
879-
verified_model.display_name
881+
model_config = load_config(
882+
file_path=os_path,
883+
config_file_name=AQUA_MODEL_ARTIFACT_CONFIG,
884+
)
885+
except Exception as ex:
886+
logger.error(
887+
f"Exception occurred while loading config file from {os_path}"
888+
f"Exception message: {ex}"
880889
)
881-
elif (
882-
model_config is not None
883-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
884-
):
885-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
886-
elif (
887-
model_config is not None
888-
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
889-
):
890-
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
890+
raise AquaRuntimeError(
891+
f"The model path {os_path} does not contain the file config.json. "
892+
f"Please check if the path is correct or the model artifacts are available at this location."
893+
) from ex
891894
else:
892-
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
893-
895+
try:
896+
metadata_model_type = verified_model.custom_metadata_list.get(
897+
AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
898+
).value
899+
if metadata_model_type:
900+
if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config:
901+
if (
902+
model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]
903+
!= metadata_model_type
904+
):
905+
raise AquaRuntimeError(
906+
f"The {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in {AQUA_MODEL_ARTIFACT_CONFIG}"
907+
f" at {os_path} is invalid, expected {metadata_model_type} for "
908+
f"the model {model_name}. Please check if the path is correct or "
909+
f"the correct model artifacts are available at this location."
910+
f""
911+
)
912+
else:
913+
logger.debug(
914+
f"Could not find {AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE} attribute in "
915+
f"{AQUA_MODEL_ARTIFACT_CONFIG}. Proceeding with model registration."
916+
)
917+
except Exception:
918+
pass
919+
if verified_model:
920+
validation_result.telemetry_model_name = (
921+
verified_model.display_name
922+
)
923+
elif (
924+
model_config is not None
925+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
926+
):
927+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME]}"
928+
elif (
929+
model_config is not None
930+
and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
931+
):
932+
validation_result.telemetry_model_name = f"{AQUA_MODEL_TYPE_CUSTOM}_{model_config[AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE]}"
933+
else:
934+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
894935
elif model_format == ModelFormat.GGUF and len(gguf_model_files) > 0:
895936
if verified_model:
896937
try:
@@ -914,10 +955,54 @@ def _validate_model_from_object_storage(
914955
if verified_model:
915956
validation_result.telemetry_model_name = verified_model.display_name
916957
else:
917-
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
958+
if download_from_hf:
959+
validation_result.telemetry_model_name = model_name
960+
else:
961+
validation_result.telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
918962

919963
return validation_result
920964

965+
966+
def _download_model_from_hf(
967+
self,
968+
import_model_details: ImportModelDetails,
969+
):
970+
# Download the model from hub
971+
model_name = import_model_details.model
972+
local_dir = import_model_details.local_dir
973+
if not local_dir:
974+
local_dir = os.path.join(os.path.expanduser("~"), "cached-model")
975+
local_dir = os.path.join(local_dir, model_name)
976+
retry = 10
977+
i = 0
978+
huggingface_download_err_message = None
979+
while i < retry:
980+
try:
981+
# Download to cache folder. The while loop retries when there is a network failure
982+
snapshot_download(repo_id=model_name)
983+
except Exception as e:
984+
huggingface_download_err_message = str(e)
985+
i += 1
986+
else:
987+
break
988+
if i == retry:
989+
raise Exception(
990+
f"Could not download the model {model_name} from https://huggingface.co with message {huggingface_download_err_message}"
991+
)
992+
os.makedirs(local_dir, exist_ok=True)
993+
# Copy the model from the cache to destination
994+
snapshot_download(
995+
repo_id=model_name, local_dir=local_dir
996+
)
997+
# Upload to object storage
998+
model_artifact_path = upload_folder(
999+
os_path=import_model_details.os_path,
1000+
local_dir=local_dir,
1001+
model_name=model_name,
1002+
)
1003+
1004+
return model_artifact_path
1005+
9211006
def register(
9221007
self, import_model_details: ImportModelDetails = None, **kwargs
9231008
) -> AquaModel:
@@ -966,15 +1051,17 @@ def register(
9661051
)
9671052

9681053
# validate model and artifact
1054+
validation_result = self._validate_model(
1055+
os_path=import_model_details.os_path,
1056+
model_name=model_name,
1057+
verified_model=verified_model,
1058+
download_from_hf=import_model_details.download_from_hf,
1059+
)
1060+
1061+
# download model from hugginface if indicates
9691062
if import_model_details.download_from_hf:
970-
# todo: added placeholder here, this should be called even before starting HF download.
971-
# as validation can be done on the files that will be downloaded.
972-
validation_result = None
973-
else:
974-
validation_result = self._validate_model_from_object_storage(
975-
os_path=import_model_details.os_path,
976-
model_name=model_name,
977-
verified_model=verified_model,
1063+
self._download_model_from_hf(
1064+
import_model_details=import_model_details
9781065
)
9791066

9801067
# Create Model catalog entry with pass by reference

0 commit comments

Comments
 (0)