Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
return op.getCacheControl();
}

static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
return op.getCacheControl();
}

static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
return op.getCacheControl();
}
Expand All @@ -222,6 +226,10 @@ static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
return op.getCacheControl();
}

static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
return op.getCacheControl();
}

static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
if (op->hasAttr("cache_control")) {
auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
Expand Down Expand Up @@ -263,6 +271,7 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
std::is_same_v<OpType, BlockPrefetch2dOp> ||
std::is_same_v<OpType, LLVM::LoadOp> ||
std::is_same_v<OpType, BlockLoadOp> ||
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
Expand Down Expand Up @@ -618,6 +627,77 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
return success();
}
};

template <typename OpType>
class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
// Get OpenCL function name
// https://registry.khronos.org/OpenCL/extensions/
// intel/cl_intel_subgroup_local_block_io.html
std::string funcName{"intel_sub_group_block_"};
// Value or Result type can be vector or scalar
Type valOrResTy;
if constexpr (isStore) {
funcName += "write_u";
valOrResTy = op.getVal().getType();
} else {
funcName += "read_u";
valOrResTy = op.getType();
}
// Get element type of the vector/scalar
VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
funcName += getTypeMangling(elemType);
if (vecTy)
funcName += std::to_string(vecTy.getNumElements());
SmallVector<Type, 2> argTypes{};
// XeVM BlockLoad/StoreOp always use signless integer types
// but OpenCL builtins expect unsigned types
// use unsigned types for mangling
SmallVector<bool, 2> isUnsigned{};
// arg0: pointer to the src/dst address
// arg1 - only if store : vector to store
// Prepare arguments
SmallVector<Value, 2> args{};
args.push_back(op.getPtr());
argTypes.push_back(op.getPtr().getType());
isUnsigned.push_back(true);
Type retType;
if constexpr (isStore) {
args.push_back(op.getVal());
argTypes.push_back(op.getVal().getType());
isUnsigned.push_back(true);
retType = LLVM::LLVMVoidType::get(rewriter.getContext());
} else {
retType = valOrResTy;
}
funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
"PU3AS" +
std::to_string(op.getPtr().getType().getAddressSpace());
funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
if constexpr (isStore)
funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};

LLVM::CallOp call =
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
{}, funcAttr, op.getOperation());
if (std::optional<ArrayAttr> optCacheControls =
getCacheControlMetadata(rewriter, op)) {
call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
}
if constexpr (isStore)
rewriter.eraseOp(op);
else
rewriter.replaceOp(op, call->getResult(0));
return success();
}
};

template <typename OpType>
class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
using OpConversionPattern<OpType>::OpConversionPattern;
Expand Down Expand Up @@ -693,7 +773,10 @@ void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
BlockLoadStore1DToOCLPattern<BlockLoadOp>,
BlockLoadStore1DToOCLPattern<BlockStoreOp>>(
patterns.getContext());
}

