diff --git a/src/kernels/layer.py b/src/kernels/layer.py index 47d918c..4e67561 100644 --- a/src/kernels/layer.py +++ b/src/kernels/layer.py @@ -87,7 +87,7 @@ class Device: Args: type (`str`): - The device type (e.g., "cuda", "mps", "cpu"). + The device type (e.g., "cuda", "mps", "rocm"). properties ([`CUDAProperties`], *optional*): Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices. @@ -531,7 +531,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]): def _validate_device_type(device_type: str) -> None: """Validate that the device type is supported.""" - supported_devices = {"cuda", "rocm", "mps", "cpu"} + supported_devices = {"cuda", "rocm", "mps"} if device_type not in supported_devices: raise ValueError( f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}" @@ -789,7 +789,7 @@ def kernelize( The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with `torch.compile`. device (`Union[str, torch.device]`, *optional*): - The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu". + The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm". The device type will be inferred from the model parameters when not provided. use_fallback (`bool`, *optional*, defaults to `True`): Whether to use the original forward method of modules when no compatible kernel could be found. diff --git a/tests/test_layer.py b/tests/test_layer.py index 8b920a5..a6b224c 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -110,24 +110,20 @@ def forward( @pytest.mark.cuda_only @pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice]) -@pytest.mark.parametrize("device", ["cuda", "cpu"]) -def test_hub_forward(cls, device): +def test_hub_forward(cls): torch.random.manual_seed(0) silu_and_mul = SiluAndMul() - X = torch.randn((32, 64), device=device) + X = torch.randn((32, 64), device="cuda") Y = silu_and_mul(X) - silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE) + silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE) Y_kernel = silu_and_mul_with_kernel(X) torch.testing.assert_close(Y_kernel, Y) assert silu_and_mul.n_calls == 1 - if device == "cuda": - assert silu_and_mul_with_kernel.n_calls == 0 - else: - assert silu_and_mul_with_kernel.n_calls == 1 + assert silu_and_mul_with_kernel.n_calls == 0 @pytest.mark.rocm_only