From a69cbf4e64c7bc054d814d64f6877180f7cd3a25 Mon Sep 17 00:00:00 2001 From: Lysandre Debut Date: Tue, 5 Mar 2024 13:37:55 +0100 Subject: [PATCH] Automatic safetensors conversion when lacking these files (#29390) * Automatic safetensors conversion when lacking these files * Remove debug * Thread name * Typo * Ensure that raises do not affect the main thread --- src/transformers/modeling_utils.py | 37 +++++++++++++++++++++-- tests/test_modeling_utils.py | 48 +++++++++++++++++++++++++++++- 2 files changed, 81 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7bda8a20165b5e..b542307794168d 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,6 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps +from threading import Thread from typing import Any, Callable, Dict, List, Optional, Tuple, Union from zipfile import is_zipfile @@ -3207,9 +3208,39 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True - if resolved_archive_file is None: - # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error - # message. + + if resolved_archive_file is not None: + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "resume_download": resume_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs=cached_file_kwargs, + name="Thread-autoconversion", + ).start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. has_file_kwargs = { "revision": revision, "proxies": proxies, diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 0d52e5a87bed35..a334cb0f2853b5 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -20,6 +20,7 @@ import os.path import sys import tempfile +import threading import unittest import unittest.mock as mock import uuid @@ -1428,7 +1429,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self): bot_opened_pr_title = None for discussion in discussions: - if discussion.author == "SFconvertBot": + if discussion.author == "SFconvertbot": bot_opened_pr = True bot_opened_pr_title = discussion.title @@ -1451,6 +1452,51 @@ def test_safetensors_on_the_fly_specific_revision(self): with self.assertRaises(EnvironmentError): BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch") + def test_absence_of_safetensors_triggers_conversion(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + initial_model = BertModel(config) + + # Push a model on `main` + initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) + + # Download the model that doesn't have safetensors + BertModel.from_pretrained(self.repo_name, token=self.token) + + for thread in threading.enumerate(): + if thread.name == "Thread-autoconversion": + thread.join(timeout=10) + + with self.subTest("PR was open with the safetensors account"): + discussions = self.api.get_repo_discussions(self.repo_name) + + bot_opened_pr = None + bot_opened_pr_title = None + + for discussion in discussions: + if discussion.author == "SFconvertbot": + bot_opened_pr = True + bot_opened_pr_title = discussion.title + + self.assertTrue(bot_opened_pr) + self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") + + @mock.patch("transformers.safetensors_conversion.spawn_conversion") + def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock): + spawn_conversion_mock.side_effect = HTTPError() + + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + initial_model = BertModel(config) + + # Push a model on `main` + initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) + + # The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread + BertModel.from_pretrained(self.repo_name, token=self.token) + @require_torch @is_staging_test