Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/kernels/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
register_kernel_mapping,
replace_kernel_forward_from_hub,
use_kernel_forward_from_hub,
use_kernel_func_from_hub,
use_kernel_mapping,
)
from kernels.utils import (
Expand Down Expand Up @@ -42,5 +43,6 @@
"register_kernel_mapping",
"replace_kernel_forward_from_hub",
"use_kernel_forward_from_hub",
"use_kernel_func_from_hub",
"use_kernel_mapping",
]
150 changes: 150 additions & 0 deletions src/kernels/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,7 @@ def __str__(self) -> str:


_CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {}
_CACHED_KERNEL_MODULE: Dict[LayerRepositoryProtocol, ModuleType] = {}


class _DeviceRepos(ABC):
Expand Down Expand Up @@ -982,6 +983,74 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
use_fallback=use_fallback,
)

# Second pass: replace kernel functions
for _, module in model.named_modules():
module_class = type(module)
if not hasattr(module_class, "_kernel_func_mappings"):
continue

func_mappings = module_class._kernel_func_mappings

for attr_name, func_name in func_mappings.items():
if _DISABLE_KERNEL_MAPPING:
continue

kernel = _KERNEL_MAPPING.get().get(str(func_name))

if kernel is None:
warnings.warn(
"\n"
f"No kernel mapping found for function `{func_name}`. "
f"Defaulting to original implementation."
)
continue

# Get kernel options for the device
property_repos = kernel.get(device_type.type)

if property_repos is None:
if not use_fallback:
raise ValueError(
f"No function mapping for `{func_name}` with device type `{device_type}`"
)
continue

repos = property_repos.repos

if repos is None:
if not use_fallback:
raise ValueError(
f"No function mapping for `{func_name}` device `{device_type}` with the right properties"
)
continue

repo_with_mode = _select_repository(
repos,
mode=mode,
)

if repo_with_mode is None:
if not use_fallback:
raise ValueError(
f"No repository for `{func_name}` for configuration mode={mode}"
)
continue

repo, repo_mode = repo_with_mode

logging.info(f"Using function `{func_name}` from repo {repo}")

try:
kernel_func = _get_kernel_function(repo, func_name)
setattr(module, attr_name, kernel_func)
except ValueError as e:
if not use_fallback:
raise
warnings.warn(
f"Failed to load kernel function `{func_name}`: {e}. "
f"Defaulting to original implementation."
)

return model


Expand Down Expand Up @@ -1031,6 +1100,61 @@ def decorator(cls):
return decorator


def use_kernel_func_from_hub(func_name: str):
"""
Decorator factory that marks a function attribute to be replaceable by a kernel implementation.

This is a decorator factory that returns a decorator which marks a class as having a replaceable
function attribute. During kernelization, if a kernel provides a matching function, it will replace
the original function.

Args:
func_name (`str`):
The name of the function to use for kernel lookup in registered mappings.

Returns:
`Callable`: A decorator function that can be applied to classes.

Example:
```python
import torch
import torch.nn as nn

from kernels import use_kernel_func_from_hub
from kernels import Mode, kernelize

def MyCustomFunc(x):
# original implementation
return x

@use_kernel_func_from_hub("custom_fn")
class MyCustomLayer(nn.Module):
def __init__(self, hidden_size):
super().__init__()
self.hidden_size = hidden_size
self.custom_fn = MyCustomFunc

def forward(self, x: torch.Tensor):
# use self.custom_fn in forward
y = self.custom_fn(x)
return y

model = MyCustomLayer(768)

# The custom_fn can now be kernelized:
# model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda")
```
"""

def decorator(cls):
if not hasattr(cls, "_kernel_func_mappings"):
cls._kernel_func_mappings = {}
cls._kernel_func_mappings[func_name] = func_name
return cls

return decorator


def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]:
"""Get a layer from a kernel."""

Expand All @@ -1045,6 +1169,32 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]:
return layer


def _get_kernel_function(repo: LayerRepositoryProtocol, func_name: str):
"""Get a function from a kernel."""

# Check cache first to avoid re-downloading
kernel = _CACHED_KERNEL_MODULE.get(repo)
if kernel is None:
kernel = repo.load()
_CACHED_KERNEL_MODULE[repo] = kernel

# Use the layer_name from repo as the actual function name in the kernel
actual_func_name = repo.layer_name

# Try to get function from kernel.functions first (if it exists)
if hasattr(kernel, "functions"):
func = getattr(kernel.functions, actual_func_name, None)
if func is not None:
return func

# Fall back to looking for the function directly in the kernel module
func = getattr(kernel, actual_func_name, None)
if func is None:
raise ValueError(f"Function `{actual_func_name}` not found in kernel repo {repo}.")

return func


def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol):
import torch.nn as nn

Expand Down