Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions kernels/src/kernels/_versions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
import logging
import os
from pathlib import Path

from huggingface_hub import constants
from huggingface_hub.file_download import repo_folder_name
from huggingface_hub.hf_api import GitRefInfo

logger = logging.getLogger(__name__)
Expand All @@ -9,6 +13,9 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
"""Get kernel versions that are available in the repository."""
from kernels.utils import _get_hf_api

if constants.HF_HUB_OFFLINE:
return _get_available_versions_from_cache(repo_id)

refs = _get_hf_api().list_repo_refs(repo_id=repo_id, repo_type="kernel")

versions = {}
Expand All @@ -23,6 +30,36 @@ def _get_available_versions(repo_id: str) -> dict[int, GitRefInfo]:
return versions


def _get_available_versions_from_cache(repo_id: str) -> dict[int, GitRefInfo]:
"""Get kernel versions from the local Hugging Face cache."""
cache_dir = os.environ.get("KERNELS_CACHE") or constants.HF_HUB_CACHE

versions: dict[int, GitRefInfo] = {}
# Tolerate both layouts: the "kernel" repo type used by newer
# huggingface_hub, and the legacy "model" prefix that older caches use.
for repo_type in ("kernel", "model"):
refs_dir = Path(cache_dir) / repo_folder_name(repo_id=repo_id, repo_type=repo_type) / "refs"
if not refs_dir.is_dir():
continue
for ref_path in refs_dir.iterdir():
if not ref_path.is_file():
continue
ref_name = ref_path.name
if not ref_name.startswith("v"):
continue
try:
version = int(ref_name[1:])
except ValueError:
continue
try:
commit = ref_path.read_text().strip()
except OSError:
continue
versions[version] = GitRefInfo(name=ref_name, ref=ref_name, target_commit=commit)

return versions


def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo:
"""
Get the ref for a kernel with the given version.
Expand All @@ -31,6 +68,12 @@ def resolve_version_spec_as_ref(repo_id: str, version_spec: int) -> GitRefInfo:

ref = versions.get(version_spec, None)
if ref is None:
if constants.HF_HUB_OFFLINE and not versions:
raise ValueError(
f"Version {version_spec} of '{repo_id}' is not available in the local cache "
"and Hugging Face Hub is in offline mode. Download the kernel "
"while online first, or pass an explicit `revision=<commit>`."
)
raise ValueError(
f"Version {version_spec} not found, available versions: {', '.join(str(v) for v in sorted(versions.keys()))}"
)
Expand Down
142 changes: 101 additions & 41 deletions kernels/src/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,14 @@
import os
import platform
import sys
import warnings
from dataclasses import dataclass
from importlib.metadata import Distribution
from pathlib import Path
from types import ModuleType

from huggingface_hub import HfApi, constants
from huggingface_hub.errors import LocalEntryNotFoundError
from kernels_data import Metadata

from kernels._system import glibc_version
Expand Down Expand Up @@ -52,8 +54,6 @@ def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str])
return

if isinstance(trust_remote_code, list):
import warnings

warnings.warn(
"Signing identity verification is not yet implemented. "
"The provided signing identities will be ignored and the "
Expand All @@ -62,6 +62,16 @@ def _check_trust_remote_code(repo_id: str, trust_remote_code: bool | list[str])
stacklevel=3,
)

if constants.HF_HUB_OFFLINE:
# Publisher trust cannot be verified offline. The user opted into
# offline mode and the kernel must already be in the local cache,
# so trust was established when it was originally downloaded.
warnings.warn(
f"Skipping publisher trust check for '{repo_id}' because Hugging Face Hub is in offline mode.",
stacklevel=3,
)
return

publisher = repo_id.split("/", 1)[0]

try:
Expand Down Expand Up @@ -244,10 +254,19 @@ def install_kernel(
`Path`: The path to the variant directory.
"""
api = _get_hf_api(user_agent=user_agent)
if local_files_only or constants.HF_HUB_OFFLINE:
# Same local-cache resolution path used by `load_kernel`, which is
# always offline. Sharing the helper avoids the network dependency
# that `get_variants` would otherwise introduce.
return _resolve_local_variant_path(
Copy link
Copy Markdown
Member

@danieldk danieldk May 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we load_kernel here and keep the body of _resolve_local_variant_path inside load_kernel? It would more clearly express the relation between the functions.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought about it a bit but install_kernel returns Path, whereas load_kernel returns ModuleType. To recover the path from a module return by load_kernel(), we will have to do some kind of parsing.

Then the _import_from_path() utility also needs some changes.

The flow would become get_kernel -> install_kernel > load_kernel -> _import_from_path(repo_info=None) — that caches LoadedKernel(repo_info=None).

Then get_kernel does its own _import_from_path(variant_path, repo_info=RepoInfo(...)), hits the cache, and returns the module while silently discarding the actual repo_info.

As a result, get_loaded_kernels() would report repo_info=None for kernels loaded via get_kernel in offline mode which would be a regression, if not handled in _import_from_path.

I think the shared utility solution is a cleaner approach here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, got it, thanks!

api,
repo_id,
revision=revision,
backend=backend,
variant_locks=variant_locks,
)

