Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 82 additions & 7 deletions src/transformers/utils/hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,75 @@ def cached_files(
if subfolder is None:
subfolder = ""

if local_files_only or is_offline_mode():
cache_dirs_to_try = []

if cache_dir is not None:
cache_dirs_to_try.append(cache_dir)

for env_var in ["HF_HOME", "TRANSFORMERS_CACHE", "HF_HUB_CACHE"]:
env_cache = os.environ.get(env_var)
if env_cache and env_cache not in cache_dirs_to_try:
cache_dirs_to_try.append(env_cache)

default_cache = default_cache_path
if default_cache not in cache_dirs_to_try:
cache_dirs_to_try.append(default_cache)

for potential_cache_dir in cache_dirs_to_try:
if not os.path.exists(potential_cache_dir):
continue

# Construct the cache path following HF Hub structure
repo_id_sanitized = path_or_repo_id.replace("/", "--")
model_cache_dir = os.path.join(potential_cache_dir, f"models--{repo_id_sanitized}")

if not os.path.exists(model_cache_dir):
continue

# Try to find the file in snapshots
refs_dir = os.path.join(model_cache_dir, "refs")
snapshots_dir = os.path.join(model_cache_dir, "snapshots")

if os.path.exists(refs_dir) and os.path.exists(snapshots_dir):
# Try to get commit hash from refs
ref_file = os.path.join(refs_dir, revision or "main")
if os.path.exists(ref_file):
with open(ref_file, "r", encoding="utf-8") as f:
commit_hash = f.read().strip()

# Check if file exists in this snapshot
found_files = []
for fname in filenames:
if subfolder:
file_path = os.path.join(snapshots_dir, commit_hash, subfolder, fname)
else:
file_path = os.path.join(snapshots_dir, commit_hash, fname)

if os.path.exists(file_path):
found_files.append(file_path)

# If we found all files, return them
if len(found_files) == len(filenames):
logger.info(f"Found all cached files in {snapshots_dir}/{commit_hash}")
return found_files

# If ref doesn't exist, try to find any snapshot with the file
if os.path.exists(snapshots_dir):
for commit_dir in os.listdir(snapshots_dir):
commit_path = os.path.join(snapshots_dir, commit_dir)
if not os.path.isdir(commit_path):
continue

if subfolder:
file_path = os.path.join(commit_path, subfolder, fname)
else:
file_path = os.path.join(commit_path, fname)

if os.path.exists(file_path):
logger.info(f"Found cached file at {file_path}")
return file_path

# Add folder to filenames
full_filenames = [os.path.join(subfolder, file) for file in filenames]

Expand Down Expand Up @@ -514,8 +583,6 @@ def cached_files(
"Check cache directory permissions. Common causes: 1) another user is downloading the same model (please wait); "
"2) a previous download was canceled and the lock file needs manual removal."
) from e
elif isinstance(e, ValueError):
raise OSError(f"{e}") from e

# Now we try to recover if we can find all files correctly in the cache
resolved_files = [
Expand Down Expand Up @@ -884,6 +951,10 @@ def push_to_hub(
```
"""
ignore_metadata_errors = deprecated_kwargs.pop("ignore_metadata_errors", False)
save_jinja_files = deprecated_kwargs.pop(
"save_jinja_files", None
) # TODO: This is only used for testing and should be removed once save_jinja_files becomes the default

repo_path_or_name = deprecated_kwargs.pop("repo_path_or_name", None)
if repo_path_or_name is not None:
# Should use `repo_id` instead of `repo_path_or_name`. When using `repo_path_or_name`, we try to infer
Expand Down Expand Up @@ -929,11 +1000,15 @@ def push_to_hub(
files_timestamps = self._get_files_timestamps(work_dir)

# Save all files.
self.save_pretrained(
work_dir,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
)
if save_jinja_files:
self.save_pretrained(
work_dir,
max_shard_size=max_shard_size,
safe_serialization=safe_serialization,
save_jinja_files=True,
)
else:
self.save_pretrained(work_dir, max_shard_size=max_shard_size, safe_serialization=safe_serialization)

# Update model card if needed:
model_card.save(os.path.join(work_dir, "README.md"))
Expand Down
178 changes: 178 additions & 0 deletions tests/utils/test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Tests for offline mode functionality.
Regression tests for issue #41311: https://github.com/huggingface/transformers/issues/41311
"""

import os
import subprocess
import sys
import tempfile
import unittest


class TestOfflineMode(unittest.TestCase):
"""
Test that models can be loaded offline after cache is warmed in a subprocess.
These are regression tests for issue #41311.
"""

def test_subprocess_warm_cache_then_offline_load(self):
"""
Test that warming cache in subprocess allows offline loading in parent process.
Regression test for: https://github.com/huggingface/transformers/issues/41311
"""
model_name = "hf-internal-testing/tiny-random-bert"

with tempfile.TemporaryDirectory() as cache_dir:
env = os.environ.copy()
env["HF_HOME"] = cache_dir

# Step 1: Download model in subprocess
warm_script = f"""
import os
os.environ["HF_HOME"] = "{cache_dir}"

from transformers import AutoConfig, AutoModel, AutoTokenizer

config = AutoConfig.from_pretrained("{model_name}")
model = AutoModel.from_pretrained("{model_name}")
tokenizer = AutoTokenizer.from_pretrained("{model_name}")
print("CACHE_WARMED")
"""

result = subprocess.run(
[sys.executable, "-c", warm_script],
capture_output=True,
text=True,
env=env,
timeout=120,
)

self.assertEqual(result.returncode, 0, f"Cache warming failed: {result.stderr}")
self.assertIn("CACHE_WARMED", result.stdout)

# Step 2: Load offline with socket blocking (after imports)
offline_script = f"""
import os
os.environ["HF_HOME"] = "{cache_dir}"
os.environ["HF_HUB_OFFLINE"] = "1"

# Import transformers first
from transformers import AutoConfig, AutoModel, AutoTokenizer

# Then block sockets to ensure no network access
import socket
original_socket = socket.socket
def guarded_socket(*args, **kwargs):
raise RuntimeError("Network access attempted in offline mode!")
socket.socket = guarded_socket

try:
config = AutoConfig.from_pretrained("{model_name}")
model = AutoModel.from_pretrained("{model_name}")
tokenizer = AutoTokenizer.from_pretrained("{model_name}")
print("OFFLINE_SUCCESS")
except RuntimeError as e:
if "Network access" in str(e):
print(f"NETWORK_ATTEMPTED: {{e}}")
exit(1)
raise
except Exception as e:
print(f"FAILED: {{e}}")
import traceback
traceback.print_exc()
exit(1)
"""

result = subprocess.run(
[sys.executable, "-c", offline_script],
capture_output=True,
text=True,
env=env,
timeout=120,
)

if "NETWORK_ATTEMPTED" in result.stdout:
self.fail(f"Network access attempted despite warm cache: {result.stdout}")

self.assertIn(
"OFFLINE_SUCCESS",
result.stdout,
f"Failed to load offline:\nSTDOUT: {result.stdout}\nSTDERR: {result.stderr}",
)
self.assertEqual(result.returncode, 0)

def test_pipeline_offline_after_subprocess_warm(self):
"""
Test pipeline API works offline after subprocess cache warming.
"""
model_name = "hf-internal-testing/tiny-random-bert"

with tempfile.TemporaryDirectory() as cache_dir:
env = os.environ.copy()
env["HF_HOME"] = cache_dir

# Warm cache
warm_script = f"""
import os
os.environ["HF_HOME"] = "{cache_dir}"

from transformers import pipeline

pipe = pipeline("text-classification", model="{model_name}")
print("WARMED")
"""

result = subprocess.run(
[sys.executable, "-c", warm_script], capture_output=True, text=True, env=env, timeout=120
)
self.assertEqual(result.returncode, 0)

# Load offline
offline_script = f"""
import os
os.environ["HF_HOME"] = "{cache_dir}"
os.environ["HF_HUB_OFFLINE"] = "1"

from transformers import pipeline
import socket

# Block sockets after imports
def no_socket(*args, **kwargs):
raise RuntimeError("Network blocked!")
socket.socket = no_socket

try:
pipe = pipeline("text-classification", model="{model_name}")
print("SUCCESS")
except RuntimeError as e:
if "Network blocked" in str(e):
print(f"BLOCKED: {{e}}")
exit(1)
raise
except Exception as e:
print(f"ERROR: {{e}}")
exit(1)
"""

result = subprocess.run(
[sys.executable, "-c", offline_script], capture_output=True, text=True, env=env, timeout=120
)

self.assertNotIn("BLOCKED", result.stdout, "Network access attempted")
self.assertIn("SUCCESS", result.stdout)
self.assertEqual(result.returncode, 0)