Skip to content

Conversation

silee2
Copy link
Contributor

@silee2 silee2 commented Oct 8, 2025

XeVM to LLVM pass: Add conversion patterns for XeVM id ops.

Target OpenCL functions described here:
https://registry.khronos.org/OpenCL/sdk/3.0/docs/man/html/get_group_id.html

@llvmbot llvmbot added the mlir label Oct 8, 2025
@llvmbot
Copy link
Member

llvmbot commented Oct 8, 2025

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

XeVM to LLVM pass: Add conversion patterns for XeVM id ops.

Target OpenCL functions described here:
https://registry.khronos.org/OpenCL/sdk/3.0/docs/man/html/get_group_id.html


Full diff: https://github.com/llvm/llvm-project/pull/162536.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+147-1)
  • (modified) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+146-1)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f449d90eb67a5..4214d09c3e5ef 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -714,6 +714,137 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
   }
 };
 
+//===----------------------------------------------------------------------===//
+// GPU index id operations
+//===----------------------------------------------------------------------===//
+/*
+// Launch Config ops
+//   dimidx - x, y, x - is fixed to i32
+//   return type is set by XeVM type converter
+// get_local_id
+xevm::WorkitemIdXOp;
+xevm::WorkitemIdYOp;
+xevm::WorkitemIdZOp;
+// get_local_size
+xevm::WorkgroupDimXOp;
+xevm::WorkgroupDimYOp;
+xevm::WorkgroupDimZOp;
+// get_group_id
+xevm::WorkgroupIdXOp;
+xevm::WorkgroupIdYOp;
+xevm::WorkgroupIdZOp;
+// get_num_groups
+xevm::GridDimXOp;
+xevm::GridDimYOp;
+xevm::GridDimZOp;
+// get_global_id : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name and dimension argument for each op.
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
+  return {"get_local_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
+  return {"get_local_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
+  return {"get_local_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
+  return {"get_local_size", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
+  return {"get_local_size", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
+  return {"get_local_size", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
+  return {"get_group_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
+  return {"get_group_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
+  return {"get_group_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
+  return {"get_num_groups", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
+  return {"get_num_groups", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
+  return {"get_num_groups", 2};
+}
+/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
+/// a constant argument for the dimension - x, y or z.
+template <typename OpType>
+class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    Location loc = op->getLoc();
+    std::pair<StringRef, int64_t> config = getConfig(op);
+    std::string baseName = config.first.str();
+    Type dimTy = rewriter.getI32Type();
+    int64_t dim = config.second;
+    Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+                                            static_cast<int64_t>(dim));
+    std::string func = mangle(baseName, {dimTy}, {true});
+    Type resTy = op.getType();
+    auto call =
+        createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
+                                 noUnwindWillReturnAttrs, op.getOperation());
+    constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/noModRef,
+        /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+    call.setMemoryEffectsAttr(memAttr);
+    rewriter.replaceOp(op, call);
+    return success();
+  }
+};
+
+/*
+// Subgroup ops
+// get_sub_group_local_id
+xevm::LaneIdOp;
+// get_sub_group_id
+xevm::SubgroupIdOp;
+// get_sub_group_size
+xevm::SubgroupSizeOp;
+// get_num_sub_groups : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name for each op.
+static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
+static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
+static StringRef getConfig(xevm::SubgroupSizeOp) {
+  return "get_sub_group_size";
+}
+template <typename OpType>
+class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    std::string func = mangle(getConfig(op).str(), {});
+    Type resTy = op.getType();
+    auto call =
+        createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
+                                 noUnwindWillReturnAttrs, op.getOperation());
+    constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+    auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+        /*other=*/noModRef,
+        /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+    call.setMemoryEffectsAttr(memAttr);
+    rewriter.replaceOp(op, call);
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Pass Definition
 //===----------------------------------------------------------------------===//
@@ -775,7 +906,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
                LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
                LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
                BlockLoadStore1DToOCLPattern<BlockLoadOp>,
