diff --git a/src/kernels/__init__.py b/src/kernels/__init__.py index 9bd5ed4..461cfd0 100644 --- a/src/kernels/__init__.py +++ b/src/kernels/__init__.py @@ -13,6 +13,7 @@ register_kernel_mapping, replace_kernel_forward_from_hub, use_kernel_forward_from_hub, + use_kernel_func_from_hub, use_kernel_mapping, ) from kernels.utils import ( @@ -42,5 +43,6 @@ "register_kernel_mapping", "replace_kernel_forward_from_hub", "use_kernel_forward_from_hub", + "use_kernel_func_from_hub", "use_kernel_mapping", ] diff --git a/src/kernels/layer.py b/src/kernels/layer.py index f3e5265..6e94794 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -436,6 +436,7 @@ def __str__(self) -> str: _CACHED_LAYER: Dict[LayerRepositoryProtocol, Type["nn.Module"]] = {} +_CACHED_KERNEL_MODULE: Dict[LayerRepositoryProtocol, ModuleType] = {} class _DeviceRepos(ABC): @@ -982,6 +983,74 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: use_fallback=use_fallback, ) + # Second pass: replace kernel functions + for _, module in model.named_modules(): + module_class = type(module) + if not hasattr(module_class, "_kernel_func_mappings"): + continue + + func_mappings = module_class._kernel_func_mappings + + for attr_name, func_name in func_mappings.items(): + if _DISABLE_KERNEL_MAPPING: + continue + + kernel = _KERNEL_MAPPING.get().get(str(func_name)) + + if kernel is None: + warnings.warn( + "\n" + f"No kernel mapping found for function `{func_name}`. " + f"Defaulting to original implementation." + ) + continue + + # Get kernel options for the device + property_repos = kernel.get(device_type.type) + + if property_repos is None: + if not use_fallback: + raise ValueError( + f"No function mapping for `{func_name}` with device type `{device_type}`" + ) + continue + + repos = property_repos.repos + + if repos is None: + if not use_fallback: + raise ValueError( + f"No function mapping for `{func_name}` device `{device_type}` with the right properties" + ) + continue + + repo_with_mode = _select_repository( + repos, + mode=mode, + ) + + if repo_with_mode is None: + if not use_fallback: + raise ValueError( + f"No repository for `{func_name}` for configuration mode={mode}" + ) + continue + + repo, repo_mode = repo_with_mode + + logging.info(f"Using function `{func_name}` from repo {repo}") + + try: + kernel_func = _get_kernel_function(repo, func_name) + setattr(module, attr_name, kernel_func) + except ValueError as e: + if not use_fallback: + raise + warnings.warn( + f"Failed to load kernel function `{func_name}`: {e}. " + f"Defaulting to original implementation." + ) + return model @@ -1031,6 +1100,61 @@ def decorator(cls): return decorator +def use_kernel_func_from_hub(func_name: str): + """ + Decorator factory that marks a function attribute to be replaceable by a kernel implementation. + + This is a decorator factory that returns a decorator which marks a class as having a replaceable + function attribute. During kernelization, if a kernel provides a matching function, it will replace + the original function. + + Args: + func_name (`str`): + The name of the function to use for kernel lookup in registered mappings. + + Returns: + `Callable`: A decorator function that can be applied to classes. + + Example: + ```python + import torch + import torch.nn as nn + + from kernels import use_kernel_func_from_hub + from kernels import Mode, kernelize + + def MyCustomFunc(x): + # original implementation + return x + + @use_kernel_func_from_hub("custom_fn") + class MyCustomLayer(nn.Module): + def __init__(self, hidden_size): + super().__init__() + self.hidden_size = hidden_size + self.custom_fn = MyCustomFunc + + def forward(self, x: torch.Tensor): + # use self.custom_fn in forward + y = self.custom_fn(x) + return y + + model = MyCustomLayer(768) + + # The custom_fn can now be kernelized: + # model = kernelize(model, mode=Mode.TRAINING | Mode.TORCH_COMPILE, device="cuda") + ``` + """ + + def decorator(cls): + if not hasattr(cls, "_kernel_func_mappings"): + cls._kernel_func_mappings = {} + cls._kernel_func_mappings[func_name] = func_name + return cls + + return decorator + + def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]: """Get a layer from a kernel.""" @@ -1045,6 +1169,32 @@ def _get_kernel_layer(repo: LayerRepositoryProtocol) -> Type["nn.Module"]: return layer +def _get_kernel_function(repo: LayerRepositoryProtocol, func_name: str): + """Get a function from a kernel.""" + + # Check cache first to avoid re-downloading + kernel = _CACHED_KERNEL_MODULE.get(repo) + if kernel is None: + kernel = repo.load() + _CACHED_KERNEL_MODULE[repo] = kernel + + # Use the layer_name from repo as the actual function name in the kernel + actual_func_name = repo.layer_name + + # Try to get function from kernel.functions first (if it exists) + if hasattr(kernel, "functions"): + func = getattr(kernel.functions, actual_func_name, None) + if func is not None: + return func + + # Fall back to looking for the function directly in the kernel module + func = getattr(kernel, actual_func_name, None) + if func is None: + raise ValueError(f"Function `{actual_func_name}` not found in kernel repo {repo}.") + + return func + + def _validate_layer(*, check_cls, cls, repo: LayerRepositoryProtocol): import torch.nn as nn