Skip to content

Conversation

@YangKai0616
Copy link
Contributor

I wrote a simple script to test:

#!/usr/bin/env python

import torch
import torch.nn as nn

from kernels import (
    Mode,
    LayerRepository,
    use_kernel_mapping,
    kernelize,
    use_kernel_forward_from_hub,
)

def rms_norm_ref(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6):
    var = x.pow(2).mean(-1, keepdim=True)
    x_norm = x * torch.rsqrt(var + eps)
    return x_norm * weight


@use_kernel_forward_from_hub("LigerRMSNorm")
class MyRMSNorm(nn.Module):
    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, x: torch.Tensor):
        return rms_norm_ref(x, self.weight, self.variance_epsilon)


def main():
    if not hasattr(torch, "xpu"):
        print("[SKIP] Current PyTorch does not include torch.xpu API.")
        return
    if not torch.xpu.is_available():
        print("[SKIP] No XPU device detected.")
        return

    device = torch.device("xpu")
    hidden = 1024
    batch = 4
    length = 16

    # Model and input
    torch.manual_seed(0)
    model = MyRMSNorm(hidden).to(device)
    x = torch.randn(batch, length, hidden, device=device, dtype=torch.float32)

    # Kernel mapping
    mapping = {
        "LigerRMSNorm": {
            "xpu": LayerRepository(
                repo_id="kernels-community/liger_kernels",
                layer_name="LigerRMSNorm",
            )
        }
    }

    print("Registering kernel mapping and performing kernelize...")
    with use_kernel_mapping(mapping):
        km = kernelize(model, mode=Mode.INFERENCE, device="xpu")

        # Forward
        y_kernel = km(x)

    # Reference implementation (reuse the same weight)
    ref = rms_norm_ref(x, model.weight)
    max_abs_diff = (y_kernel - ref).abs().max().item()
    print(f"Max absolute difference: {max_abs_diff:.3e}")

    # Tolerance based on experience (kernel may fuse ops / change precision)
    tol = 5e-4
    if max_abs_diff < tol:
        print("[PASS] Kernel output matches the reference implementation.")
    else:
        print("[WARN] Difference exceeds threshold; kernel may not be loaded correctly or precision differs.")

    # Print forward origin (debug only)
    print(f"Current forward callable: {km.forward} (type={type(km.forward)})")


if __name__ == "__main__":
    main()

The test results:

Registering kernel mapping and performing kernelize...
Fetching 17 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 17/17 [00:00<00:00, 188135.01it/s]
Max absolute difference: 4.768e-07
[PASS] Kernel output matches the reference implementation.
Current forward callable: <bound method LigerRMSNorm.forward of MyRMSNorm()> (type=<class 'method'>)

@YangKai0616 YangKai0616 changed the title Add support for XPU for kernel calls [XPU] Add support for XPU to kernel calls Sep 10, 2025
@YangKai0616
Copy link
Contributor Author

Hi @danieldk , should I add the test case to test_layer.py like rocm does?

@danieldk
Copy link
Member

Hi @danieldk , should I add the test case to test_layer.py like rocm does?

Hey. Yeah, having a test like that for an XPU repo would be great!

@YangKai0616
Copy link
Contributor Author

YangKai0616 commented Sep 11, 2025

Hey. Yeah, having a test like that for an XPU repo would be great!

Done. Pls review, thx! (I have tested test_layer.py on both Intel PVC 1550 and NVIDIA A100)

@danieldk
Copy link
Member

Merged in #142. Thanks a lot the PR!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants