Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/kernels/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Device:

Args:
type (`str`):
The device type (e.g., "cuda", "mps", "cpu").
The device type (e.g., "cuda", "mps", "rocm").
properties ([`CUDAProperties`], *optional*):
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.

Expand Down Expand Up @@ -531,7 +531,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):

def _validate_device_type(device_type: str) -> None:
"""Validate that the device type is supported."""
supported_devices = {"cuda", "rocm", "mps", "cpu"}
supported_devices = {"cuda", "rocm", "mps"}
if device_type not in supported_devices:
raise ValueError(
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
Expand Down Expand Up @@ -789,7 +789,7 @@ def kernelize(
The mode that the kernel is going to be used in. For example, `Mode.TRAINING | Mode.TORCH_COMPILE`
kernelizes the model for training with `torch.compile`.
device (`Union[str, torch.device]`, *optional*):
The device type to load kernels for. Supported device types are: "cuda", "rocm", "mps", "cpu".
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
The device type will be inferred from the model parameters when not provided.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the original forward method of modules when no compatible kernel could be found.
Expand Down
12 changes: 4 additions & 8 deletions tests/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,24 +110,20 @@ def forward(

@pytest.mark.cuda_only
@pytest.mark.parametrize("cls", [SiluAndMulWithKernel, SiluAndMulStringDevice])
@pytest.mark.parametrize("device", ["cuda", "cpu"])
def test_hub_forward(cls, device):
def test_hub_forward(cls):
torch.random.manual_seed(0)

silu_and_mul = SiluAndMul()
X = torch.randn((32, 64), device=device)
X = torch.randn((32, 64), device="cuda")
Y = silu_and_mul(X)

silu_and_mul_with_kernel = kernelize(cls(), device=device, mode=Mode.INFERENCE)
silu_and_mul_with_kernel = kernelize(cls(), device="cuda", mode=Mode.INFERENCE)
Y_kernel = silu_and_mul_with_kernel(X)

torch.testing.assert_close(Y_kernel, Y)

assert silu_and_mul.n_calls == 1
if device == "cuda":
assert silu_and_mul_with_kernel.n_calls == 0
else:
assert silu_and_mul_with_kernel.n_calls == 1
assert silu_and_mul_with_kernel.n_calls == 0


@pytest.mark.rocm_only
Expand Down
Loading