-               BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+               BlockLoadStore1DToOCLPattern<BlockStoreOp>,
+               LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
+               LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
+               LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
+               LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
+               LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
+               LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
+               LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
+               LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
+               LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
+               LaunchConfigOpToOCLPattern<GridDimXOp>,
+               LaunchConfigOpToOCLPattern<GridDimYOp>,
+               LaunchConfigOpToOCLPattern<GridDimZOp>,
+               SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
+               SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
+               SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
       patterns.getContext());
 }
 
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index b31a973ffd6a1..72e70ff519b77 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -35,7 +35,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
 // -----
 // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
 llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
-  // CHECK: xevm.DecorationCacheControl = 
+  // CHECK: xevm.DecorationCacheControl =
   // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
   // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
@@ -345,3 +345,148 @@ llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
   xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
   llvm.return
 }
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.x() -> i32 {
+llvm.func @local_id.x() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]])
+  // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+  %1 = xevm.local_id.x : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.y() -> i32 {
+llvm.func @local_id.y() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+  %1 = xevm.local_id.y : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.z() -> i32 {
+llvm.func @local_id.z() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %1 = xevm.local_id.z : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.x() -> i32 {
+llvm.func @local_size.x() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]])
+  // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32
+  %1 = xevm.local_size.x : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.y() -> i32 {
+llvm.func @local_size.y() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+  %1 = xevm.local_size.y : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.z() -> i32 {
+llvm.func @local_size.z() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %1 = xevm.local_size.z : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.x() -> i32 {
+llvm.func @group_id.x() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]])
+  // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+  %1 = xevm.group_id.x : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.y() -> i32 {
+llvm.func @group_id.y() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+  %1 = xevm.group_id.y : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.z() -> i32 {
+llvm.func @group_id.z() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %1 = xevm.group_id.z : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.x() -> i32 {
+llvm.func @group_count.x() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+  // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]])
+  // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+  %1 = xevm.group_count.x : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.y() -> i32 {
+llvm.func @group_count.y() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+  %1 = xevm.group_count.y : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.z() -> i32 {
+llvm.func @group_count.z() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+  %1 = xevm.group_count.z : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
+llvm.func @lane_id() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
+  // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32
+  %1 = xevm.lane_id : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
+llvm.func @subgroup_size() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size()
+  // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32
+  %1 = xevm.subgroup_size : i32
+  llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
+llvm.func @subgroup_id() -> i32 {
+  // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
+  // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+  // CHECK-SAME:  memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+  // CHECK-SAME:  no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32
+  %1 = xevm.subgroup_id : i32
+  llvm.return %1 : i32
+}

Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

//===----------------------------------------------------------------------===//
/*
// Launch Config ops
// dimidx - x, y, x - is fixed to i32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you mean

Suggested change
// dimidx - x, y, x - is fixed to i32
// dimidx - x, y, z - is fixed to i32

?

matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
std::pair<StringRef, int64_t> config = getConfig(op);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about structured bindings? Smth like

Suggested change
std::pair<StringRef, int64_t> config = getConfig(op);
auto [baseName, dim] = getConfig(op);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for suggestion. Totally forgot that was possible with pair.

*/

// Helpers to get the OpenCL function name and dimension argument for each op.
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it has to be an extension to tablegen special register ops.
But AFAIU with Intel GPUs, without their own LLVM backend, xevm could similarly target not only OpenCL calls, but also SPIRV calls directly or potentially some other llvm-func-based interface to built-ins, so it is up to a specific conversion to set up the proper function name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Intel GPU, without it's own LLVM backend, this pass is used as an alternative to tablegen based translation with llvmBuilder.
Although, XeVM to LLVM currently exist as a conversion,
In reality it is set up as a XeVM to LLVMOCL transform.
It is a transform since the source XeVM is LLVM (extension) dialect and the target is LLVM as well.
For SPIRV calls, it will require a separate transform like XeVM to LLVMSPV.

Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants