diff --git a/src/huggingface_hub/file_download.py b/src/huggingface_hub/file_download.py index 6344e2b4fb..1720036084 100644 --- a/src/huggingface_hub/file_download.py +++ b/src/huggingface_hub/file_download.py @@ -122,6 +122,10 @@ def is_tf_available(): return _tf_available +def get_tf_version(): + return _tf_version + + def is_fastai_available(): return _fastai_available diff --git a/src/huggingface_hub/keras_mixin.py b/src/huggingface_hub/keras_mixin.py index 0947e2a95f..46183a48c3 100644 --- a/src/huggingface_hub/keras_mixin.py +++ b/src/huggingface_hub/keras_mixin.py @@ -8,6 +8,7 @@ import yaml from huggingface_hub import ModelHubMixin from huggingface_hub.file_download import ( + get_tf_version, is_graphviz_available, is_pydot_available, is_tf_available, @@ -509,7 +510,11 @@ def _from_pretrained( # Root is either a local filepath matching model_id or a cached snapshot if not os.path.isdir(model_id): storage_folder = snapshot_download( - repo_id=model_id, revision=revision, cache_dir=cache_dir + repo_id=model_id, + revision=revision, + cache_dir=cache_dir, + library_name="keras", + library_version=get_tf_version(), ) else: storage_folder = model_id