void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
Expand Down
84 changes: 84 additions & 0 deletions mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -261,3 +261,87 @@ llvm.func @llvm.store(%a: !llvm.ptr<1>, %val: i32) {
llvm.store %val, %a {cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>} : i32, !llvm.ptr<1>
llvm.return
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS1t
// CHECK: llvm.func @blockload_as1(%[[ARG0:.*]]: !llvm.ptr<1>)
llvm.func @blockload_as1(%ptr: !llvm.ptr<1>) -> vector<8xi16> {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z30intel_sub_group_block_read_us8PU3AS1t(%[[ARG0]])
// CHECK-SAME: {function_type = !llvm.func<vector<8xi16> (ptr<1>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: no_unwind, sym_name = "_Z30intel_sub_group_block_read_us8PU3AS1t",
// CHECK-SAME: visibility_ = 0 : i64, will_return, xevm.DecorationCacheControl =
// CHECK-SAME: [6442 : i32, 0 : i32, 1 : i32, 0 : i32],
// CHECK-SAME: [6442 : i32, 1 : i32, 1 : i32, 0 : i32]
%loaded_a = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>) -> vector<8xi16>
llvm.return %loaded_a : vector<8xi16>
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z31intel_sub_group_block_read_uc16PU3AS3h(!llvm.ptr<3>)
// CHECK: llvm.func @blockload_as3(%[[ARG0:.*]]: !llvm.ptr<3>)
llvm.func @blockload_as3(%ptr: !llvm.ptr<3>) -> vector<16xi8> {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z31intel_sub_group_block_read_uc16PU3AS3h(%[[ARG0]])
// CHECK-SAME: {function_type = !llvm.func<vector<16xi8> (ptr<3>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: no_unwind, sym_name = "_Z31intel_sub_group_block_read_uc16PU3AS3h", visibility_ = 0 : i64,
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
// CHECK-SAME: [6442 : i32, 0 : i32, 1 : i32, 0 : i32],
// CHECK-SAME: [6442 : i32, 1 : i32, 1 : i32, 0 : i32]
%loaded_a = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<3>) -> vector<16xi8>
llvm.return %loaded_a : vector<16xi8>
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z29intel_sub_group_block_read_ucPU3AS3h(!llvm.ptr<3>)
// CHECK: llvm.func @blockload_scalar(%[[ARG0:.*]]: !llvm.ptr<3>)
llvm.func @blockload_scalar(%ptr: !llvm.ptr<3>) -> i8 {
// CHECK: %[[VAR0:.*]] = llvm.call spir_funccc @_Z29intel_sub_group_block_read_ucPU3AS3h(%[[ARG0]])
// CHECK-SAME: {function_type = !llvm.func<i8 (ptr<3>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: no_unwind, sym_name = "_Z29intel_sub_group_block_read_ucPU3AS3h", visibility_ = 0 : i64,
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
// CHECK-SAME: [6442 : i32, 0 : i32, 1 : i32, 0 : i32],
// CHECK-SAME: [6442 : i32, 1 : i32, 1 : i32, 0 : i32]
%loaded_a = xevm.blockload %ptr <{cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<3>) -> i8
llvm.return %loaded_a : i8
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS1jDv8_j
// CHECK: llvm.func @blockstore_as1(%[[ARG0:.*]]: !llvm.ptr<1>, %[[ARG1:.*]]: vector<8xi32>) {
llvm.func @blockstore_as1(%ptr: !llvm.ptr<1>, %data: vector<8xi32>) {
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ui8PU3AS1jDv8_j(%[[ARG0]], %[[ARG1]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, vector<8xi32>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: no_unwind, sym_name = "_Z31intel_sub_group_block_write_ui8PU3AS1jDv8_j", visibility_ = 0 : i64,
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
// CHECK-SAME: [6443 : i32, 0 : i32, 2 : i32, 0 : i32],
// CHECK-SAME: [6443 : i32, 1 : i32, 2 : i32, 0 : i32]
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<1>, vector<8xi32>)
llvm.return
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z31intel_sub_group_block_write_ul2PU3AS3mDv2_m
// CHECK: llvm.func @blockstore_as3(%[[ARG0:.*]]: !llvm.ptr<3>, %[[ARG1:.*]]: vector<2xi64>) {
llvm.func @blockstore_as3(%ptr: !llvm.ptr<3>, %data: vector<2xi64>) {
// CHECK: llvm.call spir_funccc @_Z31intel_sub_group_block_write_ul2PU3AS3mDv2_m(%[[ARG0]], %[[ARG1]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<3>, vector<2xi64>)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: no_unwind, sym_name = "_Z31intel_sub_group_block_write_ul2PU3AS3mDv2_m", visibility_ = 0 : i64,
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
// CHECK-SAME: [6443 : i32, 0 : i32, 2 : i32, 0 : i32],
// CHECK-SAME: [6443 : i32, 1 : i32, 2 : i32, 0 : i32]
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, vector<2xi64>)
llvm.return
}

// -----
// CHECK-LABEL: llvm.func spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm
// CHECK: llvm.func @blockstore_scalar(%[[ARG0:.*]]: !llvm.ptr<3>, %[[ARG1:.*]]: i64) {
llvm.func @blockstore_scalar(%ptr: !llvm.ptr<3>, %data: i64) {
// CHECK: llvm.call spir_funccc @_Z30intel_sub_group_block_write_ulPU3AS3mm(%[[ARG0]], %[[ARG1]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<3>, i64)>, linkage = #llvm.linkage<external>,
// CHECK-SAME: no_unwind, sym_name = "_Z30intel_sub_group_block_write_ulPU3AS3mm", visibility_ = 0 : i64,
// CHECK-SAME: will_return, xevm.DecorationCacheControl =
// CHECK-SAME: [6443 : i32, 0 : i32, 2 : i32, 0 : i32],
// CHECK-SAME: [6443 : i32, 1 : i32, 2 : i32, 0 : i32]
xevm.blockstore %ptr, %data <{cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>}> : (!llvm.ptr<3>, i64)
llvm.return
}