Skip to content

Commit

Permalink
Fix use_auth_token with ORTModel (#1740)
Browse files Browse the repository at this point in the history
fix use_auth_token
  • Loading branch information
fxmarty committed Mar 19, 2024
1 parent 7e08a82 commit 568aa35
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 5 deletions.
13 changes: 11 additions & 2 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,6 +1378,7 @@ def get_model_files(
model_name_or_path: Union[str, Path],
subfolder: str = "",
cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE,
use_auth_token: Optional[str] = None,
):
request_exception = None
full_model_path = Path(model_name_or_path) / subfolder
Expand All @@ -1391,7 +1392,9 @@ def get_model_files(
try:
if not isinstance(model_name_or_path, str):
model_name_or_path = str(model_name_or_path)
all_files = huggingface_hub.list_repo_files(model_name_or_path, repo_type="model")
all_files = huggingface_hub.list_repo_files(
model_name_or_path, repo_type="model", token=use_auth_token
)
if subfolder != "":
all_files = [file[len(subfolder) + 1 :] for file in all_files if file.startswith(subfolder)]
except RequestsConnectionError as e: # Hub not accessible
Expand Down Expand Up @@ -1672,6 +1675,7 @@ def infer_library_from_model(
revision: Optional[str] = None,
cache_dir: str = huggingface_hub.constants.HUGGINGFACE_HUB_CACHE,
library_name: Optional[str] = None,
use_auth_token: Optional[str] = None,
):
"""
Infers the library from the model repo.
Expand All @@ -1689,13 +1693,17 @@ def infer_library_from_model(
Path to a directory in which a downloaded pretrained model weights have been cached if the standard cache should not be used.
library_name (`Optional[str]`, *optional*):
The library name of the model. Can be any of "transformers", "timm", "diffusers", "sentence_transformers".
use_auth_token (`Optional[str]`, defaults to `None`):
The token to use as HTTP bearer authorization for remote files.
Returns:
`str`: The library name automatically detected from the model repo.
"""
if library_name is not None:
return library_name

all_files, _ = TasksManager.get_model_files(model_name_or_path, subfolder, cache_dir)
all_files, _ = TasksManager.get_model_files(
model_name_or_path, subfolder, cache_dir, use_auth_token=use_auth_token
)

if "model_index.json" in all_files:
library_name = "diffusers"
Expand All @@ -1710,6 +1718,7 @@ def infer_library_from_model(
"subfolder": subfolder,
"revision": revision,
"cache_dir": cache_dir,
"use_auth_token": use_auth_token,
}
config_dict, kwargs = PretrainedConfig.get_config_dict(model_name_or_path, **kwargs)
model_config = PretrainedConfig.from_dict(config_dict, **kwargs)
Expand Down
4 changes: 3 additions & 1 deletion optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,9 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir)
library_name = TasksManager.infer_library_from_model(
model_id, subfolder, revision, cache_dir, use_auth_token=use_auth_token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)
Expand Down
7 changes: 5 additions & 2 deletions tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,9 +937,12 @@ def test_stable_diffusion_model_on_rocm_ep_str(self):
self.assertEqual(model.vae_encoder.session.get_providers()[0], "ROCMExecutionProvider")
self.assertListEqual(model.providers, ["ROCMExecutionProvider", "CPUExecutionProvider"])

@require_hf_token
def test_load_model_from_hub_private(self):
model = ORTModel.from_pretrained(self.ONNX_MODEL_ID, use_auth_token=os.environ.get("HF_AUTH_TOKEN", None))
subprocess.run("huggingface-cli logout", shell=True)
# Read token of fxmartyclone (dummy user).
token = "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"

model = ORTModelForCustomTasks.from_pretrained("fxmartyclone/tiny-onnx-private-2", use_auth_token=token)
self.assertIsInstance(model.model, onnxruntime.InferenceSession)
self.assertIsInstance(model.config, PretrainedConfig)

Expand Down

0 comments on commit 568aa35

Please sign in to comment.