Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ markers =
cuda_only: marks tests that should only hosts with CUDA GPUs
rocm_only: marks tests that should only run on hosts with ROCm GPUs
darwin_only: marks tests that should only run on macOS
xpu_only: marks tests that should only run on hosts with Intel XPUs
31 changes: 28 additions & 3 deletions src/kernels/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class Device:

Args:
type (`str`):
The device type (e.g., "cuda", "mps", "rocm").
The device type (e.g., "cuda", "mps", "rocm", "xpu").
properties ([`CUDAProperties`], *optional*):
Device-specific properties. Currently only [`CUDAProperties`] is supported for CUDA devices.

Expand All @@ -106,6 +106,9 @@ class Device:

# MPS device for Apple Silicon
mps_device = Device(type="mps")

# XPU device (e.g., Intel(R) Data Center GPU Max 1550)
xpu_device = Device(type="xpu")
```
"""

Expand All @@ -125,6 +128,8 @@ def create_repo(self) -> _DeviceRepos:
return _ROCMRepos()
elif self.type == "mps":
return _MPSRepos()
elif self.type == "xpu":
return _XPURepos()
else:
raise ValueError(f"Unknown device type: {self.type}")

Expand Down Expand Up @@ -447,6 +452,26 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
...


class _XPURepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]

def __init__(self):
super().__init__()
self._repos = {}

@property
def repos(
self,
) -> Optional[Dict[Mode, LayerRepositoryProtocol]]:
return self._repos

def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):
if device.type != "xpu":
raise ValueError(f"Device type must be 'xpu', got {device.type}")

self._repos = repos


class _MPSRepos(_DeviceRepos):
_repos: Dict[Mode, LayerRepositoryProtocol]

Expand Down Expand Up @@ -531,7 +556,7 @@ def insert(self, device: Device, repos: Dict[Mode, LayerRepositoryProtocol]):

def _validate_device_type(device_type: str) -> None:
"""Validate that the device type is supported."""
supported_devices = {"cuda", "rocm", "mps"}
supported_devices = {"cuda", "rocm", "mps", "xpu"}
if device_type not in supported_devices:
raise ValueError(
f"Unsupported device type '{device_type}'. Supported device types are: {', '.join(sorted(supported_devices))}"
Expand Down Expand Up @@ -789,7 +814,7 @@ def kernelize(
`Mode.TRAINING | Mode.TORCH_COMPILE` kernelizes the model for training with
`torch.compile`.
device (`Union[str, torch.device]`, *optional*):
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm".
The device type to load kernels for. Supported device types are: "cuda", "mps", "rocm", "xpu".
The device type will be inferred from the model parameters when not provided.
use_fallback (`bool`, *optional*, defaults to `True`):
Whether to use the original forward method of modules when no compatible kernel could be found.
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@
and torch.version.hip is not None
and torch.cuda.device_count() > 0
)
has_xpu = (
hasattr(torch.version, "xpu")
and torch.version.xpu is not None
and torch.xpu.device_count() > 0
)


def pytest_runtest_setup(item):
Expand All @@ -22,3 +27,5 @@ def pytest_runtest_setup(item):
pytest.skip("skipping ROCm-only test on host without ROCm")
if "darwin_only" in item.keywords and not sys.platform.startswith("darwin"):
pytest.skip("skipping macOS-only test on non-macOS platform")
if "xpu_only" in item.keywords and not has_xpu:
pytest.skip("skipping XPU-only test on host without XPU")
99 changes: 83 additions & 16 deletions tests/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,37 @@
layer_name="SiluAndMul",
)
},
"LigerRMSNorm": {
"xpu": LayerRepository(
repo_id="kernels-community/liger_kernels",
layer_name="LigerRMSNorm", # Triton
)
},
}

register_kernel_mapping(kernel_layer_mapping)


class RMSNorm(nn.Module):
def __init__(self, weight: torch.Tensor, eps: float = 1e-6):
super().__init__()
# Used to check that we called hub kernel.
self.n_calls = 0
self.weight = nn.Parameter(weight)
self.variance_epsilon = eps

def forward(self, x: torch.Tensor):
self.n_calls += 1
var = x.pow(2).mean(-1, keepdim=True)
x_norm = x * torch.rsqrt(var + self.variance_epsilon)
return x_norm * self.weight


@use_kernel_forward_from_hub("LigerRMSNorm")
class RMSNormWithKernel(RMSNorm):
pass


class SiluAndMul(nn.Module):
def __init__(self):
super().__init__()
Expand Down Expand Up @@ -90,6 +116,16 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
return super().forward(input)


@pytest.fixture
def device():
if torch.cuda.is_available():
return "cuda"
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return "xpu"

pytest.skip("No CUDA or XPU")


def test_arg_kinds():
@use_kernel_forward_from_hub("ArgKind")
class ArgKind(nn.Module):
Expand Down Expand Up @@ -147,6 +183,31 @@ def test_hub_forward_rocm():
assert silu_and_mul_with_kernel.n_calls in [0, 1]


@pytest.mark.xpu_only
def test_hub_forward_xpu():
torch.manual_seed(0)

hidden_size = 1024
weight = torch.ones(hidden_size, device="xpu")
rms_norm = RMSNorm(weight).to("xpu")
X = torch.randn(4, 16, hidden_size, device="xpu", dtype=torch.float32)
Y = rms_norm(X)

rms_norm_with_kernel = kernelize(
RMSNormWithKernel(weight), mode=Mode.INFERENCE, device="xpu"
)
Y_kernel = rms_norm_with_kernel(X)

torch.testing.assert_close(Y_kernel, Y)

assert rms_norm.n_calls == 1
assert rms_norm_with_kernel.n_calls == 0


@pytest.mark.skipif(
hasattr(torch, "xpu") and getattr(torch.xpu, "is_available", lambda: False)(),
reason="Skip on xpu devices",
)
def test_rocm_kernel_mapping():
"""Test that ROCm shorthand device mapping works correctly."""
kernel_layer_mapping = {
Expand Down Expand Up @@ -234,16 +295,16 @@ class SiluAndMulWithKernelFallback(SiluAndMul):
kernelize(silu_and_mul, device="cuda", mode=Mode.INFERENCE)


def test_local_layer_repo():
def test_local_layer_repo(device):
# Fetch a kernel to the local cache.
package_name, path = install_kernel("kernels-test/backward-marker-test", "main")

linear = TorchLinearWithCounter(32, 32).to("cuda")
linear = TorchLinearWithCounter(32, 32).to(device)

with use_kernel_mapping(
{
"Linear": {
"cuda": LocalLayerRepository(
device: LocalLayerRepository(
# install_kernel will give the fully-resolved path.
repo_path=path.parent.parent,
package_name=package_name,
Expand All @@ -255,7 +316,7 @@ def test_local_layer_repo():
):
kernelize(linear, mode=Mode.INFERENCE)

X = torch.randn(10, 32, device="cuda")
X = torch.randn(10, 32, device=device)
linear(X)
assert linear.n_calls == 0

Expand Down Expand Up @@ -323,6 +384,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
}

extra_mapping1 = {
Expand All @@ -340,6 +402,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}

Expand All @@ -358,6 +421,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
assert (
Expand All @@ -371,6 +435,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
assert (
Expand All @@ -393,6 +458,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
"TestKernel",
}
assert (
Expand All @@ -404,6 +470,7 @@ def test_mapping_contexts():
"SiluAndMul",
"SiluAndMulStringDevice",
"SiluAndMulNoCompile",
"LigerRMSNorm",
}


Expand Down Expand Up @@ -923,7 +990,7 @@ def test_kernel_modes_cross_fallback():
assert linear.n_calls == 2


def test_layer_versions():
def test_layer_versions(device):
@use_kernel_forward_from_hub("Version")
class Version(nn.Module):
def forward(self) -> str:
Expand All @@ -934,62 +1001,62 @@ def forward(self) -> str:
with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.2.0"

with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<1.0.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.2.0"

with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version="<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.1.1"

with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.1.0,<0.2.0",
)
}
}
):
version = kernelize(version, device="cuda", mode=Mode.INFERENCE)
version = kernelize(version, device=device, mode=Mode.INFERENCE)
assert version() == "0.1.1"

with use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
version=">0.2.0",
Expand All @@ -998,13 +1065,13 @@ def forward(self) -> str:
}
):
with pytest.raises(ValueError, match=r"No version.*satisfies requirement"):
kernelize(version, device="cuda", mode=Mode.INFERENCE)
kernelize(version, device=device, mode=Mode.INFERENCE)

with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"):
use_kernel_mapping(
{
"Version": {
Device(type="cuda"): LayerRepository(
Device(type=device): LayerRepository(
repo_id="kernels-test/versions",
layer_name="Version",
revision="v0.1.0",
Expand Down
Loading