Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reimplement "Automatic safetensors conversion when lacking these files" #29846

Merged
merged 2 commits into from
Mar 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 34 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from contextlib import contextmanager
from dataclasses import dataclass
from functools import partial, wraps
from threading import Thread
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from zipfile import is_zipfile

Expand Down Expand Up @@ -3228,9 +3229,39 @@ def from_pretrained(
)
if resolved_archive_file is not None:
is_sharded = True
if resolved_archive_file is None:
# Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error
# message.

if resolved_archive_file is not None:
if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]:
# If the PyTorch file was found, check if there is a safetensors file on the repository
# If there is no safetensors file on the repositories, start an auto conversion
safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
"token": token,
}
cached_file_kwargs = {
"cache_dir": cache_dir,
"force_download": force_download,
"resume_download": resume_download,
"local_files_only": local_files_only,
"user_agent": user_agent,
"subfolder": subfolder,
"_raise_exceptions_for_gated_repo": False,
"_raise_exceptions_for_missing_entries": False,
"_commit_hash": commit_hash,
**has_file_kwargs,
}
if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs):
Thread(
target=auto_conversion,
args=(pretrained_model_name_or_path,),
kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
name="Thread-autoconversion",
).start()
else:
# Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
# We try those to give a helpful error message.
has_file_kwargs = {
"revision": revision,
"proxies": proxies,
Expand Down
46 changes: 25 additions & 21 deletions src/transformers/safetensors_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,24 +84,28 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs):
return sha


def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs):
api = HfApi(token=cached_file_kwargs.get("token"))
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

if sha is None:
return None, None
cached_file_kwargs["revision"] = sha
del cached_file_kwargs["_commit_hash"]

# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
# description.
sharded = api.file_exists(
pretrained_model_name_or_path,
"model.safetensors.index.json",
revision=sha,
token=cached_file_kwargs.get("token"),
)
filename = "model.safetensors.index.json" if sharded else "model.safetensors"

resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
return resolved_archive_file, sha, sharded
def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs):
try:
api = HfApi(token=cached_file_kwargs.get("token"))
sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs)

if sha is None:
return None, None
cached_file_kwargs["revision"] = sha
del cached_file_kwargs["_commit_hash"]

# This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR
# description.
sharded = api.file_exists(
pretrained_model_name_or_path,
"model.safetensors.index.json",
revision=sha,
token=cached_file_kwargs.get("token"),
)
filename = "model.safetensors.index.json" if sharded else "model.safetensors"

resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
return resolved_archive_file, sha, sharded
except Exception as e:
if not ignore_errors_during_conversion:
raise e
48 changes: 47 additions & 1 deletion tests/test_modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import os.path
import sys
import tempfile
import threading
import unittest
import unittest.mock as mock
import uuid
Expand Down Expand Up @@ -1429,7 +1430,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self):
bot_opened_pr_title = None

for discussion in discussions:
if discussion.author == "SFconvertBot":
if discussion.author == "SFconvertbot":
bot_opened_pr = True
bot_opened_pr_title = discussion.title

Expand All @@ -1452,6 +1453,51 @@ def test_safetensors_on_the_fly_specific_revision(self):
with self.assertRaises(EnvironmentError):
BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch")

def test_absence_of_safetensors_triggers_conversion(self):
config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
initial_model = BertModel(config)

# Push a model on `main`
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)

# Download the model that doesn't have safetensors
BertModel.from_pretrained(self.repo_name, token=self.token)

for thread in threading.enumerate():
if thread.name == "Thread-autoconversion":
thread.join(timeout=10)

with self.subTest("PR was open with the safetensors account"):
discussions = self.api.get_repo_discussions(self.repo_name)

bot_opened_pr = None
bot_opened_pr_title = None

for discussion in discussions:
if discussion.author == "SFconvertbot":
bot_opened_pr = True
bot_opened_pr_title = discussion.title

self.assertTrue(bot_opened_pr)
self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model")

@mock.patch("transformers.safetensors_conversion.spawn_conversion")
def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock):
spawn_conversion_mock.side_effect = HTTPError()

config = BertConfig(
vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37
)
initial_model = BertModel(config)

# Push a model on `main`
initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False)

# The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread
BertModel.from_pretrained(self.repo_name, token=self.token)


@require_torch
@is_staging_test
Expand Down
Loading