From 04017c2a46ad85f4606925befffcf500a3004b85 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 1 Mar 2024 11:16:57 +0100 Subject: [PATCH 1/5] Automatic safetensors conversion when lacking these files --- src/transformers/modeling_utils.py | 37 ++++++++++++++++++++-- src/transformers/safetensors_conversion.py | 1 + tests/test_modeling_utils.py | 28 +++++++++++++++- 3 files changed, 62 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 7bda8a20165b..cf9ae0454955 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -29,6 +29,7 @@ from contextlib import contextmanager from dataclasses import dataclass from functools import partial, wraps +from threading import Thread from typing import Any, Callable, Dict, List, Optional, Tuple, Union from zipfile import is_zipfile @@ -3207,9 +3208,39 @@ def from_pretrained( ) if resolved_archive_file is not None: is_sharded = True - if resolved_archive_file is None: - # Otherwise, maybe there is a TF or Flax model file. We try those to give a helpful error - # message. + + if resolved_archive_file is not None: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "resume_download": resume_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + logging.set_verbosity_debug() + cls._auto_conversion = Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs=cached_file_kwargs, + ) + cls._auto_conversion.start() + else: + # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. + # We try those to give a helpful error message. has_file_kwargs = { "revision": revision, "proxies": proxies, diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py index 46de0e5755fd..06efd583484f 100644 --- a/src/transformers/safetensors_conversion.py +++ b/src/transformers/safetensors_conversion.py @@ -80,6 +80,7 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): logger.info("Safetensors PR exists") sha = f"refs/pr/{pr.num}" + logger.info(sha) return sha diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 0d52e5a87bed..5df5ebe77108 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1428,7 +1428,7 @@ def test_safetensors_on_the_fly_wrong_user_opened_pr(self): bot_opened_pr_title = None for discussion in discussions: - if discussion.author == "SFconvertBot": + if discussion.author == "SFconvertbot": bot_opened_pr = True bot_opened_pr_title = discussion.title @@ -1451,6 +1451,32 @@ def test_safetensors_on_the_fly_specific_revision(self): with self.assertRaises(EnvironmentError): BertModel.from_pretrained(self.repo_name, use_safetensors=True, token=self.token, revision="new-branch") + def test_absence_of_safetensors_triggers_conversion(self): + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + initial_model = BertModel(config) + + # Push a model on `main` + initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) + + initial_model = BertModel.from_pretrained(self.repo_name, token=self.token) + BertModel._auto_conversion.join() + + with self.subTest("PR was open with the safetensors account"): + discussions = self.api.get_repo_discussions(self.repo_name) + + bot_opened_pr = None + bot_opened_pr_title = None + + for discussion in discussions: + if discussion.author == "SFconvertbot": + bot_opened_pr = True + bot_opened_pr_title = discussion.title + + self.assertTrue(bot_opened_pr) + self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") + @require_torch @is_staging_test From 38c31c323602d550720537113279efc7d4d3070f Mon Sep 17 00:00:00 2001 From: Lysandre Date: Fri, 1 Mar 2024 11:18:28 +0100 Subject: [PATCH 2/5] Remove debug --- src/transformers/modeling_utils.py | 1 - src/transformers/safetensors_conversion.py | 1 - 2 files changed, 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index cf9ae0454955..9f051a963150 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3231,7 +3231,6 @@ def from_pretrained( **has_file_kwargs, } if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - logging.set_verbosity_debug() cls._auto_conversion = Thread( target=auto_conversion, args=(pretrained_model_name_or_path,), diff --git a/src/transformers/safetensors_conversion.py b/src/transformers/safetensors_conversion.py index 06efd583484f..46de0e5755fd 100644 --- a/src/transformers/safetensors_conversion.py +++ b/src/transformers/safetensors_conversion.py @@ -80,7 +80,6 @@ def get_conversion_pr_reference(api: HfApi, model_id: str, **kwargs): logger.info("Safetensors PR exists") sha = f"refs/pr/{pr.num}" - logger.info(sha) return sha From b6cabae38f15d1c361b0afbd16aaca721e38277f Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Mar 2024 14:04:45 +0100 Subject: [PATCH 3/5] Thread name --- src/transformers/modeling_utils.py | 7 ++++--- tests/test_modeling_utils.py | 5 ++++- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 9f051a963150..0557b0f76faa 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3231,12 +3231,13 @@ def from_pretrained( **has_file_kwargs, } if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - cls._auto_conversion = Thread( + logging.set_verbosity_debug() + Thread( target=auto_conversion, args=(pretrained_model_name_or_path,), kwargs=cached_file_kwargs, - ) - cls._auto_conversion.start() + name="Thread-autoconversion", + ).start() else: # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. # We try those to give a helpful error message. diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 5df5ebe77108..3eac47dd768a 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -20,6 +20,7 @@ import os.path import sys import tempfile +import threading import unittest import unittest.mock as mock import uuid @@ -1461,7 +1462,9 @@ def test_absence_of_safetensors_triggers_conversion(self): initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) initial_model = BertModel.from_pretrained(self.repo_name, token=self.token) - BertModel._auto_conversion.join() + for thread in threading.enumerate(): + if thread.name == "Thread-autoconversion": + thread.join(timeout=10) with self.subTest("PR was open with the safetensors account"): discussions = self.api.get_repo_discussions(self.repo_name) From c7baeae2a1cbec99918fe4cff5c50d3c764736f1 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Mar 2024 14:21:18 +0100 Subject: [PATCH 4/5] Typo --- src/transformers/modeling_utils.py | 57 +++++++++++++++--------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 0557b0f76faa..4ddebc044e71 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3210,34 +3210,35 @@ def from_pretrained( is_sharded = True if resolved_archive_file is not None: - # If the PyTorch file was found, check if there is a safetensors file on the repository - # If there is no safetensors file on the repositories, start an auto conversion - safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME - has_file_kwargs = { - "revision": revision, - "proxies": proxies, - "token": token, - } - cached_file_kwargs = { - "cache_dir": cache_dir, - "force_download": force_download, - "resume_download": resume_download, - "local_files_only": local_files_only, - "user_agent": user_agent, - "subfolder": subfolder, - "_raise_exceptions_for_gated_repo": False, - "_raise_exceptions_for_missing_entries": False, - "_commit_hash": commit_hash, - **has_file_kwargs, - } - if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - logging.set_verbosity_debug() - Thread( - target=auto_conversion, - args=(pretrained_model_name_or_path,), - kwargs=cached_file_kwargs, - name="Thread-autoconversion", - ).start() + if filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]: + # If the PyTorch file was found, check if there is a safetensors file on the repository + # If there is no safetensors file on the repositories, start an auto conversion + safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME + has_file_kwargs = { + "revision": revision, + "proxies": proxies, + "token": token, + } + cached_file_kwargs = { + "cache_dir": cache_dir, + "force_download": force_download, + "resume_download": resume_download, + "local_files_only": local_files_only, + "user_agent": user_agent, + "subfolder": subfolder, + "_raise_exceptions_for_gated_repo": False, + "_raise_exceptions_for_missing_entries": False, + "_commit_hash": commit_hash, + **has_file_kwargs, + } + if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): + logging.set_verbosity_debug() + Thread( + target=auto_conversion, + args=(pretrained_model_name_or_path,), + kwargs=cached_file_kwargs, + name="Thread-autoconversion", + ).start() else: # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file. # We try those to give a helpful error message. From a3d54cdf472f62bec0cc3fe4f0665036c1ce00a8 Mon Sep 17 00:00:00 2001 From: Lysandre Date: Mon, 4 Mar 2024 16:26:15 +0100 Subject: [PATCH 5/5] Ensure that raises do not affect the main thread --- src/transformers/modeling_utils.py | 1 - tests/test_modeling_utils.py | 19 ++++++++++++++++++- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4ddebc044e71..b54230779416 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -3232,7 +3232,6 @@ def from_pretrained( **has_file_kwargs, } if not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs): - logging.set_verbosity_debug() Thread( target=auto_conversion, args=(pretrained_model_name_or_path,), diff --git a/tests/test_modeling_utils.py b/tests/test_modeling_utils.py index 3eac47dd768a..a334cb0f2853 100755 --- a/tests/test_modeling_utils.py +++ b/tests/test_modeling_utils.py @@ -1461,7 +1461,9 @@ def test_absence_of_safetensors_triggers_conversion(self): # Push a model on `main` initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) - initial_model = BertModel.from_pretrained(self.repo_name, token=self.token) + # Download the model that doesn't have safetensors + BertModel.from_pretrained(self.repo_name, token=self.token) + for thread in threading.enumerate(): if thread.name == "Thread-autoconversion": thread.join(timeout=10) @@ -1480,6 +1482,21 @@ def test_absence_of_safetensors_triggers_conversion(self): self.assertTrue(bot_opened_pr) self.assertEqual(bot_opened_pr_title, "Adding `safetensors` variant of this model") + @mock.patch("transformers.safetensors_conversion.spawn_conversion") + def test_absence_of_safetensors_triggers_conversion_failed(self, spawn_conversion_mock): + spawn_conversion_mock.side_effect = HTTPError() + + config = BertConfig( + vocab_size=99, hidden_size=32, num_hidden_layers=5, num_attention_heads=4, intermediate_size=37 + ) + initial_model = BertModel(config) + + # Push a model on `main` + initial_model.push_to_hub(self.repo_name, token=self.token, safe_serialization=False) + + # The auto conversion is mocked to always raise; ensure that it doesn't raise in the main thread + BertModel.from_pretrained(self.repo_name, token=self.token) + @require_torch @is_staging_test