diff --git a/kernels/src/kernels/_versions.py b/kernels/src/kernels/_versions.py index be825ff0..b5b4fae8 100644 --- a/kernels/src/kernels/_versions.py +++ b/kernels/src/kernels/_versions.py @@ -1,5 +1,9 @@ import logging +import os +from pathlib import Path +from huggingface_hub import constants +from huggingface_hub.file_download import repo_folder_name from huggingface_hub.hf_api import GitRefInfo logger = logging.getLogger(__name__) @@ -9,6 +13,9 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]: """Get kernel versions that are available in the repository.""" from kernels.utils import _get_hf_api + if constants.HF_HUB_OFFLINE: + return _get_available_versions_from_cache(repo_id) + refs = _get_hf_api().list_repo_refs(repo_id=repo_id, repo_type="kernel") versions = {} @@ -23,6 +30,36 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]: return versions +def _get_available_versions_from_cache(repo_id: str) -> dict[int, GitRefInfo]: + """Get kernel versions from the local Hugging Face cache.""" + cache_dir = os.environ.get("KERNELS_CACHE") or constants.HF_HUB_CACHE + + versions: dict[int, GitRefInfo] = {} + # Tolerate both layouts: the "kernel" repo type used by newer + # huggingface_hub, and the legacy "model" prefix that older caches use. + for repo_type in ("kernel", "model"): + refs_dir = Path(cache_dir) / repo_folder_name(repo_id=repo_id, repo_type=repo_type) / "refs" + if not refs_dir.is_dir(): + continue + for ref_path in refs_dir.iterdir(): + if not ref_path.is_file(): + continue + ref_name = ref_path.name + if not ref_name.startswith("v"): + continue + try: + version = int(ref_name[1:]) + except ValueError: + continue + try: + commit = ref_path.read_text().strip() + except OSError: + continue + versions[version] = GitRefInfo(name=ref_name, ref=ref_name, target_commit=commit) + + return versions + + def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo: """ Get the ref for a kernel with the given version. @@ -31,6 +68,12 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo: ref = versions.get(version_spec, None) if ref is None: + if constants.HF_HUB_OFFLINE and not versions: + raise ValueError( + f"Version {version_spec} of '{repo_id}' is not available in the local cache " + "and Hugging Face Hub is in offline mode. Download the kernel " + "while online first, or pass an explicit `revision=`." + ) raise ValueError( f"Version {version_spec} not found, available versions: {', '.join(str(v) for v in sorted(versions.keys()))}" ) diff --git a/kernels/src/kernels/utils.py b/kernels/src/kernels/utils.py index 5a046014..8b5e4670 100644 --- a/kernels/src/kernels/utils.py +++ b/kernels/src/kernels/utils.py @@ -7,12 +7,14 @@ import os import platform import sys +import warnings from dataclasses import dataclass from importlib.metadata import Distribution from pathlib import Path from types import ModuleType from huggingface_hub import HfApi, constants +from huggingface_hub.errors import LocalEntryNotFoundError from kernels_data import Metadata from kernels._system import glibc_version @@ -52,8 +54,6 @@ def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str]) return if isinstance(trust_remote_code, list): - import warnings - warnings.warn( "Signing identity verification is not yet implemented. " "The provided signing identities will be ignored and the " @@ -62,6 +62,16 @@ def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str]) stacklevel=3, ) + if constants.HF_HUB_OFFLINE: + # Publisher trust cannot be verified offline. The user opted into + # offline mode and the kernel must already be in the local cache, + # so trust was established when it was originally downloaded. + warnings.warn( + f"Skipping publisher trust check for '{repo_id}' because Hugging Face Hub is in offline mode.", + stacklevel=3, + ) + return + publisher = repo_id.split("/", 1)[0] try: @@ -244,10 +254,19 @@ def install_kernel( `Path`: The path to the variant directory. """ api = _get_hf_api(user_agent=user_agent) + if local_files_only or constants.HF_HUB_OFFLINE: + # Same local-cache resolution path used by `load_kernel`, which is + # always offline. Sharing the helper avoids the network dependency + # that `get_variants` would otherwise introduce. + return _resolve_local_variant_path( + api, + repo_id, + revision=revision, + backend=backend, + variant_locks=variant_locks, + ) - if not local_files_only: - repo_id, revision = resolve_status(api, repo_id, revision) - + repo_id, revision = resolve_status(api, repo_id, revision) variants = get_variants(api, repo_id=repo_id, revision=revision) variant, trace = resolve_variant(variants, backend) @@ -266,7 +285,7 @@ def install_kernel( allow_patterns=allow_patterns, cache_dir=CACHE_DIR, revision=revision, - local_files_only=local_files_only, + local_files_only=False, ) ) ) @@ -281,6 +300,61 @@ def install_kernel( raise FileNotFoundError(f"Cannot install kernel from repo {repo_id} (revision: {revision})") +def _resolve_local_variant_path( + api: HfApi, + repo_id: str, + *, + revision: str, + backend: str | None = None, + variant_locks: dict[str, VariantLock] | None = None, +) -> Path: + """Resolve a kernel variant path from the local Hugging Face cache only. + + Used by `load_kernel` (which always operates on a pre-downloaded, locked + kernel) and by the offline branch of `install_kernel`. + """ + try: + local_repo_path = Path( + str( + api.snapshot_download( + repo_id, + repo_type="kernel", + cache_dir=CACHE_DIR, + revision=revision, + local_files_only=True, + ) + ) + ) + except LocalEntryNotFoundError as e: + raise FileNotFoundError( + f"Cannot find a local snapshot for {repo_id} (revision: {revision}). " + "When Hugging Face Hub is in offline mode the kernel must already " + "be present in the local cache." + ) from e + + variants = get_variants_local(local_repo_path / "build") + variant, status = resolve_variant(variants, backend) + if variant is None: + raise FileNotFoundError( + f"Cannot find a build variant for this system in {repo_id} (revision: {revision}):\n\n{variants_trace_str(status)}" + ) + + allow_patterns = [f"build/{variant.variant_str}/*"] + repo_path = Path( + str( + api.snapshot_download( + repo_id, + repo_type="kernel", + allow_patterns=allow_patterns, + cache_dir=CACHE_DIR, + revision=revision, + local_files_only=True, + ) + ) + ) + return _find_kernel_in_repo_path(repo_path, variant=variant, variant_locks=variant_locks) + + def _find_kernel_in_repo_path( repo_path: Path, *, @@ -479,10 +553,7 @@ def has_kernel( def load_kernel( - repo_id: str, - *, - lockfile: Path | None, - backend: str | None = None, + repo_id: str, *, lockfile: Path | None, backend: str | None = None, revision: str | None = None ) -> ModuleType: """ Get a pre-downloaded, locked kernel. @@ -497,13 +568,20 @@ def load_kernel( backend (`str`, *optional*): The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for. The backend will be detected automatically if not provided. + revision (`str`, *optional*): + The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. Returns: `ModuleType`: The imported kernel module. """ - if lockfile is None: + if lockfile is not None and revision is not None: + raise ValueError("`lockfile` and `revision` both cannot be specified at the same time.") + + if lockfile is None and revision is None: locked_sha = _get_caller_locked_kernel(repo_id) - else: + elif revision is not None: + locked_sha = revision + elif lockfile is not None: with open(lockfile, "r") as f: locked_sha = _get_locked_kernel(repo_id, f.read()) @@ -513,39 +591,21 @@ def load_kernel( ) api = _get_hf_api() - variants = get_variants(api, repo_id=repo_id, revision=locked_sha) - variant, status = resolve_variant(variants, backend) - - if variant is None: - raise FileNotFoundError( - f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}):\n\n{variants_trace_str(status)}" - ) - - allow_patterns = [f"build/{variant.variant_str}/*"] - repo_path = Path( - str( - api.snapshot_download( - repo_id, - repo_type="kernel", - allow_patterns=allow_patterns, - cache_dir=CACHE_DIR, - revision=locked_sha, - local_files_only=True, - ) - ) - ) try: - variant_path = _find_kernel_in_repo_path( - repo_path, - variant=variant, - variant_locks=None, + variant_path = _resolve_local_variant_path( + api, + repo_id, + revision=locked_sha, + backend=backend, ) - return _import_from_path(variant_path) - except FileNotFoundError: + except FileNotFoundError as e: raise FileNotFoundError( - f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download `" - ) + f"Locked kernel `{repo_id}` was not downloaded or does not have an " + "applicable variant. Make sure it's downloaded locally via " + "`kernels download `." + ) from e + return _import_from_path(variant_path) def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType: diff --git a/kernels/tests/test_basic.py b/kernels/tests/test_basic.py index 4ca1cf64..1460f945 100644 --- a/kernels/tests/test_basic.py +++ b/kernels/tests/test_basic.py @@ -3,9 +3,11 @@ import pytest import torch import torch.nn.functional as F +from huggingface_hub import constants from huggingface_hub.errors import HfHubHTTPError from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel +from kernels._versions import resolve_version_spec_as_ref, select_revision_or_version @pytest.fixture @@ -243,6 +245,67 @@ def test_trust_remote_code_flag_allows_untrusted(): get_kernel("kernels-test-untrusted/ci-test-kernel", version=1, trust_remote_code=True) +def test_install_kernel_offline_with_revision(monkeypatch, local_kernel_path): + """install_kernel should resolve a cached snapshot when HF_HUB_OFFLINE=1.""" + expected_path = local_kernel_path + monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True) + + path = install_kernel("kernels-community/relu", revision="v1") + assert path == expected_path + + +def test_install_kernel_offline_avoids_network(monkeypatch, local_kernel_path): + """When HF_HUB_OFFLINE=1, install_kernel must not make any Hub requests.""" + expected_path = local_kernel_path + + class _NoNetwork(RuntimeError): + pass + + def _fail(*_args, **_kwargs): + raise _NoNetwork("Hub access attempted in offline test") + + monkeypatch.setattr("huggingface_hub.hf_api.get_session", _fail) + + # Online path must touch the Hub via get_session and therefore fail. + with pytest.raises(_NoNetwork): + install_kernel("kernels-community/relu", revision="v1") + + # Offline mode resolves entirely from the local cache, so get_session is + # never called. + monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True) + path = install_kernel("kernels-community/relu", revision="v1") + assert path == expected_path + + +def test_install_kernel_offline_with_version(monkeypatch, local_kernel_path): + """get_kernel(version=) should resolve via local refs when HF_HUB_OFFLINE=1.""" + expected_path = local_kernel_path + monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True) + + commit = select_revision_or_version("kernels-community/relu", revision=None, version=1) + path = install_kernel("kernels-community/relu", revision=commit) + assert path == expected_path + + +def test_install_kernel_offline_uncached_revision(monkeypatch): + """install_kernel should fail with a helpful error when offline and uncached.""" + monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True) + + with pytest.raises(FileNotFoundError, match=r"local snapshot"): + install_kernel( + "kernels-test/this-repo-should-not-exist", + revision="0000000000000000000000000000000000000000", + ) + + +def test_version_resolution_offline_missing(monkeypatch): + """resolve_version_spec_as_ref should raise a clear error when offline and no cache.""" + monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True) + + with pytest.raises(ValueError, match=r"offline mode"): + resolve_version_spec_as_ref("kernels-test/this-repo-should-not-exist", 1) + + def silu_and_mul_torch(x: torch.Tensor): d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:]