diff --git a/pytest.ini b/pytest.ini index 376a549..d8fc63c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,3 +3,4 @@ markers = cuda_only: marks tests that should only hosts with CUDA GPUs rocm_only: marks tests that should only run on hosts with ROCm GPUs darwin_only: marks tests that should only run on macOS + xpu_only: marks tests that should only run on hosts with Intel XPUs diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 7952f9d..8eedb3a 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -87,7 +87,7 @@ class Device: Args: type (`str`): - The device type (e.g., "cuda", "mps", "rocm"). + The device type (e.g., "cuda", "mps", "rocm", "xpu"). properties ([`CUDAProperties`], *optional*): Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. @@ -106,6 +106,9 @@ class Device: # MPS device for Apple Silicon mps_device = Device(type="mps") + + # XPU device (e.g., Intel(R) Data Center GPU Max 1550) + xpu_device = Device(type="xpu") ``` """ @@ -125,6 +128,8 @@ def create_repo(self) -> _DeviceRepos: return _ROCMRepos() elif self.type == "mps": return _MPSRepos() + elif self.type == "xpu": + return _XPURepos() else: raise ValueError(f"Unknown device type: {self.type}") @@ -447,6 +452,26 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): ... +class _XPURepos(_DeviceRepos): + _repos: Dict[Mode, LayerRepositoryProtocol] + + def __init__(self): + super().__init__() + self._repos = {} + + @property + def repos( + self, + ) -> Optional[Dict[Mode, LayerRepositoryProtocol]]: + return self._repos + + def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): + if device.type != "xpu": + raise ValueError(f"Device type must be 'xpu', got {device.type}") + + self._repos = repos + + class _MPSRepos(_DeviceRepos): _repos: Dict[Mode, LayerRepositoryProtocol] @@ -531,7 +556,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): def _validate_device_type(device_type: str) -> None: """Validate that the device type is supported.""" - supported_devices = {"cuda", "rocm", "mps"} + supported_devices = {"cuda", "rocm", "mps", "xpu"} if device_type not in supported_devices: raise ValueError( f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" @@ -789,7 +814,7 @@ def kernelize( `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with `torch.compile`. device (`Union[str, torch.device]`, *optional*): - The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm". + The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu". The device type will be inferred from the model parameters when not provided. use_fallback (`bool`, *optional*, defaults to `True`): Whether to use the original forward method of modules when no compatible kernel could be found. diff --git a/tests/conftest.py b/tests/conftest.py index 04e705e..6d9d379 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,6 +13,11 @@ and torch.version.hip is not None and torch.cuda.device_count() > 0 ) +has_xpu = ( + hasattr(torch.version, "xpu") + and torch.version.xpu is not None + and torch.xpu.device_count() > 0 +) def pytest_runtest_setup(item): @@ -22,3 +27,5 @@ def pytest_runtest_setup(item): pytest.skip("skipping ROCm-only test on host without ROCm") if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"): pytest.skip("skipping macOS-only test on non-macOS platform") + if "xpu_only" in item.keywords and not has_xpu: + pytest.skip("skipping XPU-only test on host without XPU") diff --git a/tests/test_layer.py b/tests/test_layer.py index 87ea9c3..0d17ce1 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -46,11 +46,37 @@ layer_name="SiluAndMul", ) }, + "LigerRMSNorm": { + "xpu": LayerRepository( + repo_id="kernels-community/liger_kernels", + layer_name="LigerRMSNorm", # Triton + ) + }, } register_kernel_mapping(kernel_layer_mapping) +class RMSNorm(nn.Module): + def __init__(self, weight: torch.Tensor, eps: float = 1e-6): + super().__init__() + # Used to check that we called hub kernel. + self.n_calls = 0 + self.weight = nn.Parameter(weight) + self.variance_epsilon = eps + + def forward(self, x: torch.Tensor): + self.n_calls += 1 + var = x.pow(2).mean(-1, keepdim=True) + x_norm = x * torch.rsqrt(var + self.variance_epsilon) + return x_norm * self.weight + + +@use_kernel_forward_from_hub("LigerRMSNorm") +class RMSNormWithKernel(RMSNorm): + pass + + class SiluAndMul(nn.Module): def __init__(self): super().__init__() @@ -90,6 +116,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor: return super().forward(input) +@pytest.fixture +def device(): + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch, "xpu") and torch.xpu.is_available(): + return "xpu" + + pytest.skip("No CUDA or XPU") + + def test_arg_kinds(): @use_kernel_forward_from_hub("ArgKind") class ArgKind(nn.Module): @@ -147,6 +183,31 @@ def test_hub_forward_rocm(): assert silu_and_mul_with_kernel.n_calls in [0, 1] +@pytest.mark.xpu_only +def test_hub_forward_xpu(): + torch.manual_seed(0) + + hidden_size = 1024 + weight = torch.ones(hidden_size, device="xpu") + rms_norm = RMSNorm(weight).to("xpu") + X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32) + Y = rms_norm(X) + + rms_norm_with_kernel = kernelize( + RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu" + ) + Y_kernel = rms_norm_with_kernel(X) + + torch.testing.assert_close(Y_kernel, Y) + + assert rms_norm.n_calls == 1 + assert rms_norm_with_kernel.n_calls == 0 + + +@pytest.mark.skipif( + hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(), + reason="Skip on xpu devices", +) def test_rocm_kernel_mapping(): """Test that ROCm shorthand device mapping works correctly.""" kernel_layer_mapping = { @@ -234,16 +295,16 @@ class SiluAndMulWithKernelFallback(SiluAndMul): kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE) -def test_local_layer_repo(): +def test_local_layer_repo(device): # Fetch a kernel to the local cache. package_name, path = install_kernel("kernels-test/backward-marker-test", "main") - linear = TorchLinearWithCounter(32, 32).to("cuda") + linear = TorchLinearWithCounter(32, 32).to(device) with use_kernel_mapping( { "Linear": { - "cuda": LocalLayerRepository( + device: LocalLayerRepository( # install_kernel will give the fully-resolved path. repo_path=path.parent.parent, package_name=package_name, @@ -255,7 +316,7 @@ def test_local_layer_repo(): ): kernelize(linear, mode=Mode.INFERENCE) - X = torch.randn(10, 32, device="cuda") + X = torch.randn(10, 32, device=device) linear(X) assert linear.n_calls == 0 @@ -323,6 +384,7 @@ def test_mapping_contexts(): "SiluAndMul", "SiluAndMulStringDevice", "SiluAndMulNoCompile", + "LigerRMSNorm", } extra_mapping1 = { @@ -340,6 +402,7 @@ def test_mapping_contexts(): "SiluAndMul", "SiluAndMulStringDevice", "SiluAndMulNoCompile", + "LigerRMSNorm", "TestKernel", } @@ -358,6 +421,7 @@ def test_mapping_contexts(): "SiluAndMul", "SiluAndMulStringDevice", "SiluAndMulNoCompile", + "LigerRMSNorm", "TestKernel", } assert ( @@ -371,6 +435,7 @@ def test_mapping_contexts(): "SiluAndMul", "SiluAndMulStringDevice", "SiluAndMulNoCompile", + "LigerRMSNorm", "TestKernel", } assert ( @@ -393,6 +458,7 @@ def test_mapping_contexts(): "SiluAndMul", "SiluAndMulStringDevice", "SiluAndMulNoCompile", + "LigerRMSNorm", "TestKernel", } assert ( @@ -404,6 +470,7 @@ def test_mapping_contexts(): "SiluAndMul", "SiluAndMulStringDevice", "SiluAndMulNoCompile", + "LigerRMSNorm", } @@ -923,7 +990,7 @@ def test_kernel_modes_cross_fallback(): assert linear.n_calls == 2 -def test_layer_versions(): +def test_layer_versions(device): @use_kernel_forward_from_hub("Version") class Version(nn.Module): def forward(self) -> str: @@ -934,20 +1001,20 @@ def forward(self) -> str: with use_kernel_mapping( { "Version": { - Device(type="cuda"): LayerRepository( + Device(type=device): LayerRepository( repo_id="kernels-test/versions", layer_name="Version", ) } } ): - version = kernelize(version, device="cuda", mode=Mode.INFERENCE) + version = kernelize(version, device=device, mode=Mode.INFERENCE) assert version() == "0.2.0" with use_kernel_mapping( { "Version": { - Device(type="cuda"): LayerRepository( + Device(type=device): LayerRepository( repo_id="kernels-test/versions", layer_name="Version", version="<1.0.0", @@ -955,13 +1022,13 @@ def forward(self) -> str: } } ): - version = kernelize(version, device="cuda", mode=Mode.INFERENCE) + version = kernelize(version, device=device, mode=Mode.INFERENCE) assert version() == "0.2.0" with use_kernel_mapping( { "Version": { - Device(type="cuda"): LayerRepository( + Device(type=device): LayerRepository( repo_id="kernels-test/versions", layer_name="Version", version="<0.2.0", @@ -969,13 +1036,13 @@ def forward(self) -> str: } } ): - version = kernelize(version, device="cuda", mode=Mode.INFERENCE) + version = kernelize(version, device=device, mode=Mode.INFERENCE) assert version() == "0.1.1" with use_kernel_mapping( { "Version": { - Device(type="cuda"): LayerRepository( + Device(type=device): LayerRepository( repo_id="kernels-test/versions", layer_name="Version", version=">0.1.0,<0.2.0", @@ -983,13 +1050,13 @@ def forward(self) -> str: } } ): - version = kernelize(version, device="cuda", mode=Mode.INFERENCE) + version = kernelize(version, device=device, mode=Mode.INFERENCE) assert version() == "0.1.1" with use_kernel_mapping( { "Version": { - Device(type="cuda"): LayerRepository( + Device(type=device): LayerRepository( repo_id="kernels-test/versions", layer_name="Version", version=">0.2.0", @@ -998,13 +1065,13 @@ def forward(self) -> str: } ): with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): - kernelize(version, device="cuda", mode=Mode.INFERENCE) + kernelize(version, device=device, mode=Mode.INFERENCE) with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): use_kernel_mapping( { "Version": { - Device(type="cuda"): LayerRepository( + Device(type=device): LayerRepository( repo_id="kernels-test/versions", layer_name="Version", revision="v0.1.0",