diff --git a/src/transformers/utils/hub.py b/src/transformers/utils/hub.py index 7103ba6b5035..78c2f6bbb68c 100644 --- a/src/transformers/utils/hub.py +++ b/src/transformers/utils/hub.py @@ -413,6 +413,75 @@ def cached_files( if subfolder is None: subfolder = "" + if local_files_only or is_offline_mode(): + cache_dirs_to_try = [] + + if cache_dir is not None: + cache_dirs_to_try.append(cache_dir) + + for env_var in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_HUB_CACHE"]: + env_cache = os.environ.get(env_var) + if env_cache and env_cache not in cache_dirs_to_try: + cache_dirs_to_try.append(env_cache) + + default_cache = default_cache_path + if default_cache not in cache_dirs_to_try: + cache_dirs_to_try.append(default_cache) + + for potential_cache_dir in cache_dirs_to_try: + if not os.path.exists(potential_cache_dir): + continue + + # Construct the cache path following HF Hub structure + repo_id_sanitized = path_or_repo_id.replace("/", "--") + model_cache_dir = os.path.join(potential_cache_dir, f"models--{repo_id_sanitized}") + + if not os.path.exists(model_cache_dir): + continue + + # Try to find the file in snapshots + refs_dir = os.path.join(model_cache_dir, "refs") + snapshots_dir = os.path.join(model_cache_dir, "snapshots") + + if os.path.exists(refs_dir) and os.path.exists(snapshots_dir): + # Try to get commit hash from refs + ref_file = os.path.join(refs_dir, revision or "main") + if os.path.exists(ref_file): + with open(ref_file, "r", encoding="utf-8") as f: + commit_hash = f.read().strip() + + # Check if file exists in this snapshot + found_files = [] + for fname in filenames: + if subfolder: + file_path = os.path.join(snapshots_dir, commit_hash, subfolder, fname) + else: + file_path = os.path.join(snapshots_dir, commit_hash, fname) + + if os.path.exists(file_path): + found_files.append(file_path) + + # If we found all files, return them + if len(found_files) == len(filenames): + logger.info(f"Found all cached files in {snapshots_dir}/{commit_hash}") + return found_files + + # If ref doesn't exist, try to find any snapshot with the file + if os.path.exists(snapshots_dir): + for commit_dir in os.listdir(snapshots_dir): + commit_path = os.path.join(snapshots_dir, commit_dir) + if not os.path.isdir(commit_path): + continue + + if subfolder: + file_path = os.path.join(commit_path, subfolder, fname) + else: + file_path = os.path.join(commit_path, fname) + + if os.path.exists(file_path): + logger.info(f"Found cached file at {file_path}") + return file_path + # Add folder to filenames full_filenames = [os.path.join(subfolder, file) for file in filenames] @@ -514,8 +583,6 @@ def cached_files( "Check cache directory permissions. Common causes: 1) another user is downloading the same model (please wait); " "2) a previous download was canceled and the lock file needs manual removal." ) from e - elif isinstance(e, ValueError): - raise OSError(f"{e}") from e # Now we try to recover if we can find all files correctly in the cache resolved_files = [ @@ -884,6 +951,10 @@ def push_to_hub( ``` """ ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False) + save_jinja_files = deprecated_kwargs.pop( + "save_jinja_files", None + ) # TODO: This is only used for testing and should be removed once save_jinja_files becomes the default + repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None) if repo_path_or_name is not None: # Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer @@ -929,11 +1000,15 @@ def push_to_hub( files_timestamps = self._get_files_timestamps(work_dir) # Save all files. - self.save_pretrained( - work_dir, - max_shard_size=max_shard_size, - safe_serialization=safe_serialization, - ) + if save_jinja_files: + self.save_pretrained( + work_dir, + max_shard_size=max_shard_size, + safe_serialization=safe_serialization, + save_jinja_files=True, + ) + else: + self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization) # Update model card if needed: model_card.save(os.path.join(work_dir, "README.md")) diff --git a/tests/utils/test.py b/tests/utils/test.py new file mode 100644 index 000000000000..ecb959899b94 --- /dev/null +++ b/tests/utils/test.py @@ -0,0 +1,178 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for offline mode functionality. +Regression tests for issue #41311: https://github.com/huggingface/transformers/issues/41311 +""" + +import os +import subprocess +import sys +import tempfile +import unittest + + +class TestOfflineMode(unittest.TestCase): + """ + Test that models can be loaded offline after cache is warmed in a subprocess. + These are regression tests for issue #41311. + """ + + def test_subprocess_warm_cache_then_offline_load(self): + """ + Test that warming cache in subprocess allows offline loading in parent process. + Regression test for: https://github.com/huggingface/transformers/issues/41311 + """ + model_name = "hf-internal-testing/tiny-random-bert" + + with tempfile.TemporaryDirectory() as cache_dir: + env = os.environ.copy() + env["HF_HOME"] = cache_dir + + # Step 1: Download model in subprocess + warm_script = f""" +import os +os.environ["HF_HOME"] = "{cache_dir}" + +from transformers import AutoConfig, AutoModel, AutoTokenizer + +config = AutoConfig.from_pretrained("{model_name}") +model = AutoModel.from_pretrained("{model_name}") +tokenizer = AutoTokenizer.from_pretrained("{model_name}") +print("CACHE_WARMED") +""" + + result = subprocess.run( + [sys.executable, "-c", warm_script], + capture_output=True, + text=True, + env=env, + timeout=120, + ) + + self.assertEqual(result.returncode, 0, f"Cache warming failed: {result.stderr}") + self.assertIn("CACHE_WARMED", result.stdout) + + # Step 2: Load offline with socket blocking (after imports) + offline_script = f""" +import os +os.environ["HF_HOME"] = "{cache_dir}" +os.environ["HF_HUB_OFFLINE"] = "1" + +# Import transformers first +from transformers import AutoConfig, AutoModel, AutoTokenizer + +# Then block sockets to ensure no network access +import socket +original_socket = socket.socket +def guarded_socket(*args, **kwargs): + raise RuntimeError("Network access attempted in offline mode!") +socket.socket = guarded_socket + +try: + config = AutoConfig.from_pretrained("{model_name}") + model = AutoModel.from_pretrained("{model_name}") + tokenizer = AutoTokenizer.from_pretrained("{model_name}") + print("OFFLINE_SUCCESS") +except RuntimeError as e: + if "Network access" in str(e): + print(f"NETWORK_ATTEMPTED: {{e}}") + exit(1) + raise +except Exception as e: + print(f"FAILED: {{e}}") + import traceback + traceback.print_exc() + exit(1) +""" + + result = subprocess.run( + [sys.executable, "-c", offline_script], + capture_output=True, + text=True, + env=env, + timeout=120, + ) + + if "NETWORK_ATTEMPTED" in result.stdout: + self.fail(f"Network access attempted despite warm cache: {result.stdout}") + + self.assertIn( + "OFFLINE_SUCCESS", + result.stdout, + f"Failed to load offline:\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}", + ) + self.assertEqual(result.returncode, 0) + + def test_pipeline_offline_after_subprocess_warm(self): + """ + Test pipeline API works offline after subprocess cache warming. + """ + model_name = "hf-internal-testing/tiny-random-bert" + + with tempfile.TemporaryDirectory() as cache_dir: + env = os.environ.copy() + env["HF_HOME"] = cache_dir + + # Warm cache + warm_script = f""" +import os +os.environ["HF_HOME"] = "{cache_dir}" + +from transformers import pipeline + +pipe = pipeline("text-classification", model="{model_name}") +print("WARMED") +""" + + result = subprocess.run( + [sys.executable, "-c", warm_script], capture_output=True, text=True, env=env, timeout=120 + ) + self.assertEqual(result.returncode, 0) + + # Load offline + offline_script = f""" +import os +os.environ["HF_HOME"] = "{cache_dir}" +os.environ["HF_HUB_OFFLINE"] = "1" + +from transformers import pipeline +import socket + +# Block sockets after imports +def no_socket(*args, **kwargs): + raise RuntimeError("Network blocked!") +socket.socket = no_socket + +try: + pipe = pipeline("text-classification", model="{model_name}") + print("SUCCESS") +except RuntimeError as e: + if "Network blocked" in str(e): + print(f"BLOCKED: {{e}}") + exit(1) + raise +except Exception as e: + print(f"ERROR: {{e}}") + exit(1) +""" + + result = subprocess.run( + [sys.executable, "-c", offline_script], capture_output=True, text=True, env=env, timeout=120 + ) + + self.assertNotIn("BLOCKED", result.stdout, "Network access attempted") + self.assertIn("SUCCESS", result.stdout) + self.assertEqual(result.returncode, 0)