diff --git a/pytest.ini b/pytest.ini index 78581f4..a942116 100644 --- a/pytest.ini +++ b/pytest.ini @@ -4,4 +4,5 @@ markers = rocm_only: marks tests that should only run on hosts with ROCm GPUs darwin_only: marks tests that should only run on macOS xpu_only: marks tests that should only run on hosts with Intel XPUs + npu_only: marks tests that should only run on Ascend NPUs token: enable tests that require a write token diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 9032b79..f3e5265 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -87,7 +87,7 @@ class Device: Args: type (`str`): - The device type (e.g., "cuda", "mps", "rocm", "xpu"). + The device type (e.g., "cuda", "mps", "npu", "rocm", "xpu"). properties ([`CUDAProperties`], *optional*): Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. @@ -109,6 +109,9 @@ class Device: # XPU device (e.g., Intel(R) Data Center GPU Max 1550) xpu_device = Device(type="xpu") + + # NPU device (Huawei Ascend) + npu_device = Device(type="npu") ``` """ @@ -130,6 +133,8 @@ def create_repo(self) -> _DeviceRepos: return _MPSRepos() elif self.type == "xpu": return _XPURepos() + elif self.type == "npu": + return _NPURepos() else: raise ValueError(f"Unknown device type: {self.type}") @@ -472,6 +477,26 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): self._repos = repos +class _NPURepos(_DeviceRepos): + _repos: Dict[Mode, LayerRepositoryProtocol] + + def __init__(self): + super().__init__() + self._repos = {} + + @property + def repos( + self, + ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + return self._repos + + def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + if device.type != "npu": + raise ValueError(f"Device type must be 'npu', got {device.type}") + + self._repos = repos + + class _MPSRepos(_DeviceRepos): _repos: Dict[Mode, LayerRepositoryProtocol] @@ -556,7 +581,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): def _validate_device_type(device_type: str) -> None: """Validate that the device type is supported.""" - supported_devices = {"cuda", "rocm", "mps", "xpu"} + supported_devices = {"cuda", "mps", "npu", "rocm", "xpu"} if device_type not in supported_devices: raise ValueError( f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" @@ -814,7 +839,7 @@ def kernelize( `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with `torch.compile`. device (`Union[str, torch.device]`, *optional*): - The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu". + The device type to load kernels for. Supported device types are: "cuda", "mps", "npu", "rocm", "xpu". The device type will be inferred from the model parameters when not provided. use_fallback (`bool`, *optional*, defaults to `True`): Whether to use the original forward method of modules when no compatible kernel could be found. @@ -838,7 +863,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.silu(x[..., :d]) * x[..., d:] mapping = { - "LayerNorm": { + "SiluAndMul": { "cuda": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", diff --git a/src/kernels/utils.py b/src/kernels/utils.py index c956f4f..2bae1c1 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -35,6 +35,13 @@ def _get_cache_dir() -> Optional[str]: CACHE_DIR: Optional[str] = _get_cache_dir() +def _get_privateuse_backend_name() -> Optional[str]: + import torch + if hasattr(torch._C, "_get_privateuse1_backend_name"): + return torch._C._get_privateuse1_backend_name() + return None + + def build_variant() -> str: import torch @@ -49,9 +56,13 @@ def build_variant() -> str: elif torch.version.xpu is not None: version = torch.version.xpu compute_framework = f"xpu{version[0:4]}{version[5:6]}" + elif _get_privateuse_backend_name() == "npu": + from torch_npu.utils.collect_env import get_cann_version + cann_major, cann_minor = get_cann_version()[0], get_cann_version()[2] + compute_framework = f"cann{cann_major}{cann_minor}" else: raise AssertionError( - "Torch was not compiled with CUDA, Metal, XPU, or ROCm enabled." + "Torch was not compiled with CUDA, Metal, XPU, NPU, or ROCm enabled." ) torch_version = parse(torch.__version__) diff --git a/tests/conftest.py b/tests/conftest.py index 49f2d5e..82fcea0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,8 @@ import pytest import torch +from kernels.utils import _get_privateuse_backend_name + has_cuda = ( hasattr(torch.version, "cuda") and torch.version.cuda is not None @@ -18,6 +20,9 @@ and torch.version.xpu is not None and torch.xpu.device_count() > 0 ) +has_npu = ( + _get_privateuse_backend_name() == "npu" +) def pytest_addoption(parser): @@ -37,5 +42,7 @@ def pytest_runtest_setup(item): pytest.skip("skipping macOS-only test on non-macOS platform") if "xpu_only" in item.keywords and not has_xpu: pytest.skip("skipping XPU-only test on host without XPU") + if "npu_only" in item.keywords and not has_npu: + pytest.skip("skipping NPU-only test on host without NPU") if "token" in item.keywords and not item.config.getoption("--token"): pytest.skip("need --token option to run this test") diff --git a/tests/test_kernel_locking.py b/tests/test_kernel_locking.py index e4691b2..7daaa88 100644 --- a/tests/test_kernel_locking.py +++ b/tests/test_kernel_locking.py @@ -35,6 +35,7 @@ def test_load_locked(): load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") +@pytest.mark.cuda_only def test_layer_locked(): project_dir = Path(__file__).parent / "layer_locking" diff --git a/tests/test_layer.py b/tests/test_layer.py index 7bfffca..6d0a8b8 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -21,14 +21,21 @@ _KERNEL_MAPPING, _validate_layer, ) -from kernels.utils import install_kernel +from kernels.utils import ( + _get_privateuse_backend_name, + install_kernel, +) kernel_layer_mapping = { "SiluAndMul": { Device(type="cuda"): LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", - ) + ), + "npu": LayerRepository( + repo_id="kernels-ext-npu/SwiGlu", + layer_name="SwiGlu", + ), }, "SiluAndMulNoCompile": { "cuda": LayerRepository( @@ -122,8 +129,10 @@ def device(): return "cuda" elif hasattr(torch, "xpu") and torch.xpu.is_available(): return "xpu" + elif _get_privateuse_backend_name() == "npu": + return "npu" - pytest.skip("No CUDA or XPU") + pytest.skip("No CUDA, NPU or XPU") def test_arg_kinds(): @@ -204,10 +213,33 @@ def test_hub_forward_xpu(): assert rms_norm_with_kernel.n_calls == 0 +@pytest.mark.npu_only +def test_hub_forward_npu(): + torch.manual_seed(0) + + silu_and_mul = SiluAndMul() + X = torch.randn((32, 64), device="npu") + Y = silu_and_mul(X) + + silu_and_mul_with_kernel = kernelize( + SiluAndMulWithKernel(), device="npu", mode=Mode.INFERENCE + ) + Y_kernel = silu_and_mul_with_kernel(X) + + torch.testing.assert_close(Y_kernel, Y) + + assert silu_and_mul.n_calls == 1 + assert silu_and_mul_with_kernel.n_calls == 0 + + @pytest.mark.skipif( hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(), reason="Skip on xpu devices", ) +@pytest.mark.skipif( + _get_privateuse_backend_name() == "npu", + reason="Skip on npu devices", +) def test_rocm_kernel_mapping(): """Test that ROCm shorthand device mapping works correctly.""" kernel_layer_mapping = {