Skip to content

Commit

Permalink
Catch all errors
Browse files Browse the repository at this point in the history
  • Loading branch information
LysandreJik committed Mar 25, 2024
1 parent 65c092a commit 9a7bcb2
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3256,7 +3256,7 @@ def from_pretrained(
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs=cached_file_kwargs,
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
name="Thread-autoconversion",
).start()
else:
Expand Down
46 changes: 25 additions & 21 deletions src/transformers/safetensors_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,28 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
return sha


def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
api = HfApi(token=cached_file_kwargs.get("token"))
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

if sha is None:
return None, None
cached_file_kwargs["revision"] = sha
del cached_file_kwargs["_commit_hash"]

# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
# description.
sharded = api.file_exists(
pretrained_model_name_or_path,
"model.safetensors.index.json",
revision=sha,
token=cached_file_kwargs.get("token"),
)
filename = "model.safetensors.index.json" if sharded else "model.safetensors"

resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
return resolved_archive_file, sha, sharded
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
try:
api = HfApi(token=cached_file_kwargs.get("token"))
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

if sha is None:
return None, None
cached_file_kwargs["revision"] = sha
del cached_file_kwargs["_commit_hash"]

# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
# description.
sharded = api.file_exists(
pretrained_model_name_or_path,
"model.safetensors.index.json",
revision=sha,
token=cached_file_kwargs.get("token"),
)
filename = "model.safetensors.index.json" if sharded else "model.safetensors"

resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
return resolved_archive_file, sha, sharded
except Exception as e:
if not ignore_errors_during_conversion:
raise e

0 comments on commit 9a7bcb2

Please sign in to comment.