From 88ec2be38b1f735ef054d07ed501c1315074b148 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 25 Nov 2025 09:35:44 +0000 Subject: [PATCH 1/3] Support functions as layers This change adds two types of new functionality. First of all, it introduces the `(Locked|Local)?FuncRepo` classes these can be used to extend a layer with a kernel function. For instance, a layer like ``` @use_kernel_forward_from_hub("SiluAndMul") class SiluAndMul(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: d = input.shape[-1] // 2 return F.silu(input[..., :d]) * input[..., d:] ``` can now also be kernelized using a function `silu_and_mul` from the Hub: ``` with use_kernel_mapping({ "SiluAndMul": { "cuda": FuncRepository( repo_id="kernels-community/activation", func_name="silu_and_mul", ), } }): kernelize(...) ``` This makes it easier to kernelize pure layers (layers that do not use module state), since the Hub kernel does not have to provide a `layers` Python module with wrappers. Secondly, we introduce a decorator `use_kernel_func_from_hub` that turns functions into layers that can be kernelized. For example: ``` @use_kernel_forward_from_hub("silu_and_mul") def silu_and_mul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] ``` will implicitly create an instance of the following class: ``` class Func(nn.Module): # We add some magic to preserve the function's signature. def forward(self, *args, **kwargs): return silu_and_mul(*args, **kwargs) ``` Due to the `__call__` implementation of `nn.Module`, the instance still behaves as a function: ``` out = silu_and_mul(x) ``` However, when the function is used as a member of an `nn.Module`, it will be kernelized: ``` class FeedForward(nn.Module): def __init__(self, in_features: int, out_features: int): self.linear = nn.Linear(in_features, out_features) # Note: silu_and_mul is a Torch module. self.silu_and_mul = silu_and_mul def forward(self, x: torch.Tensor) -> torch.Tensor: return self.silu_and_mul(self.linear(x)) ``` --- docs/source/api/layers.md | 16 ++ docs/source/layers.md | 46 +++++ src/kernels/__init__.py | 17 +- src/kernels/layer/__init__.py | 12 +- src/kernels/layer/func.py | 306 +++++++++++++++++++++++++++++++++ src/kernels/layer/globals.py | 4 +- src/kernels/layer/kernelize.py | 8 +- src/kernels/layer/layer.py | 45 +++-- src/kernels/layer/repos.py | 49 +++--- tests/conftest.py | 12 ++ tests/test_func.py | 129 ++++++++++++++ tests/test_interval_tree.py | 2 +- tests/test_kernel_locking.py | 58 ++++++- tests/test_layer.py | 42 +++-- 14 files changed, 673 insertions(+), 73 deletions(-) create mode 100644 src/kernels/layer/func.py create mode 100644 tests/test_func.py diff --git a/docs/source/api/layers.md b/docs/source/api/layers.md index 44169c8..db52214 100644 --- a/docs/source/api/layers.md +++ b/docs/source/api/layers.md @@ -6,6 +6,10 @@ [[autodoc]] kernels.use_kernel_forward_from_hub +### use_kernel_func_from_hub + +[[autodoc]] kernels.use_kernel_func_from_hub + ### replace_kernel_forward_from_hub [[autodoc]] kernels.replace_kernel_forward_from_hub @@ -36,14 +40,26 @@ [[autodoc]] kernels.Mode +### FuncRepository + +[[autodoc]] kernels.FuncRepository + ### LayerRepository [[autodoc]] kernels.LayerRepository +### LocalFuncRepository + +[[autodoc]] kernels.LocalFuncRepository + ### LocalLayerRepository [[autodoc]] kernels.LocalLayerRepository +### LocalFuncRepository + +[[autodoc]] kernels.LockedFuncRepository + ### LockedLayerRepository [[autodoc]] kernels.LockedLayerRepository diff --git a/docs/source/layers.md b/docs/source/layers.md index 4d3dc71..951a67a 100644 --- a/docs/source/layers.md +++ b/docs/source/layers.md @@ -43,6 +43,37 @@ replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul") it signifies that the maintainer intends to keep the `forward` signature compatible with layers from the hub. +### Using a function as a layer + +Sometimes it can be useful to make a function extensible, for example +because the function cannot be replaced by a layer. In such a case, you +can annotate the function with the `use_kernel_func_from_hub` function: + +```python +@use_kernel_func_from_hub("silu_and_mul") +def silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] +``` + +This will replace the function by an instantiated `torch.nn.Module` +(singleton) that calls the function itself in its forward method. So, the +'function module' + +**Note:** for kernelization to see the function, it must be a member of +another `torch.nn.Module` that is past of the model. For example: + +```python +class FeedForward(nn.Module): + def __init__(self, in_features: int, out_features: int): + self.linear = nn.Linear(in_features, out_features) + # Note: silu_and_mul is a Torch module. + self.silu_and_mul = silu_and_mul + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.silu_and_mul(self.linear(x)) +``` + ## Kernelizing a model A model will not use Hub kernels by default, even if it contains extensible @@ -157,6 +188,21 @@ with use_kernel_mapping(kernel_layer_mapping): This ensures that the mapping is not active anymore outside the `with`-scope. +If the layer is stateless (it does not use member variables _or_ it was +originally a function that was converted into a kernel layer with +`use_kernel_func_from_hub`), it can also be mapped to a kernel function: + +```python +kernel_layer_mapping = { + "SiluAndMul": { + "cuda": FuncRepository( + repo_id="kernels-community/activation", + func_name="silu_and_mul", + ), + } +} +``` + ### Using version bounds Kernels are versioned using tags of the form `v..`. diff --git a/src/kernels/__init__.py b/src/kernels/__init__.py index 0e521ee..85057db 100644 --- a/src/kernels/__init__.py +++ b/src/kernels/__init__.py @@ -2,15 +2,22 @@ __version__ = importlib.metadata.version("kernels") -from kernels.layer import Device, CUDAProperties -from kernels.layer import kernelize, register_kernel_mapping, use_kernel_mapping -from kernels.layer import Mode from kernels.layer import ( + CUDAProperties, + Device, + FuncRepository, LayerRepository, + LocalFuncRepository, LocalLayerRepository, + LockedFuncRepository, LockedLayerRepository, + Mode, + kernelize, + register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, + use_kernel_func_from_hub, + use_kernel_mapping, ) from kernels.utils import ( get_kernel, @@ -25,8 +32,11 @@ "__version__", "CUDAProperties", "Device", + "FuncRepository", "LayerRepository", + "LocalFuncRepository", "LocalLayerRepository", + "LockedFuncRepository", "LockedLayerRepository", "Mode", "get_kernel", @@ -38,7 +48,6 @@ "load_kernel", "register_kernel_mapping", "replace_kernel_forward_from_hub", - "replace_kernel_func_from_hub", "use_kernel_forward_from_hub", "use_kernel_func_from_hub", "use_kernel_mapping", diff --git a/src/kernels/layer/__init__.py b/src/kernels/layer/__init__.py index 85e17dd..bf8d209 100644 --- a/src/kernels/layer/__init__.py +++ b/src/kernels/layer/__init__.py @@ -1,4 +1,10 @@ -from .device import Device, CUDAProperties +from .device import CUDAProperties, Device +from .func import ( + FuncRepository, + LocalFuncRepository, + LockedFuncRepository, + use_kernel_func_from_hub, +) from .kernelize import ( kernelize, register_kernel_mapping, @@ -16,13 +22,17 @@ __all__ = [ "CUDAProperties", "Device", + "FuncRepository", "LayerRepository", + "LocalFuncRepository", "LocalLayerRepository", + "LockedFuncRepository", "LockedLayerRepository", "Mode", "kernelize", "register_kernel_mapping", "replace_kernel_forward_from_hub", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "use_kernel_mapping", ] diff --git a/src/kernels/layer/func.py b/src/kernels/layer/func.py new file mode 100644 index 0000000..e2f1127 --- /dev/null +++ b/src/kernels/layer/func.py @@ -0,0 +1,306 @@ +import functools +import inspect +from inspect import Parameter, Signature +from pathlib import Path +from types import ModuleType +from typing import TYPE_CHECKING, Callable, Optional, Protocol, Type + +from kernels.layer.repos import RepositoryProtocol + +from .._versions import select_revision_or_version +from ..utils import ( + _get_caller_locked_kernel, + _get_locked_kernel, + get_kernel, + get_local_kernel, +) + +if TYPE_CHECKING: + from torch import nn + + +class FuncRepositoryProtocol(RepositoryProtocol, Protocol): + @property + def func_name(self) -> str: ... + + +class FuncRepository: + """ + Repository and name of a function for kernel mapping. + + Args: + repo_id (`str`): + The Hub repository containing the layer. + func_name (`str`): + The name of the function within the kernel repository. + revision (`str`, *optional*, defaults to `"main"`): + The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. + version (`str`, *optional*): + The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. + Cannot be used together with `revision`. + + Example: + ```python + from kernels import FuncRepository + + # Reference a specific layer by revision + layer_repo = FuncRepository( + repo_id="kernels-community/activation", + func_name="silu_and_mul", + ) + + # Reference a layer by version constraint + layer_repo_versioned = FuncRepository( + repo_id="kernels-community/activation", + func_name="silu_and_mul", + version=">=0.0.3,<0.1" + ) + ``` + """ + + def __init__( + self, + repo_id: str, + *, + func_name: str, + revision: Optional[str] = None, + version: Optional[str] = None, + ): + if revision is not None and version is not None: + raise ValueError( + "Either a revision or a version must be specified, not both." + ) + + self._repo_id = repo_id + self.func_name = func_name + + # We are going to resolve these lazily, since we do not want + # to do a network request for every registered FuncRepository. + self._revision = revision + self._version = version + + @functools.lru_cache() + def _resolve_revision(self) -> str: + return select_revision_or_version( + repo_id=self._repo_id, revision=self._revision, version=self._version + ) + + def load(self) -> Type["nn.Module"]: + kernel = get_kernel(self._repo_id, revision=self._resolve_revision()) + return _get_kernel_func(self, kernel) + + def __eq__(self, other): + return ( + isinstance(other, FuncRepository) + and self.func_name == other.func_name + and self._repo_id == other._repo_id + and self._revision == other._revision + and self._version == other._version + ) + + def __hash__(self): + return hash((self.func_name, self._repo_id, self._revision, self._version)) + + def __str__(self) -> str: + return f"`{self._repo_id}` (revision: {self._resolve_revision()}), function `{self.func_name}`" + + +class LocalFuncRepository: + """ + Repository and function name from a local directory for kernel mapping. + + Args: + repo_path (`Path`): + The local repository containing the layer. + package_name (`str`): + Package name of the kernel. + func_name (`str`): + The name of the function within the kernel repository. + + Example: + ```python + from pathlib import Path + + from kernels import LocalFuncRepository + + # Reference a specific layer by revision + layer_repo = LocalFuncRepository( + repo_path=Path("/home/daniel/kernels/activation"), + package_name="activation", + func_name="silu_and_mul", + ) + ``` + """ + + def __init__( + self, + repo_path: Path, + *, + package_name: str, + func_name: str, + ): + self._repo_path = repo_path + self._package_name = package_name + self.func_name = func_name + + def load(self) -> Type["nn.Module"]: + kernel = get_local_kernel(self._repo_path, self._package_name) + return _get_kernel_func(self, kernel) + + def __eq__(self, other): + return ( + isinstance(other, LocalFuncRepository) + and self.func_name == other.func_name + and self._repo_path == other._repo_path + and self._package_name == other._package_name + ) + + def __hash__(self): + return hash((self.func_name, self._repo_path, self._package_name)) + + def __str__(self) -> str: + return f"`{self._repo_path}` (package: {self._package_name}), layer `{self.func_name}`" + + +def use_kernel_func_from_hub(func_name: str): + """ + Decorator that makes a function extensible using the specified function name. + + This is a decorator factory that returns a decorator which prepares a function to use kernels from the + Hugging Face Hub. + + The function will be exposed as an instance of `torch.nn.Module` in which + the function is called in `forward`. For the function to be properly + kernelized, it **must** be a member of another `torch.nn.Module` that is + part of the model (see the example). + + Args: + func_name (`str`): + The name of the function name to use for kernel lookup in registered mappings. + + Returns: + `Callable`: A decorator function that can be applied to layer classes. + + Example: + ```python + import torch + import torch.nn as nn + + from kernels import use_kernel_func_from_hub + from kernels import Mode, kernelize + + @use_kernel_func_from_hub("my_custom_func") + def my_custom_func(x: torch.Tensor): + # Original implementation + return x + + class MyModel(torch.nn.Module): + def __init__(self): + super().__init__() + self.fn = my_custom_func + + def forward(self, x): + return self.fn(x) + + model = MyModel() + + # The layer can now be kernelized: + # model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda") + ``` + """ + + def decorator(func): + Func = _create_func_module(func) + Func.kernel_layer_name = func_name + return Func() + + return decorator + + +class LockedFuncRepository: + """ + Repository and name of a function. + + In contrast to `FuncRepository`, this class uses repositories that + are locked inside a project. + """ + + def __init__( + self, + repo_id: str, + *, + lockfile: Optional[Path] = None, + func_name: str, + ): + """ + Construct a function repository. + + Args: + repo_id (`str`): The Hub repository containing the function. + lockfile (`Path`, *optional*): Path to the lockfile. If not provided, + the lockfile will be inferred from the caller's context. + func_name (`str`): The name of the function within the kernel repository. + """ + self._repo_id = repo_id + self._lockfile = lockfile + self.func_name = func_name + + @functools.lru_cache() + def _resolve_revision(self) -> str: + if self._lockfile is None: + locked_sha = _get_caller_locked_kernel(self._repo_id) + else: + with open(self._lockfile, "r") as f: + locked_sha = _get_locked_kernel(self._repo_id, f.read()) + + if locked_sha is None: + raise ValueError(f"Kernel `{self._repo_id}` is not locked") + + return locked_sha + + def load(self) -> Type["nn.Module"]: + kernel = get_kernel(repo_id=self._repo_id, revision=self._resolve_revision()) + return _get_kernel_func(self, kernel) + + def __eq__(self, other): + return ( + isinstance(other, LockedFuncRepository) + and self.func_name == other.func_name + and self._repo_id == other._repo_id + ) + + def __hash__(self): + return hash((self.func_name, self._repo_id)) + + def __str__(self) -> str: + return f"`{self._repo_id}` (revision: {self._resolve_revision()}), function `{self.func_name}`" + + +def _get_kernel_func( + repo: FuncRepositoryProtocol, kernel: ModuleType +) -> Type["nn.Module"]: + func = getattr(kernel, repo.func_name, None) + if func is None: + raise ValueError(f"Function `{repo.func_name}` not found in `{repo}`") + + return _create_func_module(func) + + +def _create_func_module(func: Callable) -> Type["nn.Module"]: + from torch import nn + + class Func(nn.Module): + def forward(self, *args, **kwargs): + return func(*args, **kwargs) + + # Use function signature with args prepended by self to support + # module validation. + func_sig = inspect.signature(func) + new_args = [Parameter("self", Parameter.POSITIONAL_OR_KEYWORD)] + new_args.extend(func_sig.parameters.values()) + Func.forward.__signature__ = Signature( # type: ignore[attr-defined] + parameters=new_args, + return_annotation=func_sig.return_annotation, + ) + + return Func diff --git a/src/kernels/layer/globals.py b/src/kernels/layer/globals.py index eaf7c7b..1c250d6 100644 --- a/src/kernels/layer/globals.py +++ b/src/kernels/layer/globals.py @@ -1,10 +1,10 @@ import os +from contextvars import ContextVar from typing import ( Dict, ) -from contextvars import ContextVar -from .repos import DeviceRepos +from .repos import DeviceRepos _DISABLE_KERNEL_MAPPING: bool = bool(int(os.environ.get("DISABLE_KERNEL_MAPPING", "0"))) diff --git a/src/kernels/layer/kernelize.py b/src/kernels/layer/kernelize.py index 659d24b..35ed11f 100644 --- a/src/kernels/layer/kernelize.py +++ b/src/kernels/layer/kernelize.py @@ -11,7 +11,7 @@ from .repos import DeviceRepos from .globals import _KERNEL_MAPPING from .layer import kernelize_layer -from .repos import LayerRepositoryProtocol +from .repos import RepositoryProtocol from .mode import Mode from .device import Device @@ -25,7 +25,7 @@ def use_kernel_mapping( str, Dict[ Union[Device, str], - Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]], + Union[RepositoryProtocol, Dict[Mode, RepositoryProtocol]], ], ], *, @@ -104,7 +104,7 @@ def register_kernel_mapping( str, Dict[ Union[Device, str], - Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]], + Union[RepositoryProtocol, Dict[Mode, RepositoryProtocol]], ], ], inherit_mapping: bool = True, @@ -116,7 +116,7 @@ def register_kernel_mapping( depending on the device and mode. This should be used in conjunction with [`kernelize`]. Args: - mapping (`Dict[str, Dict[Union[Device, str], Union[LayerRepositoryProtocol, Dict[Mode, LayerRepositoryProtocol]]]]`): + mapping (`Dict[str, Dict[Union[Device, str], Union[RepositoryProtocol, Dict[Mode, RepositoryProtocol]]]]`): The kernel mapping to register globally. Maps layer names to device-specific kernels. The mapping can specify different kernels for different modes (training, inference, etc.). inherit_mapping (`bool`, *optional*, defaults to `True`): diff --git a/src/kernels/layer/layer.py b/src/kernels/layer/layer.py index 9b37d78..9965c64 100644 --- a/src/kernels/layer/layer.py +++ b/src/kernels/layer/layer.py @@ -10,6 +10,7 @@ TYPE_CHECKING, Dict, Optional, + Protocol, Type, ) @@ -23,12 +24,17 @@ get_local_kernel, ) from .mode import Mode -from .repos import _select_repository, LayerRepositoryProtocol +from .repos import _select_repository, RepositoryProtocol if TYPE_CHECKING: from torch import nn +class LayerRepositoryProtocol(RepositoryProtocol, Protocol): + @property + def layer_name(self) -> str: ... + + class LayerRepository: """ Repository and name of a layer for kernel mapping. @@ -90,8 +96,9 @@ def _resolve_revision(self) -> str: repo_id=self._repo_id, revision=self._revision, version=self._version ) - def load(self) -> ModuleType: - return get_kernel(self._repo_id, revision=self._resolve_revision()) + def load(self) -> Type["nn.Module"]: + kernel = get_kernel(self._repo_id, revision=self._resolve_revision()) + return _get_kernel_layer(self, kernel) def __eq__(self, other): return ( @@ -147,8 +154,9 @@ def __init__( self._package_name = package_name self.layer_name = layer_name - def load(self) -> ModuleType: - return get_local_kernel(self._repo_path, self._package_name) + def load(self) -> Type["nn.Module"]: + kernel = get_local_kernel(self._repo_path, self._package_name) + return _get_kernel_layer(self, kernel) def __eq__(self, other): return ( @@ -203,8 +211,9 @@ def _resolve_revision(self) -> str: return locked_sha - def load(self) -> ModuleType: - return get_kernel(repo_id=self._repo_id, revision=self._resolve_revision()) + def load(self) -> Type["nn.Module"]: + kernel = get_kernel(repo_id=self._repo_id, revision=self._resolve_revision()) + return _get_kernel_layer(self, kernel) def __eq__(self, other): return ( @@ -220,7 +229,7 @@ def __str__(self) -> str: return f"`{self._repo_id}` (revision: {self._resolve_revision()}), layer `{self.layer_name}`" -_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {} +_CACHED_LAYER: Dict[RepositoryProtocol, Type["nn.Module"]] = {} def replace_kernel_forward_from_hub( @@ -352,7 +361,7 @@ def kernelize_layer( repo, repo_mode = repo_with_mode - logging.info(f"Using layer `{repo.layer_name}` from repo {repo}") + logging.info(f"Using function/layer from repo {repo}") logging.debug(f"kernelize mode: {mode}, repo mode: {repo_mode}") layer = _get_layer_memoize(repo, module_class) @@ -373,11 +382,11 @@ def kernelize_layer( ) -def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]: +def _get_kernel_layer( + repo: LayerRepositoryProtocol, kernel: ModuleType +) -> Type["nn.Module"]: """Get a layer from a kernel.""" - kernel = repo.load() - if getattr(kernel, "layers", None) is None: raise ValueError(f"Kernel repo {repo} does not define any layers.") @@ -387,7 +396,7 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]: return layer -def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol): +def _validate_layer(*, check_cls, cls, repo: RepositoryProtocol): import torch.nn as nn # The layer must have at least have the following properties: (1) it @@ -471,7 +480,7 @@ def _validate_layer_has_mode( *, layer_name: str, module: Type["nn.Module"], - repo: LayerRepositoryProtocol, + repo: RepositoryProtocol, repo_mode: Mode, ): """ @@ -480,7 +489,7 @@ def _validate_layer_has_mode( if Mode.TRAINING in repo_mode and not getattr(module, "has_backward", True): raise ValueError( - f"Layer `{repo.layer_name}` from repo {repo} does not support backward.\n" + f"Function/layer from repo {repo} does not support backward.\n" f"Was registered for `{layer_name}` with mode `{repo_mode}`" ) @@ -488,7 +497,7 @@ def _validate_layer_has_mode( module, "can_torch_compile", False ): raise ValueError( - f"Layer `{repo.layer_name}` from repo {repo} does not support torch.compile.\n" + f"Function/layer from repo {repo} does not support torch.compile.\n" f"Was registered for `{layer_name}` with mode `{repo_mode}`" ) @@ -496,13 +505,13 @@ def _validate_layer_has_mode( def _get_layer_memoize( - repo: LayerRepositoryProtocol, module_class: Type["nn.Module"] + repo: RepositoryProtocol, module_class: Type["nn.Module"] ) -> Type["nn.Module"]: layer = _CACHED_LAYER.get(repo, None) if layer is not None: return layer - layer = _get_kernel_layer(repo) + layer = repo.load() _validate_layer(check_cls=module_class, cls=layer, repo=repo) _CACHED_LAYER[repo] = layer diff --git a/src/kernels/layer/repos.py b/src/kernels/layer/repos.py index 47e8dd0..7576b4a 100644 --- a/src/kernels/layer/repos.py +++ b/src/kernels/layer/repos.py @@ -1,6 +1,5 @@ from abc import ABC, abstractmethod -from types import ModuleType -from typing import Dict, Optional, Protocol, Tuple +from typing import TYPE_CHECKING, Dict, Optional, Protocol, Tuple, Type import sys from functools import lru_cache @@ -9,12 +8,12 @@ from ._interval_tree import IntervalTree from .device import CUDAProperties, ROCMProperties +if TYPE_CHECKING: + from torch import nn -class LayerRepositoryProtocol(Protocol): - @property - def layer_name(self) -> str: ... - def load(self) -> ModuleType: ... +class RepositoryProtocol(Protocol): + def load(self) -> Type["nn.Module"]: ... class DeviceRepos(ABC): @@ -42,10 +41,10 @@ def create_repo(device: Device) -> "DeviceRepos": @abstractmethod def repos( self, - ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: ... + ) -> Optional[Dict[Mode, RepositoryProtocol]]: ... @abstractmethod - def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]): """ Insert a repository for a specific device and mode. """ @@ -53,7 +52,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): class _XPURepos(DeviceRepos): - _repos: Dict[Mode, LayerRepositoryProtocol] + _repos: Dict[Mode, RepositoryProtocol] def __init__(self): super().__init__() @@ -62,10 +61,10 @@ def __init__(self): @property def repos( self, - ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + ) -> Optional[Dict[Mode, RepositoryProtocol]]: return self._repos - def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]): if device.type != "xpu": raise ValueError(f"Device type must be 'xpu', got {device.type}") @@ -73,7 +72,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): class _NPURepos(DeviceRepos): - _repos: Dict[Mode, LayerRepositoryProtocol] + _repos: Dict[Mode, RepositoryProtocol] def __init__(self): super().__init__() @@ -82,10 +81,10 @@ def __init__(self): @property def repos( self, - ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + ) -> Optional[Dict[Mode, RepositoryProtocol]]: return self._repos - def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]): if device.type != "npu": raise ValueError(f"Device type must be 'npu', got {device.type}") @@ -93,7 +92,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): class _MPSRepos(DeviceRepos): - _repos: Dict[Mode, LayerRepositoryProtocol] + _repos: Dict[Mode, RepositoryProtocol] def __init__(self): super().__init__() @@ -102,10 +101,10 @@ def __init__(self): @property def repos( self, - ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + ) -> Optional[Dict[Mode, RepositoryProtocol]]: return self._repos - def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]): if device.type != "mps": raise ValueError(f"Device type must be 'mps', got {device.type}") @@ -113,7 +112,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): class _CUDARepos(DeviceRepos): - _repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]] + _repos: IntervalTree[Dict[Mode, RepositoryProtocol]] def __init__(self): super().__init__() @@ -122,11 +121,11 @@ def __init__(self): @property def repos( self, - ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + ) -> Optional[Dict[Mode, RepositoryProtocol]]: capability = _find_capability() return self.repos_by_capability.find_smallest_interval(capability) - def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]): assert device.properties is None or isinstance( device.properties, CUDAProperties ) @@ -144,7 +143,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): class _ROCMRepos(DeviceRepos): - _repos: IntervalTree[Dict[Mode, LayerRepositoryProtocol]] + _repos: IntervalTree[Dict[Mode, RepositoryProtocol]] def __init__(self): super().__init__() @@ -153,11 +152,11 @@ def __init__(self): @property def repos( self, - ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + ) -> Optional[Dict[Mode, RepositoryProtocol]]: capability = _find_capability() return self.repos_by_capability.find_smallest_interval(capability) - def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + def insert(self, device: Device, repos: Dict[Mode, RepositoryProtocol]): assert device.properties is None or isinstance( device.properties, ROCMProperties ) @@ -202,10 +201,10 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): def _select_repository( - repositories: Dict[Mode, LayerRepositoryProtocol], + repositories: Dict[Mode, RepositoryProtocol], *, mode: Mode, -) -> Optional[Tuple[LayerRepositoryProtocol, Mode]]: +) -> Optional[Tuple[RepositoryProtocol, Mode]]: # Get the fallback priority list for the requested mode if mode not in _MODE_FALLBACK_PRIORITY: raise ValueError(f"Unsupported mode: {mode}") diff --git a/tests/conftest.py b/tests/conftest.py index 7369646..5c307df 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -31,6 +31,18 @@ def pytest_addoption(parser): ) +@pytest.fixture +def device(): + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + elif _get_privateuse_backend_name() == "npu": + return "npu" + + pytest.skip("No CUDA, NPU or XPU") + + def pytest_runtest_setup(item): if "cuda_only" in item.keywords and not has_cuda: pytest.skip("skipping CUDA-only test on host without CUDA") diff --git a/tests/test_func.py b/tests/test_func.py new file mode 100644 index 0000000..47728ec --- /dev/null +++ b/tests/test_func.py @@ -0,0 +1,129 @@ +import pytest +import torch +import torch.nn.functional as F +from torch import nn + +from kernels import ( + FuncRepository, + LayerRepository, + LocalFuncRepository, + Mode, + install_kernel, + kernelize, + use_kernel_func_from_hub, + use_kernel_mapping, +) + + +# A function + layer that we can map arbitrary functions to for testing. +@use_kernel_func_from_hub("surprise_me") +def surprise_me(x: torch.Tensor): + return x + + +class SurpriseMe(nn.Module): + def __init__(self): + super().__init__() + self.surprise_me = surprise_me + + def forward(self, x: torch.Tensor): + return self.surprise_me(x) + + +def test_decorator(): + @use_kernel_func_from_hub("identity_func") + def identity(x): + return x + + assert type(identity).kernel_layer_name == "identity_func" + assert isinstance(identity, nn.Module) + + +def test_kernel_func(device): + model = SurpriseMe() + + x = torch.arange(-10, 10, device=device).float() + assert model(x) is x + + with use_kernel_mapping( + { + "surprise_me": { + device: FuncRepository( + "kernels-test/flattened-build", + func_name="silu_and_mul", + ) + } + } + ): + model = kernelize(model, mode=Mode.INFERENCE, device=device) + + torch.testing.assert_close(model(x), _silu_and_mul(x)) + + # And empty mapping should revert to the original implementation. + with use_kernel_mapping({"surprise_me": {}}): + model = kernelize(model, mode=Mode.INFERENCE, device=device) + + assert model(x) is x + + +@pytest.mark.cuda_only +def test_kernel_func_with_layer(): + model = SurpriseMe() + + x = torch.arange(-10, 10, device="cuda").float() + assert model(x) is x + + # We can also replace the function by a pure layer. + with use_kernel_mapping( + { + "surprise_me": { + "cuda": LayerRepository( + "kernels-community/activation", + layer_name="SiluAndMul", + ) + } + } + ): + model = kernelize(model, mode=Mode.INFERENCE, device="cuda") + + torch.testing.assert_close(model(x), _silu_and_mul(x)) + + # And empty mapping should revert to the original implementation. + with use_kernel_mapping({"surprise_me": {}}): + model = kernelize(model, mode=Mode.INFERENCE, device="cuda") + + assert model(x) is x + + +def test_local_kernel_func(device): + model = SurpriseMe() + + x = torch.arange(-10, 10).float() + assert model(x) is x + + package_name, path = install_kernel("kernels-test/flattened-build", "main") + + with use_kernel_mapping( + { + "surprise_me": { + device: LocalFuncRepository( + repo_path=path.parent.parent, + package_name=package_name, + func_name="silu_and_mul", + ) + } + } + ): + model = kernelize(model, mode=Mode.INFERENCE, device=device) + + torch.testing.assert_close(model(x), _silu_and_mul(x)) + + with use_kernel_mapping({"do_something_func": {}}): + model = kernelize(model, mode=Mode.INFERENCE, device=device) + + assert model(x) is x + + +def _silu_and_mul(x: torch.Tensor) -> torch.Tensor: + d = x.shape[-1] // 2 + return F.silu(x[..., :d]) * x[..., d:] diff --git a/tests/test_interval_tree.py b/tests/test_interval_tree.py index 8b8b957..00ccc2b 100644 --- a/tests/test_interval_tree.py +++ b/tests/test_interval_tree.py @@ -177,7 +177,7 @@ def test_property_based_interval_tree(): simple_result = simple.find_smallest_interval(point) assert tree_result == simple_result, ( - f"Mismatch for point {point} after inserting {i+1} intervals. " + f"Mismatch for point {point} after inserting {i + 1} intervals. " f"Tree: {tree_result}, Simple: {simple_result}. " f"Last inserted: ({start}, {end})" ) diff --git a/tests/test_kernel_locking.py b/tests/test_kernel_locking.py index 7daaa88..5b7795f 100644 --- a/tests/test_kernel_locking.py +++ b/tests/test_kernel_locking.py @@ -4,15 +4,17 @@ import pytest import torch.nn as nn -from kernels import load_kernel -from kernels.cli import download_kernels -from kernels.layer import ( +from kernels import ( + LockedFuncRepository, LockedLayerRepository, Mode, kernelize, + load_kernel, use_kernel_forward_from_hub, + use_kernel_func_from_hub, use_kernel_mapping, ) +from kernels.cli import download_kernels # Mock download arguments class. @@ -36,7 +38,7 @@ def test_load_locked(): @pytest.mark.cuda_only -def test_layer_locked(): +def test_layer_locked(device): project_dir = Path(__file__).parent / "layer_locking" @use_kernel_forward_from_hub("Version") @@ -49,7 +51,7 @@ def forward(self) -> str: with use_kernel_mapping( { "Version": { - "cuda": LockedLayerRepository( + device: LockedLayerRepository( repo_id="kernels-test/versions", layer_name="Version", lockfile=project_dir / "kernels.lock", @@ -57,5 +59,49 @@ def forward(self) -> str: }, } ): - version = kernelize(version, device="cuda", mode=Mode.INFERENCE) + version = kernelize(version, device=device, mode=Mode.INFERENCE) assert version() == "0.1.1" + + +def test_func_locked(device): + project_dir = Path(__file__).parent / "layer_locking" + + @use_kernel_func_from_hub("version") + def version(): + return "0.0.0" + + class Version(nn.Module): + def __init__(self): + super().__init__() + self.version = version + + def forward(self) -> str: + return self.version() + + model = Version() + + print(model.version.forward) + + with use_kernel_mapping( + { + "version": { + device: LockedFuncRepository( + repo_id="kernels-test/versions", + func_name="version", + lockfile=project_dir / "kernels.lock", + ) + }, + } + ): + model = kernelize(model, device=device, mode=Mode.INFERENCE) + + assert version() == "0.1.1" + + print(model.version.forward) + + with use_kernel_mapping({"version": {}}): + model = kernelize(model, mode=Mode.INFERENCE, device=device) + + assert version() == "0.0.0" + + print(model.version.forward) diff --git a/tests/test_layer.py b/tests/test_layer.py index 43ff377..32494fd 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -9,6 +9,7 @@ from kernels import ( CUDAProperties, Device, + FuncRepository, LayerRepository, LocalLayerRepository, Mode, @@ -123,18 +124,6 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input) -@pytest.fixture -def device(): - if torch.cuda.is_available(): - return "cuda" - elif hasattr(torch, "xpu") and torch.xpu.is_available(): - return "xpu" - elif _get_privateuse_backend_name() == "npu": - return "npu" - - pytest.skip("No CUDA, NPU or XPU") - - def test_arg_kinds(): @use_kernel_forward_from_hub("ArgKind") class ArgKind(nn.Module): @@ -171,6 +160,35 @@ def test_hub_forward(cls): assert silu_and_mul_with_kernel.n_calls == 0 +@pytest.mark.cuda_only +@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) +def test_hub_func(cls): + torch.random.manual_seed(0) + + silu_and_mul = SiluAndMul() + X = torch.randn((32, 64), device="cuda") + Y = silu_and_mul(X) + + # SiluAndMul is pure, so we can also use a function. + with use_kernel_mapping( + { + "surprise_me": { + "cuda": FuncRepository( + "kernels-test/flattened-build", + func_name="silu_and_mul", + ) + } + } + ): + silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE) + Y_kernel = silu_and_mul_with_kernel(X) + + torch.testing.assert_close(Y_kernel, Y) + + assert silu_and_mul.n_calls == 1 + assert silu_and_mul_with_kernel.n_calls == 0 + + @pytest.mark.rocm_only def test_hub_forward_rocm(): torch.manual_seed(0) From 9956630fb52632af3ac1321320042f8dc279db7d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 25 Nov 2025 10:14:51 +0000 Subject: [PATCH 2/3] test_layer_locked: remove unnecessary cuda_only mark --- tests/test_kernel_locking.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_kernel_locking.py b/tests/test_kernel_locking.py index 5b7795f..2cbe905 100644 --- a/tests/test_kernel_locking.py +++ b/tests/test_kernel_locking.py @@ -37,7 +37,6 @@ def test_load_locked(): load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") -@pytest.mark.cuda_only def test_layer_locked(device): project_dir = Path(__file__).parent / "layer_locking" From 579e8f2550504740a3b26624f85612ea341a7682 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Wed, 26 Nov 2025 12:07:03 +0100 Subject: [PATCH 3/3] Documentation fixes Co-authored-by: Mohamed Mekkouri <93391238+MekkCyber@users.noreply.github.com> --- docs/source/api/layers.md | 2 +- docs/source/layers.md | 11 +++++------ 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/docs/source/api/layers.md b/docs/source/api/layers.md index db52214..37bfb41 100644 --- a/docs/source/api/layers.md +++ b/docs/source/api/layers.md @@ -56,7 +56,7 @@ [[autodoc]] kernels.LocalLayerRepository -### LocalFuncRepository +### LockedFuncRepository [[autodoc]] kernels.LockedFuncRepository diff --git a/docs/source/layers.md b/docs/source/layers.md index 951a67a..99629e8 100644 --- a/docs/source/layers.md +++ b/docs/source/layers.md @@ -46,8 +46,8 @@ compatible with layers from the hub. ### Using a function as a layer Sometimes it can be useful to make a function extensible, for example -because the function cannot be replaced by a layer. In such a case, you -can annotate the function with the `use_kernel_func_from_hub` function: +because the function cannot be replaced by a layer. In such cases, you +can annotate the function with the `use_kernel_func_from_hub` decorator: ```python @use_kernel_func_from_hub("silu_and_mul") @@ -57,11 +57,10 @@ def silu_and_mul(x: torch.Tensor) -> torch.Tensor: ``` This will replace the function by an instantiated `torch.nn.Module` -(singleton) that calls the function itself in its forward method. So, the -'function module' +(singleton) that calls the function itself in its forward method. **Note:** for kernelization to see the function, it must be a member of -another `torch.nn.Module` that is past of the model. For example: +another `torch.nn.Module` that is part of the model. For example: ```python class FeedForward(nn.Module): @@ -188,7 +187,7 @@ with use_kernel_mapping(kernel_layer_mapping): This ensures that the mapping is not active anymore outside the `with`-scope. -If the layer is stateless (it does not use member variables _or_ it was +If the layer is stateless (it does not use member variables in its forward _or_ it was originally a function that was converted into a kernel layer with `use_kernel_func_from_hub`), it can also be mapped to a kernel function: