Skip to content

Commit 06eda5b

Browse files
authored
Raise initial HTTPError if pipeline is not cached locally (#4230)
* Raise initial HTTPError if pipeline is not cached locally * make style
1 parent 8e5921c commit 06eda5b

File tree

1 file changed

+30
-14
lines changed

1 file changed

+30
-14
lines changed

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12481248
allow_patterns = None
12491249
ignore_patterns = None
12501250

1251+
model_info_call_error: Optional[Exception] = None
12511252
if not local_files_only:
12521253
try:
12531254
info = model_info(
@@ -1258,6 +1259,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
12581259
except HTTPError as e:
12591260
logger.warn(f"Couldn't connect to the Hub: {e}.\nWill try to load from local cache.")
12601261
local_files_only = True
1262+
model_info_call_error = e # save error to reraise it if model is not cached locally
12611263

12621264
if not local_files_only:
12631265
config_file = hf_hub_download(
@@ -1389,20 +1391,34 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
13891391
user_agent["custom_pipeline"] = custom_pipeline
13901392

13911393
# download all allow_patterns - ignore_patterns
1392-
cached_folder = snapshot_download(
1393-
pretrained_model_name,
1394-
cache_dir=cache_dir,
1395-
resume_download=resume_download,
1396-
proxies=proxies,
1397-
local_files_only=local_files_only,
1398-
use_auth_token=use_auth_token,
1399-
revision=revision,
1400-
allow_patterns=allow_patterns,
1401-
ignore_patterns=ignore_patterns,
1402-
user_agent=user_agent,
1403-
)
1404-
1405-
return cached_folder
1394+
try:
1395+
return snapshot_download(
1396+
pretrained_model_name,
1397+
cache_dir=cache_dir,
1398+
resume_download=resume_download,
1399+
proxies=proxies,
1400+
local_files_only=local_files_only,
1401+
use_auth_token=use_auth_token,
1402+
revision=revision,
1403+
allow_patterns=allow_patterns,
1404+
ignore_patterns=ignore_patterns,
1405+
user_agent=user_agent,
1406+
)
1407+
except FileNotFoundError:
1408+
# Means we tried to load pipeline with `local_files_only=True` but the files have not been found in local cache.
1409+
# This can happen in two cases:
1410+
# 1. If the user passed `local_files_only=True` => we raise the error directly
1411+
# 2. If we forced `local_files_only=True` when `model_info` failed => we raise the initial error
1412+
if model_info_call_error is None:
1413+
# 1. user passed `local_files_only=True`
1414+
raise
1415+
else:
1416+
# 2. we forced `local_files_only=True` when `model_info` failed
1417+
raise EnvironmentError(
1418+
f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occured"
1419+
" while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace"
1420+
" above."
1421+
) from model_info_call_error
14061422

14071423
@staticmethod
14081424
def _get_signature_keys(obj):

0 commit comments

Comments
 (0)