-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR][XeVM] XeVM to LLVM: Add conversion patterns for id ops #162536
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesXeVM to LLVM pass: Add conversion patterns for XeVM id ops. Target OpenCL functions described here: Full diff: https://github.com/llvm/llvm-project/pull/162536.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index f449d90eb67a5..4214d09c3e5ef 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -714,6 +714,137 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
}
};
+//===----------------------------------------------------------------------===//
+// GPU index id operations
+//===----------------------------------------------------------------------===//
+/*
+// Launch Config ops
+// dimidx - x, y, x - is fixed to i32
+// return type is set by XeVM type converter
+// get_local_id
+xevm::WorkitemIdXOp;
+xevm::WorkitemIdYOp;
+xevm::WorkitemIdZOp;
+// get_local_size
+xevm::WorkgroupDimXOp;
+xevm::WorkgroupDimYOp;
+xevm::WorkgroupDimZOp;
+// get_group_id
+xevm::WorkgroupIdXOp;
+xevm::WorkgroupIdYOp;
+xevm::WorkgroupIdZOp;
+// get_num_groups
+xevm::GridDimXOp;
+xevm::GridDimYOp;
+xevm::GridDimZOp;
+// get_global_id : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name and dimension argument for each op.
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
+ return {"get_local_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
+ return {"get_local_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
+ return {"get_local_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
+ return {"get_local_size", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
+ return {"get_local_size", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
+ return {"get_local_size", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
+ return {"get_group_id", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
+ return {"get_group_id", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
+ return {"get_group_id", 2};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
+ return {"get_num_groups", 0};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
+ return {"get_num_groups", 1};
+}
+static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
+ return {"get_num_groups", 2};
+}
+/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
+/// a constant argument for the dimension - x, y or z.
+template <typename OpType>
+class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ Location loc = op->getLoc();
+ std::pair<StringRef, int64_t> config = getConfig(op);
+ std::string baseName = config.first.str();
+ Type dimTy = rewriter.getI32Type();
+ int64_t dim = config.second;
+ Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
+ static_cast<int64_t>(dim));
+ std::string func = mangle(baseName, {dimTy}, {true});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
+/*
+// Subgroup ops
+// get_sub_group_local_id
+xevm::LaneIdOp;
+// get_sub_group_id
+xevm::SubgroupIdOp;
+// get_sub_group_size
+xevm::SubgroupSizeOp;
+// get_num_sub_groups : to be added if needed
+*/
+
+// Helpers to get the OpenCL function name for each op.
+static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
+static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
+static StringRef getConfig(xevm::SubgroupSizeOp) {
+ return "get_sub_group_size";
+}
+template <typename OpType>
+class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
+ using OpConversionPattern<OpType>::OpConversionPattern;
+ LogicalResult
+ matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ std::string func = mangle(getConfig(op).str(), {});
+ Type resTy = op.getType();
+ auto call =
+ createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
+ noUnwindWillReturnAttrs, op.getOperation());
+ constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
+ auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
+ /*other=*/noModRef,
+ /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
+ call.setMemoryEffectsAttr(memAttr);
+ rewriter.replaceOp(op, call);
+ return success();
+ }
+};
+
//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
@@ -775,7 +906,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
- BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
+ BlockLoadStore1DToOCLPattern<BlockStoreOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
+ LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
+ LaunchConfigOpToOCLPattern<GridDimXOp>,
+ LaunchConfigOpToOCLPattern<GridDimYOp>,
+ LaunchConfigOpToOCLPattern<GridDimZOp>,
+ SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
+ SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
patterns.getContext());
}
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index b31a973ffd6a1..72e70ff519b77 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -35,7 +35,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
- // CHECK: xevm.DecorationCacheControl =
+ // CHECK: xevm.DecorationCacheControl =
// CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
// CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
@@ -345,3 +345,148 @@ llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
llvm.return
}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.x() -> i32 {
+llvm.func @local_id.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.local_id.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.y() -> i32 {
+llvm.func @local_id.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.local_id.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_id.z() -> i32 {
+llvm.func @local_id.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.local_id.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.x() -> i32 {
+llvm.func @local_size.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.local_size.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.y() -> i32 {
+llvm.func @local_size.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.local_size.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @local_size.z() -> i32 {
+llvm.func @local_size.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.local_size.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.x() -> i32 {
+llvm.func @group_id.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.group_id.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.y() -> i32 {
+llvm.func @group_id.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.group_id.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_id.z() -> i32 {
+llvm.func @group_id.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.group_id.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.x() -> i32 {
+llvm.func @group_count.x() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
+ // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]])
+ // CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32
+ %1 = xevm.group_count.x : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.y() -> i32 {
+llvm.func @group_count.y() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
+ %1 = xevm.group_count.y : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @group_count.z() -> i32 {
+llvm.func @group_count.z() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
+ %1 = xevm.group_count.z : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
+llvm.func @lane_id() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
+ // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32
+ %1 = xevm.lane_id : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
+llvm.func @subgroup_size() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size()
+ // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32
+ %1 = xevm.subgroup_size : i32
+ llvm.return %1 : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
+llvm.func @subgroup_id() -> i32 {
+ // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
+ // CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
+ // CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
+ // CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32
+ %1 = xevm.subgroup_id : i32
+ llvm.return %1 : i32
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
//===----------------------------------------------------------------------===// | ||
/* | ||
// Launch Config ops | ||
// dimidx - x, y, x - is fixed to i32 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you mean
// dimidx - x, y, x - is fixed to i32 | |
// dimidx - x, y, z - is fixed to i32 |
?
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, | ||
ConversionPatternRewriter &rewriter) const override { | ||
Location loc = op->getLoc(); | ||
std::pair<StringRef, int64_t> config = getConfig(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What about structured bindings? Smth like
std::pair<StringRef, int64_t> config = getConfig(op); | |
auto [baseName, dim] = getConfig(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for suggestion. Totally forgot that was possible with pair.
*/ | ||
|
||
// Helpers to get the OpenCL function name and dimension argument for each op. | ||
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like it has to be an extension to tablegen special register ops.
But AFAIU with Intel GPUs, without their own LLVM backend, xevm could similarly target not only OpenCL calls, but also SPIRV calls directly or potentially some other llvm-func-based interface to built-ins, so it is up to a specific conversion to set up the proper function name.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. Intel GPU, without it's own LLVM backend, this pass is used as an alternative to tablegen based translation with llvmBuilder
.
Although, XeVM to LLVM currently exist as a conversion,
In reality it is set up as a XeVM to LLVMOCL transform.
It is a transform since the source XeVM is LLVM (extension) dialect and the target is LLVM as well.
For SPIRV calls, it will require a separate transform like XeVM to LLVMSPV.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
XeVM to LLVM pass: Add conversion patterns for XeVM id ops.
Target OpenCL functions described here:
https://registry.khronos.org/OpenCL/sdk/3.0/docs/man/html/get_group_id.html