diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index f449d90eb67a5..f2769846abbd9 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -714,6 +714,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// GPU index id operations +//===----------------------------------------------------------------------===// +/* +// Launch Config ops +// dimidx - x, y, z - is fixed to i32 +// return type is set by XeVM type converter +// get_local_id +xevm::WorkitemIdXOp; +xevm::WorkitemIdYOp; +xevm::WorkitemIdZOp; +// get_local_size +xevm::WorkgroupDimXOp; +xevm::WorkgroupDimYOp; +xevm::WorkgroupDimZOp; +// get_group_id +xevm::WorkgroupIdXOp; +xevm::WorkgroupIdYOp; +xevm::WorkgroupIdZOp; +// get_num_groups +xevm::GridDimXOp; +xevm::GridDimYOp; +xevm::GridDimZOp; +// get_global_id : to be added if needed +*/ + +// Helpers to get the OpenCL function name and dimension argument for each op. +static std::pair getConfig(xevm::WorkitemIdXOp) { + return {"get_local_id", 0}; +} +static std::pair getConfig(xevm::WorkitemIdYOp) { + return {"get_local_id", 1}; +} +static std::pair getConfig(xevm::WorkitemIdZOp) { + return {"get_local_id", 2}; +} +static std::pair getConfig(xevm::WorkgroupDimXOp) { + return {"get_local_size", 0}; +} +static std::pair getConfig(xevm::WorkgroupDimYOp) { + return {"get_local_size", 1}; +} +static std::pair getConfig(xevm::WorkgroupDimZOp) { + return {"get_local_size", 2}; +} +static std::pair getConfig(xevm::WorkgroupIdXOp) { + return {"get_group_id", 0}; +} +static std::pair getConfig(xevm::WorkgroupIdYOp) { + return {"get_group_id", 1}; +} +static std::pair getConfig(xevm::WorkgroupIdZOp) { + return {"get_group_id", 2}; +} +static std::pair getConfig(xevm::GridDimXOp) { + return {"get_num_groups", 0}; +} +static std::pair getConfig(xevm::GridDimYOp) { + return {"get_num_groups", 1}; +} +static std::pair getConfig(xevm::GridDimZOp) { + return {"get_num_groups", 2}; +} +/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with +/// a constant argument for the dimension - x, y or z. +template +class LaunchConfigOpToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto [baseName, dim] = getConfig(op); + Type dimTy = rewriter.getI32Type(); + Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy, + static_cast(dim)); + std::string func = mangle(baseName, {dimTy}, {true}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + +/* +// Subgroup ops +// get_sub_group_local_id +xevm::LaneIdOp; +// get_sub_group_id +xevm::SubgroupIdOp; +// get_sub_group_size +xevm::SubgroupSizeOp; +// get_num_sub_groups : to be added if needed +*/ + +// Helpers to get the OpenCL function name for each op. +static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; } +static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; } +static StringRef getConfig(xevm::SubgroupSizeOp) { + return "get_sub_group_size"; +} +template +class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::string func = mangle(getConfig(op).str(), {}); + Type resTy = op.getType(); + auto call = + createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {}, + noUnwindWillReturnAttrs, op.getOperation()); + constexpr auto noModRef = LLVM::ModRefInfo::NoModRef; + auto memAttr = rewriter.getAttr( + /*other=*/noModRef, + /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef); + call.setMemoryEffectsAttr(memAttr); + rewriter.replaceOp(op, call); + return success(); + } +}; + //===----------------------------------------------------------------------===// // Pass Definition //===----------------------------------------------------------------------===// @@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, LLVMLoadStoreToOCLPattern, LLVMLoadStoreToOCLPattern, BlockLoadStore1DToOCLPattern, - BlockLoadStore1DToOCLPattern>( + BlockLoadStore1DToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + LaunchConfigOpToOCLPattern, + SubgroupOpWorkitemOpToOCLPattern, + SubgroupOpWorkitemOpToOCLPattern, + SubgroupOpWorkitemOpToOCLPattern>( patterns.getContext()); } diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index b31a973ffd6a1..72e70ff519b77 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -35,7 +35,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32 // ----- // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt( llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> { - // CHECK: xevm.DecorationCacheControl = + // CHECK: xevm.DecorationCacheControl = // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y @@ -345,3 +345,148 @@ llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) { xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control}> : (!llvm.ptr<3>, i64) llvm.return } + +// ----- +// CHECK-LABEL: llvm.func @local_id.x() -> i32 { +llvm.func @local_id.x() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]]) + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32 + %1 = xevm.local_id.x : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @local_id.y() -> i32 { +llvm.func @local_id.y() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32 + %1 = xevm.local_id.y : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @local_id.z() -> i32 { +llvm.func @local_id.z() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32 + %1 = xevm.local_id.z : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @local_size.x() -> i32 { +llvm.func @local_size.x() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]]) + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32 + %1 = xevm.local_size.x : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @local_size.y() -> i32 { +llvm.func @local_size.y() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32 + %1 = xevm.local_size.y : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @local_size.z() -> i32 { +llvm.func @local_size.z() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32 + %1 = xevm.local_size.z : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @group_id.x() -> i32 { +llvm.func @group_id.x() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]]) + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32 + %1 = xevm.group_id.x : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @group_id.y() -> i32 { +llvm.func @group_id.y() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32 + %1 = xevm.group_id.y : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @group_id.z() -> i32 { +llvm.func @group_id.z() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32 + %1 = xevm.group_id.z : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @group_count.x() -> i32 { +llvm.func @group_count.x() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32 + // CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]]) + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32 + %1 = xevm.group_count.x : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @group_count.y() -> i32 { +llvm.func @group_count.y() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32 + %1 = xevm.group_count.y : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @group_count.z() -> i32 { +llvm.func @group_count.z() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32 + %1 = xevm.group_count.z : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return} +llvm.func @lane_id() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id() + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32 + %1 = xevm.lane_id : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return} +llvm.func @subgroup_size() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size() + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32 + %1 = xevm.subgroup_size : i32 + llvm.return %1 : i32 +} + +// ----- +// CHECK-LABEL: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return} +llvm.func @subgroup_id() -> i32 { + // CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id() + // CHECK-SAME: {function_type = !llvm.func, linkage = #llvm.linkage, + // CHECK-SAME: memory_effects = #llvm.memory_effects, + // CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32 + %1 = xevm.subgroup_id : i32 + llvm.return %1 : i32 +}