if not local_files_only:
repo_id, revision = resolve_status(api, repo_id, revision)

repo_id, revision = resolve_status(api, repo_id, revision)
variants = get_variants(api, repo_id=repo_id, revision=revision)
variant, trace = resolve_variant(variants, backend)

Expand All @@ -266,7 +285,7 @@ def install_kernel(
allow_patterns=allow_patterns,
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=local_files_only,
local_files_only=False,
)
)
)
Expand All @@ -281,6 +300,61 @@ def install_kernel(
raise FileNotFoundError(f"Cannot install kernel from repo {repo_id} (revision: {revision})")


def _resolve_local_variant_path(
api: HfApi,
repo_id: str,
*,
revision: str,
backend: str | None = None,
variant_locks: dict[str, VariantLock] | None = None,
) -> Path:
"""Resolve a kernel variant path from the local Hugging Face cache only.

Used by `load_kernel` (which always operates on a pre-downloaded, locked
kernel) and by the offline branch of `install_kernel`.
"""
try:
local_repo_path = Path(
str(
api.snapshot_download(
repo_id,
repo_type="kernel",
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=True,
)
)
)
except LocalEntryNotFoundError as e:
raise FileNotFoundError(
f"Cannot find a local snapshot for {repo_id} (revision: {revision}). "
"When Hugging Face Hub is in offline mode the kernel must already "
"be present in the local cache."
) from e

variants = get_variants_local(local_repo_path / "build")
variant, status = resolve_variant(variants, backend)
if variant is None:
raise FileNotFoundError(
f"Cannot find a build variant for this system in {repo_id} (revision: {revision}):\n\n{variants_trace_str(status)}"
)

allow_patterns = [f"build/{variant.variant_str}/*"]
repo_path = Path(
str(
api.snapshot_download(
repo_id,
repo_type="kernel",
allow_patterns=allow_patterns,
cache_dir=CACHE_DIR,
revision=revision,
local_files_only=True,
)
)
)
return _find_kernel_in_repo_path(repo_path, variant=variant, variant_locks=variant_locks)


def _find_kernel_in_repo_path(
repo_path: Path,
*,
Expand Down Expand Up @@ -479,10 +553,7 @@ def has_kernel(


def load_kernel(
repo_id: str,
*,
lockfile: Path | None,
backend: str | None = None,
repo_id: str, *, lockfile: Path | None, backend: str | None = None, revision: str | None = None
) -> ModuleType:
"""
Get a pre-downloaded, locked kernel.
Expand All @@ -497,13 +568,20 @@ def load_kernel(
backend (`str`, *optional*):
The backend to load the kernel for. Can only be `cpu` or the backend that Torch is compiled for.
The backend will be detected automatically if not provided.
revision (`str`, *optional*):
The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`.

Returns:
`ModuleType`: The imported kernel module.
"""
if lockfile is None:
if lockfile is not None and revision is not None:
raise ValueError("`lockfile` and `revision` both cannot be specified at the same time.")

if lockfile is None and revision is None:
locked_sha = _get_caller_locked_kernel(repo_id)
else:
elif revision is not None:
locked_sha = revision
elif lockfile is not None:
with open(lockfile, "r") as f:
locked_sha = _get_locked_kernel(repo_id, f.read())

Expand All @@ -513,39 +591,21 @@ def load_kernel(
)

api = _get_hf_api()
variants = get_variants(api, repo_id=repo_id, revision=locked_sha)
variant, status = resolve_variant(variants, backend)

if variant is None:
raise FileNotFoundError(
f"Cannot find a build variant for this system in {repo_id} (revision: {locked_sha}):\n\n{variants_trace_str(status)}"
)

allow_patterns = [f"build/{variant.variant_str}/*"]
repo_path = Path(
str(
api.snapshot_download(
repo_id,
repo_type="kernel",
allow_patterns=allow_patterns,
cache_dir=CACHE_DIR,
revision=locked_sha,
local_files_only=True,
)
)
)

try:
variant_path = _find_kernel_in_repo_path(
repo_path,
variant=variant,
variant_locks=None,
variant_path = _resolve_local_variant_path(
api,
repo_id,
revision=locked_sha,
backend=backend,
)
return _import_from_path(variant_path)
except FileNotFoundError:
except FileNotFoundError as e:
raise FileNotFoundError(
f"Locked kernel `{repo_id}` does not have applicable variant or was not downloaded with `kernels download <project>`"
)
f"Locked kernel `{repo_id}` was not downloaded or does not have an "
"applicable variant. Make sure it's downloaded locally via "
"`kernels download <project>`."
) from e
return _import_from_path(variant_path)


def get_locked_kernel(repo_id: str, local_files_only: bool = False) -> ModuleType:
Expand Down
63 changes: 63 additions & 0 deletions kernels/tests/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
import pytest
import torch
import torch.nn.functional as F
from huggingface_hub import constants
from huggingface_hub.errors import HfHubHTTPError

from kernels import get_kernel, get_local_kernel, has_kernel, install_kernel
from kernels._versions import resolve_version_spec_as_ref, select_revision_or_version


@pytest.fixture
Expand Down Expand Up @@ -243,6 +245,67 @@ def test_trust_remote_code_flag_allows_untrusted():
get_kernel("kernels-test-untrusted/ci-test-kernel", version=1, trust_remote_code=True)


def test_install_kernel_offline_with_revision(monkeypatch, local_kernel_path):
Copy link
Copy Markdown
Member

@danieldk danieldk May 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love these tests!

"""install_kernel should resolve a cached snapshot when HF_HUB_OFFLINE=1."""
expected_path = local_kernel_path
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)

path = install_kernel("kernels-community/relu", revision="v1")
assert path == expected_path


def test_install_kernel_offline_avoids_network(monkeypatch, local_kernel_path):
"""When HF_HUB_OFFLINE=1, install_kernel must not make any Hub requests."""
expected_path = local_kernel_path

class _NoNetwork(RuntimeError):
pass

def _fail(*_args, **_kwargs):
raise _NoNetwork("Hub access attempted in offline test")

monkeypatch.setattr("huggingface_hub.hf_api.get_session", _fail)

# Online path must touch the Hub via get_session and therefore fail.
with pytest.raises(_NoNetwork):
install_kernel("kernels-community/relu", revision="v1")

# Offline mode resolves entirely from the local cache, so get_session is
# never called.
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)
path = install_kernel("kernels-community/relu", revision="v1")
assert path == expected_path


def test_install_kernel_offline_with_version(monkeypatch, local_kernel_path):
"""get_kernel(version=) should resolve via local refs when HF_HUB_OFFLINE=1."""
expected_path = local_kernel_path
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)

commit = select_revision_or_version("kernels-community/relu", revision=None, version=1)
path = install_kernel("kernels-community/relu", revision=commit)
assert path == expected_path


def test_install_kernel_offline_uncached_revision(monkeypatch):
"""install_kernel should fail with a helpful error when offline and uncached."""
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)

with pytest.raises(FileNotFoundError, match=r"local snapshot"):
install_kernel(
"kernels-test/this-repo-should-not-exist",
revision="0000000000000000000000000000000000000000",
)


def test_version_resolution_offline_missing(monkeypatch):
"""resolve_version_spec_as_ref should raise a clear error when offline and no cache."""
monkeypatch.setattr(constants, "HF_HUB_OFFLINE", True)

with pytest.raises(ValueError, match=r"offline mode"):
resolve_version_spec_as_ref("kernels-test/this-repo-should-not-exist", 1)


def silu_and_mul_torch(x: torch.Tensor):
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
Loading