From 7b291843fa63991e0ce34b88a7235eb902d2374a Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Wed, 27 Mar 2024 08:58:08 +0100 Subject: [PATCH] Reimplement "Automatic safetensors conversion when lacking these files" (#29846) * Automatic safetensors conversion when lacking these files (#29390) * Automatic safetensors conversion when lacking these files * Remove debug * Thread name * Typo * Ensure that raises do not affect the main thread * Catch all errors --- src/transformers/modeling_utils.py | 37 +++++++++++++++-- src/transformers/safetensors_conversion.py | 46 +++++++++++---------- tests/test_modeling_utils.py | 48 +++++++++++++++++++++- 3 files changed, 106 insertions(+), 25 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 263ae5d2f988cf..19aab734784a4f 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,6 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps +from threading import Thread from typing import Any, Callable, Dict, List, Optional, Tuple, Union from zipfile import is_zipfile @@ -3228,9 +3229,39 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True - if resolved_archive_file is None: - # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error - # message. + + if resolved_archive_file is not None: + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "resume_download": resume_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs}, + name="Thread-autoconversion", + ).start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. has_file_kwargs = { "revision": revision, "proxies": proxies, diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py index 46de0e5755fdf0..5d3af9e8aad13a 100644 --- a/src/transformers/safetensors_conversion.py +++ b/src/transformers/safetensors_conversion.py @@ -84,24 +84,28 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): return sha -def auto_conversion(pretrained_model_name_or_path: str, **cached_file_kwargs): - api = HfApi(token=cached_file_kwargs.get("token")) - sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) - - if sha is None: - return None, None - cached_file_kwargs["revision"] = sha - del cached_file_kwargs["_commit_hash"] - - # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR - # description. - sharded = api.file_exists( - pretrained_model_name_or_path, - "model.safetensors.index.json", - revision=sha, - token=cached_file_kwargs.get("token"), - ) - filename = "model.safetensors.index.json" if sharded else "model.safetensors" - - resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) - return resolved_archive_file, sha, sharded +def auto_conversion(pretrained_model_name_or_path: str, ignore_errors_during_conversion=False, **cached_file_kwargs): + try: + api = HfApi(token=cached_file_kwargs.get("token")) + sha = get_conversion_pr_reference(api, pretrained_model_name_or_path, **cached_file_kwargs) + + if sha is None: + return None, None + cached_file_kwargs["revision"] = sha + del cached_file_kwargs["_commit_hash"] + + # This is an additional HEAD call that could be removed if we could infer sharded/non-sharded from the PR + # description. + sharded = api.file_exists( + pretrained_model_name_or_path, + "model.safetensors.index.json", + revision=sha, + token=cached_file_kwargs.get("token"), + ) + filename = "model.safetensors.index.json" if sharded else "model.safetensors" + + resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs) + return resolved_archive_file, sha, sharded + except Exception as e: + if not ignore_errors_during_conversion: + raise e diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 4e68fad8ef7fc9..7f82d0dfcaf632 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -20,6 +20,7 @@ import os.path import sys import tempfile +import threading import unittest import unittest.mock as mock import uuid @@ -1428,7 +1429,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self): bot_opened_pr_title = None for discussion in discussions: - if discussion.author == "SFconvertBot": + if discussion.author == "SFconvertbot": bot_opened_pr = True bot_opened_pr_title = discussion.title @@ -1451,6 +1452,51 @@ def test_safetensors_on_the_fly_specific_revision(self): with self.assertRaises(EnvironmentError): BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch") + def test_absence_of_safetensors_triggers_conversion(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + initial_model = BertModel(config) + + # Push a model on `main` + initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) + + # Download the model that doesn't have safetensors + BertModel.from_pretrained(self.repo_name, token=self.token) + + for thread in threading.enumerate(): + if thread.name == "Thread-autoconversion": + thread.join(timeout=10) + + with self.subTest("PR was open with the safetensors account"): + discussions = self.api.get_repo_discussions(self.repo_name) + + bot_opened_pr = None + bot_opened_pr_title = None + + for discussion in discussions: + if discussion.author == "SFconvertbot": + bot_opened_pr = True + bot_opened_pr_title = discussion.title + + self.assertTrue(bot_opened_pr) + self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") + + @mock.patch("transformers.safetensors_conversion.spawn_conversion") + def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock): + spawn_conversion_mock.side_effect = HTTPError() + + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + initial_model = BertModel(config) + + # Push a model on `main` + initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) + + # The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread + BertModel.from_pretrained(self.repo_name, token=self.token) + @require_torch @is_staging_test