3
3
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
4
4
import os
5
5
from datetime import datetime , timedelta
6
+ import pathlib
6
7
from threading import Lock
7
8
from typing import Dict , List , Optional , Set , Union
8
9
23
24
list_os_files_with_extension ,
24
25
load_config ,
25
26
read_file ,
27
+ upload_folder ,
26
28
)
27
29
from ads .aqua .constants import (
28
30
AQUA_MODEL_ARTIFACT_CONFIG ,
74
76
from ads .model .model_metadata import ModelCustomMetadata , ModelCustomMetadataItem
75
77
from ads .telemetry import telemetry
76
78
79
+ from huggingface_hub import HfApi , snapshot_download
80
+
77
81
78
82
class AquaModelApp (AquaApp ):
79
83
"""Provides a suite of APIs to interact with Aqua models within the Oracle
@@ -779,12 +783,13 @@ def get_model_files(os_path: str, model_format: ModelFormat) -> List[str]:
779
783
list_os_files_with_extension (oss_path = os_path , extension = ".gguf" )
780
784
)
781
785
return model_files
782
-
783
- def _validate_model_from_object_storage (
786
+
787
+ def _validate_model (
784
788
self ,
785
- os_path : str ,
786
- model_name : str ,
787
- verified_model : DataScienceModel ,
789
+ os_path : str = None ,
790
+ model_name : str = None ,
791
+ verified_model : DataScienceModel = None ,
792
+ download_from_hf : bool = None ,
788
793
) -> ModelValidationResult :
789
794
"""
790
795
Validates the model configuration and returns the model format telemetry model name.
@@ -793,29 +798,58 @@ def _validate_model_from_object_storage(
793
798
os_path (str): OCI where the model is uploaded - oci://bucket@namespace/prefix
794
799
model_name (str): name of the model
795
800
verified_model (DataScienceModel): If set, then copies all the tags and custom metadata information from the service verified model
801
+ download_from_hf (bool): If set, validates file formats from model downloaded from hugginface.
796
802
797
803
Returns:
798
804
ModelValidationResult: The result of the model validation.
799
805
800
806
Raises:
801
807
AquaRuntimeError: If there is an error while loading the config file or if the model path is incorrect.
802
- AquaValueError: If the model format is not supported by AQUA."""
803
-
808
+ AquaValueError: If the model format is not supported by AQUA.
809
+ """
804
810
model_formats = []
805
811
validation_result : ModelValidationResult = ModelValidationResult ()
806
812
807
- # as model formats can be either gguf or safetensors or both, we need to validate if there's at least
808
- # one way to register a valid model by checking if the oss path has the required artifacts.
809
- safetensors_model_files = self .get_model_files (os_path , ModelFormat .SAFETENSORS )
810
- gguf_model_files = self .get_model_files (os_path , ModelFormat .GGUF )
813
+ safetensors_model_files = []
814
+ gguf_model_files = []
815
+ hf_download_config_present = False
816
+
817
+ if download_from_hf :
818
+ model_siblings = []
819
+ try :
820
+ model_siblings = HfApi ().model_info (
821
+ repo_id = model_name
822
+ ).siblings
823
+ except Exception as e :
824
+ huggingface_err_message = str (e )
825
+ raise AquaValueError (
826
+ f"Could not get the model_info of { model_name } from https://huggingface.co with message { huggingface_err_message } ."
827
+ )
828
+
829
+ if not model_siblings :
830
+ raise AquaValueError (
831
+ f"Failed to fetch the model files of { model_name } from https://huggingface.co."
832
+ )
833
+ for model_sibling in model_siblings :
834
+ extension = pathlib .Path (model_sibling .rfilename ).suffix [1 :].upper ()
835
+ if ModelFormat .SAFETENSORS .value == extension :
836
+ if AQUA_MODEL_ARTIFACT_CONFIG not in safetensors_model_files :
837
+ safetensors_model_files .append (AQUA_MODEL_ARTIFACT_CONFIG )
838
+ elif ModelFormat .GGUF .value == extension :
839
+ gguf_model_files .append (model_sibling .rfilename )
840
+ elif model_sibling .rfilename == AQUA_MODEL_ARTIFACT_CONFIG :
841
+ hf_download_config_present = True
842
+ else :
843
+ safetensors_model_files = self .get_model_files (os_path , ModelFormat .SAFETENSORS )
844
+ gguf_model_files = self .get_model_files (os_path , ModelFormat .GGUF )
811
845
812
846
if not (safetensors_model_files or gguf_model_files ):
813
847
raise AquaRuntimeError (
814
- f"The model path { os_path } does not container either { ModelFormat .SAFETENSORS } "
815
- f"or { ModelFormat .GGUF } files. Please check if the path is correct or the model "
848
+ f"The model { model_name } does not contain either { ModelFormat .SAFETENSORS . value } "
849
+ f"or { ModelFormat .GGUF . value } files in { os_path } or Hugging Face repository . Please check if the path is correct or the model "
816
850
f"artifacts are available at this location."
817
851
)
818
-
852
+
819
853
if verified_model :
820
854
aqua_model = self .to_aqua_model (verified_model , self .region )
821
855
model_formats = aqua_model .model_formats
@@ -835,62 +869,69 @@ def _validate_model_from_object_storage(
835
869
model_format == ModelFormat .SAFETENSORS
836
870
and len (safetensors_model_files ) > 0
837
871
):
838
- try :
839
- model_config = load_config (
840
- file_path = os_path ,
841
- config_file_name = AQUA_MODEL_ARTIFACT_CONFIG ,
842
- )
843
- except Exception as ex :
844
- logger .error (
845
- f"Exception occurred while loading config file from { os_path } "
846
- f"Exception message: { ex } "
847
- )
848
- raise AquaRuntimeError (
849
- f"The model path { os_path } does not contain the file config.json. "
850
- f"Please check if the path is correct or the model artifacts are available at this location."
851
- ) from ex
872
+ if download_from_hf :
873
+ # validates config.json exists for safetensors model from hugginface
874
+ if not hf_download_config_present :
875
+ raise AquaRuntimeError (
876
+ f"The model { model_name } does not contain { AQUA_MODEL_ARTIFACT_CONFIG } file as required "
877
+ f"by { ModelFormat .SAFETENSORS .value } format model. Please check if the model name is correct in Hugging Face repository."
878
+ )
852
879
else :
853
880
try :
854
- metadata_model_type = verified_model .custom_metadata_list .get (
855
- AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
856
- ).value
857
- if metadata_model_type :
858
- if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config :
859
- if (
860
- model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]
861
- != metadata_model_type
862
- ):
863
- raise AquaRuntimeError (
864
- f"The { AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE } attribute in { AQUA_MODEL_ARTIFACT_CONFIG } "
865
- f" at { os_path } is invalid, expected { metadata_model_type } for "
866
- f"the model { model_name } . Please check if the path is correct or "
867
- f"the correct model artifacts are available at this location."
868
- f""
869
- )
870
- else :
871
- logger .debug (
872
- f"Could not find { AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE } attribute in "
873
- f"{ AQUA_MODEL_ARTIFACT_CONFIG } . Proceeding with model registration."
874
- )
875
- except Exception :
876
- pass
877
- if verified_model :
878
- validation_result .telemetry_model_name = (
879
- verified_model .display_name
881
+ model_config = load_config (
882
+ file_path = os_path ,
883
+ config_file_name = AQUA_MODEL_ARTIFACT_CONFIG ,
884
+ )
885
+ except Exception as ex :
886
+ logger .error (
887
+ f"Exception occurred while loading config file from { os_path } "
888
+ f"Exception message: { ex } "
880
889
)
881
- elif (
882
- model_config is not None
883
- and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
884
- ):
885
- validation_result .telemetry_model_name = f"{ AQUA_MODEL_TYPE_CUSTOM } _{ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME ]} "
886
- elif (
887
- model_config is not None
888
- and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
889
- ):
890
- validation_result .telemetry_model_name = f"{ AQUA_MODEL_TYPE_CUSTOM } _{ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]} "
890
+ raise AquaRuntimeError (
891
+ f"The model path { os_path } does not contain the file config.json. "
892
+ f"Please check if the path is correct or the model artifacts are available at this location."
893
+ ) from ex
891
894
else :
892
- validation_result .telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
893
-
895
+ try :
896
+ metadata_model_type = verified_model .custom_metadata_list .get (
897
+ AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE
898
+ ).value
899
+ if metadata_model_type :
900
+ if AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config :
901
+ if (
902
+ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]
903
+ != metadata_model_type
904
+ ):
905
+ raise AquaRuntimeError (
906
+ f"The { AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE } attribute in { AQUA_MODEL_ARTIFACT_CONFIG } "
907
+ f" at { os_path } is invalid, expected { metadata_model_type } for "
908
+ f"the model { model_name } . Please check if the path is correct or "
909
+ f"the correct model artifacts are available at this location."
910
+ f""
911
+ )
912
+ else :
913
+ logger .debug (
914
+ f"Could not find { AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE } attribute in "
915
+ f"{ AQUA_MODEL_ARTIFACT_CONFIG } . Proceeding with model registration."
916
+ )
917
+ except Exception :
918
+ pass
919
+ if verified_model :
920
+ validation_result .telemetry_model_name = (
921
+ verified_model .display_name
922
+ )
923
+ elif (
924
+ model_config is not None
925
+ and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME in model_config
926
+ ):
927
+ validation_result .telemetry_model_name = f"{ AQUA_MODEL_TYPE_CUSTOM } _{ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_NAME ]} "
928
+ elif (
929
+ model_config is not None
930
+ and AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE in model_config
931
+ ):
932
+ validation_result .telemetry_model_name = f"{ AQUA_MODEL_TYPE_CUSTOM } _{ model_config [AQUA_MODEL_ARTIFACT_CONFIG_MODEL_TYPE ]} "
933
+ else :
934
+ validation_result .telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
894
935
elif model_format == ModelFormat .GGUF and len (gguf_model_files ) > 0 :
895
936
if verified_model :
896
937
try :
@@ -914,10 +955,54 @@ def _validate_model_from_object_storage(
914
955
if verified_model :
915
956
validation_result .telemetry_model_name = verified_model .display_name
916
957
else :
917
- validation_result .telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
958
+ if download_from_hf :
959
+ validation_result .telemetry_model_name = model_name
960
+ else :
961
+ validation_result .telemetry_model_name = AQUA_MODEL_TYPE_CUSTOM
918
962
919
963
return validation_result
920
964
965
+
966
+ def _download_model_from_hf (
967
+ self ,
968
+ import_model_details : ImportModelDetails ,
969
+ ):
970
+ # Download the model from hub
971
+ model_name = import_model_details .model
972
+ local_dir = import_model_details .local_dir
973
+ if not local_dir :
974
+ local_dir = os .path .join (os .path .expanduser ("~" ), "cached-model" )
975
+ local_dir = os .path .join (local_dir , model_name )
976
+ retry = 10
977
+ i = 0
978
+ huggingface_download_err_message = None
979
+ while i < retry :
980
+ try :
981
+ # Download to cache folder. The while loop retries when there is a network failure
982
+ snapshot_download (repo_id = model_name )
983
+ except Exception as e :
984
+ huggingface_download_err_message = str (e )
985
+ i += 1
986
+ else :
987
+ break
988
+ if i == retry :
989
+ raise Exception (
990
+ f"Could not download the model { model_name } from https://huggingface.co with message { huggingface_download_err_message } "
991
+ )
992
+ os .makedirs (local_dir , exist_ok = True )
993
+ # Copy the model from the cache to destination
994
+ snapshot_download (
995
+ repo_id = model_name , local_dir = local_dir
996
+ )
997
+ # Upload to object storage
998
+ model_artifact_path = upload_folder (
999
+ os_path = import_model_details .os_path ,
1000
+ local_dir = local_dir ,
1001
+ model_name = model_name ,
1002
+ )
1003
+
1004
+ return model_artifact_path
1005
+
921
1006
def register (
922
1007
self , import_model_details : ImportModelDetails = None , ** kwargs
923
1008
) -> AquaModel :
@@ -966,15 +1051,17 @@ def register(
966
1051
)
967
1052
968
1053
# validate model and artifact
1054
+ validation_result = self ._validate_model (
1055
+ os_path = import_model_details .os_path ,
1056
+ model_name = model_name ,
1057
+ verified_model = verified_model ,
1058
+ download_from_hf = import_model_details .download_from_hf ,
1059
+ )
1060
+
1061
+ # download model from hugginface if indicates
969
1062
if import_model_details .download_from_hf :
970
- # todo: added placeholder here, this should be called even before starting HF download.
971
- # as validation can be done on the files that will be downloaded.
972
- validation_result = None
973
- else :
974
- validation_result = self ._validate_model_from_object_storage (
975
- os_path = import_model_details .os_path ,
976
- model_name = model_name ,
977
- verified_model = verified_model ,
1063
+ self ._download_model_from_hf (
1064
+ import_model_details = import_model_details
978
1065
)
979
1066
980
1067
# Create Model catalog entry with pass by reference
0 commit comments