Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
146 changes: 145 additions & 1 deletion mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -714,6 +714,135 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
}
};

//===----------------------------------------------------------------------===//
// GPU index id operations
//===----------------------------------------------------------------------===//
/*
// Launch Config ops
// dimidx - x, y, z - is fixed to i32
// return type is set by XeVM type converter
// get_local_id
xevm::WorkitemIdXOp;
xevm::WorkitemIdYOp;
xevm::WorkitemIdZOp;
// get_local_size
xevm::WorkgroupDimXOp;
xevm::WorkgroupDimYOp;
xevm::WorkgroupDimZOp;
// get_group_id
xevm::WorkgroupIdXOp;
xevm::WorkgroupIdYOp;
xevm::WorkgroupIdZOp;
// get_num_groups
xevm::GridDimXOp;
xevm::GridDimYOp;
xevm::GridDimZOp;
// get_global_id : to be added if needed
*/

// Helpers to get the OpenCL function name and dimension argument for each op.
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks like it has to be an extension to tablegen special register ops.
But AFAIU with Intel GPUs, without their own LLVM backend, xevm could similarly target not only OpenCL calls, but also SPIRV calls directly or potentially some other llvm-func-based interface to built-ins, so it is up to a specific conversion to set up the proper function name.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. Intel GPU, without it's own LLVM backend, this pass is used as an alternative to tablegen based translation with llvmBuilder.
Although, XeVM to LLVM currently exist as a conversion,
In reality it is set up as a XeVM to LLVMOCL transform.
It is a transform since the source XeVM is LLVM (extension) dialect and the target is LLVM as well.
For SPIRV calls, it will require a separate transform like XeVM to LLVMSPV.

return {"get_local_id", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
return {"get_local_id", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
return {"get_local_id", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
return {"get_local_size", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
return {"get_local_size", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
return {"get_local_size", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
return {"get_group_id", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
return {"get_group_id", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
return {"get_group_id", 2};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
return {"get_num_groups", 0};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
return {"get_num_groups", 1};
}
static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
return {"get_num_groups", 2};
}
/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
/// a constant argument for the dimension - x, y or z.
template <typename OpType>
class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
auto [baseName, dim] = getConfig(op);
Type dimTy = rewriter.getI32Type();
Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
static_cast<int64_t>(dim));
std::string func = mangle(baseName, {dimTy}, {true});
Type resTy = op.getType();
auto call =
createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
noUnwindWillReturnAttrs, op.getOperation());
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
}
};

/*
// Subgroup ops
// get_sub_group_local_id
xevm::LaneIdOp;
// get_sub_group_id
xevm::SubgroupIdOp;
// get_sub_group_size
xevm::SubgroupSizeOp;
// get_num_sub_groups : to be added if needed
*/

// Helpers to get the OpenCL function name for each op.
static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
static StringRef getConfig(xevm::SubgroupSizeOp) {
return "get_sub_group_size";
}
template <typename OpType>
class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
std::string func = mangle(getConfig(op).str(), {});
Type resTy = op.getType();
auto call =
createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
noUnwindWillReturnAttrs, op.getOperation());
constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
/*other=*/noModRef,
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
call.setMemoryEffectsAttr(memAttr);
rewriter.replaceOp(op, call);
return success();
}
};

//===----------------------------------------------------------------------===//
// Pass Definition
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -775,7 +904,22 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
BlockLoadStore1DToOCLPattern<BlockStoreOp>,
LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
LaunchConfigOpToOCLPattern<GridDimXOp>,
LaunchConfigOpToOCLPattern<GridDimYOp>,
LaunchConfigOpToOCLPattern<GridDimZOp>,
SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
patterns.getContext());
}

Expand Down
147 changes: 146 additions & 1 deletion mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
// CHECK: xevm.DecorationCacheControl =
// CHECK: xevm.DecorationCacheControl =
// CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
// CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
Expand Down Expand Up @@ -345,3 +345,148 @@ llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
llvm.return
}

// -----
// CHECK-LABEL: llvm.func @local_id.x() -> i32 {
llvm.func @local_id.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_local_idj(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z12get_local_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.local_id.x : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @local_id.y() -> i32 {
llvm.func @local_id.y() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
%1 = xevm.local_id.y : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @local_id.z() -> i32 {
llvm.func @local_id.z() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
%1 = xevm.local_id.z : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @local_size.x() -> i32 {
llvm.func @local_size.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_local_sizej(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z14get_local_sizej", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.local_size.x : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @local_size.y() -> i32 {
llvm.func @local_size.y() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
%1 = xevm.local_size.y : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @local_size.z() -> i32 {
llvm.func @local_size.z() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
%1 = xevm.local_size.z : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @group_id.x() -> i32 {
llvm.func @group_id.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z12get_group_idj(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z12get_group_idj", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.group_id.x : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @group_id.y() -> i32 {
llvm.func @group_id.y() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
%1 = xevm.group_id.y : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @group_id.z() -> i32 {
llvm.func @group_id.z() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
%1 = xevm.group_id.z : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @group_count.x() -> i32 {
llvm.func @group_count.x() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: %[[VAR1:.*]] = llvm.call spir_funccc @_Z14get_num_groupsj(%[[VAR0]])
// CHECK-SAME: {function_type = !llvm.func<i32 (i32)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z14get_num_groupsj", visibility_ = 0 : i64, will_return} : (i32) -> i32
%1 = xevm.group_count.x : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @group_count.y() -> i32 {
llvm.func @group_count.y() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(1 : i32) : i32
%1 = xevm.group_count.y : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func @group_count.z() -> i32 {
llvm.func @group_count.z() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.mlir.constant(2 : i32) : i32
%1 = xevm.group_count.z : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z22get_sub_group_local_id() -> i32 attributes {no_unwind, will_return}
llvm.func @lane_id() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z22get_sub_group_local_id()
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z22get_sub_group_local_id", visibility_ = 0 : i64, will_return} : () -> i32
%1 = xevm.lane_id : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z18get_sub_group_size() -> i32 attributes {no_unwind, will_return}
llvm.func @subgroup_size() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z18get_sub_group_size()
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z18get_sub_group_size", visibility_ = 0 : i64, will_return} : () -> i32
%1 = xevm.subgroup_size : i32
llvm.return %1 : i32
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z16get_sub_group_id() -> i32 attributes {no_unwind, will_return}
llvm.func @subgroup_id() -> i32 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z16get_sub_group_id()
// CHECK-SAME: {function_type = !llvm.func<i32 ()>, linkage = #llvm.linkage<external>,
// CHECK-SAME: memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>,
// CHECK-SAME: no_unwind, sym_name = "_Z16get_sub_group_id", visibility_ = 0 : i64, will_return} : () -> i32
%1 = xevm.subgroup_id : i32
llvm.return %1 : i32
}