diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h index 7ffdbd4307f9e..f591407b602df 100644 --- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h +++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h @@ -11,6 +11,7 @@ #include namespace mlir { +class ConversionTarget; class DialectRegistry; class LLVMTypeConverter; class RewritePatternSet; @@ -19,7 +20,8 @@ class Pass; #define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS #include "mlir/Conversion/Passes.h.inc" -void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns); +void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, + RewritePatternSet &patterns); void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry); } // namespace mlir diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp index 4dfcb2b43c19c..0f90acf0d9c39 100644 --- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp +++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp @@ -98,127 +98,179 @@ std::string mangle(StringRef baseName, ArrayRef types, return os.str(); } -template -int32_t getL1CacheControl(OpType op) { +static int32_t getL1CacheControl(LoadCacheControl cc) { int32_t control = 0; - if constexpr (isLoad) { - switch (*op.getCacheControl()) { - case LoadCacheControl::L1UC_L2UC_L3UC: - case LoadCacheControl::L1UC_L2UC_L3C: - case LoadCacheControl::L1UC_L2C_L3UC: - case LoadCacheControl::L1UC_L2C_L3C: - control = 1; - break; - case LoadCacheControl::L1C_L2UC_L3UC: - case LoadCacheControl::L1C_L2UC_L3C: - case LoadCacheControl::L1C_L2C_L3UC: - case LoadCacheControl::L1C_L2C_L3C: - control = 2; - break; - case LoadCacheControl::L1S_L2UC_L3UC: - case LoadCacheControl::L1S_L2UC_L3C: - case LoadCacheControl::L1S_L2C_L3UC: - case LoadCacheControl::L1S_L2C_L3C: - control = 3; - break; - case LoadCacheControl::INVALIDATE_READ: - control = 4; - break; - } - } else { - switch (*op.getCacheControl()) { - case StoreCacheControl::L1UC_L2UC_L3UC: - case StoreCacheControl::L1UC_L2UC_L3WB: - case StoreCacheControl::L1UC_L2WB_L3UC: - case StoreCacheControl::L1UC_L2WB_L3WB: - control = 1; - break; - case StoreCacheControl::L1WT_L2UC_L3UC: - case StoreCacheControl::L1WT_L2UC_L3WB: - case StoreCacheControl::L1WT_L2WB_L3UC: - case StoreCacheControl::L1WT_L2WB_L3WB: - control = 2; - break; - case StoreCacheControl::L1S_L2UC_L3UC: - case StoreCacheControl::L1S_L2UC_L3WB: - case StoreCacheControl::L1S_L2WB_L3UC: - case StoreCacheControl::L1S_L2WB_L3WB: - control = 3; - break; - case StoreCacheControl::L1WB_L2UC_L3UC: - case StoreCacheControl::L1WB_L2WB_L3UC: - case StoreCacheControl::L1WB_L2UC_L3WB: - control = 4; - break; - } + switch (cc) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1UC_L2C_L3C: + control = 1; + break; + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1C_L2C_L3C: + control = 2; + break; + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3UC: + case LoadCacheControl::L1S_L2C_L3C: + control = 3; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; } return control; } -template -int32_t getL3CacheControl(OpType op) { +static int32_t getL1CacheControl(StoreCacheControl cc) { int32_t control = 0; - if constexpr (isLoad) { - switch (*op.getCacheControl()) { - case LoadCacheControl::L1UC_L2UC_L3UC: - case LoadCacheControl::L1UC_L2C_L3UC: - case LoadCacheControl::L1C_L2UC_L3UC: - case LoadCacheControl::L1C_L2C_L3UC: - case LoadCacheControl::L1S_L2UC_L3UC: - case LoadCacheControl::L1S_L2C_L3UC: - control = 1; - break; - case LoadCacheControl::L1UC_L2UC_L3C: - case LoadCacheControl::L1UC_L2C_L3C: - case LoadCacheControl::L1C_L2UC_L3C: - case LoadCacheControl::L1C_L2C_L3C: - case LoadCacheControl::L1S_L2UC_L3C: - case LoadCacheControl::L1S_L2C_L3C: - control = 2; - break; - case LoadCacheControl::INVALIDATE_READ: - control = 4; - break; - } - } else { - switch (*op.getCacheControl()) { - case StoreCacheControl::L1UC_L2UC_L3UC: - case StoreCacheControl::L1UC_L2WB_L3UC: - case StoreCacheControl::L1WT_L2UC_L3UC: - case StoreCacheControl::L1WT_L2WB_L3UC: - case StoreCacheControl::L1S_L2UC_L3UC: - case StoreCacheControl::L1S_L2WB_L3UC: - case StoreCacheControl::L1WB_L2UC_L3UC: - case StoreCacheControl::L1WB_L2WB_L3UC: - control = 1; - break; - case StoreCacheControl::L1UC_L2UC_L3WB: - case StoreCacheControl::L1UC_L2WB_L3WB: - case StoreCacheControl::L1WT_L2UC_L3WB: - case StoreCacheControl::L1WT_L2WB_L3WB: - case StoreCacheControl::L1S_L2UC_L3WB: - case StoreCacheControl::L1S_L2WB_L3WB: - case StoreCacheControl::L1WB_L2UC_L3WB: - control = 2; - break; - } + switch (cc) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1UC_L2WB_L3WB: + control = 1; + break; + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1WT_L2WB_L3WB: + control = 2; + break; + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1S_L2WB_L3WB: + control = 3; + break; + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 4; + break; } return control; } -template +static int32_t getL3CacheControl(LoadCacheControl cc) { + int32_t control = 0; + switch (cc) { + case LoadCacheControl::L1UC_L2UC_L3UC: + case LoadCacheControl::L1UC_L2C_L3UC: + case LoadCacheControl::L1C_L2UC_L3UC: + case LoadCacheControl::L1C_L2C_L3UC: + case LoadCacheControl::L1S_L2UC_L3UC: + case LoadCacheControl::L1S_L2C_L3UC: + control = 1; + break; + case LoadCacheControl::L1UC_L2UC_L3C: + case LoadCacheControl::L1UC_L2C_L3C: + case LoadCacheControl::L1C_L2UC_L3C: + case LoadCacheControl::L1C_L2C_L3C: + case LoadCacheControl::L1S_L2UC_L3C: + case LoadCacheControl::L1S_L2C_L3C: + control = 2; + break; + case LoadCacheControl::INVALIDATE_READ: + control = 4; + break; + } + return control; +} + +static int32_t getL3CacheControl(StoreCacheControl cc) { + int32_t control = 0; + switch (cc) { + case StoreCacheControl::L1UC_L2UC_L3UC: + case StoreCacheControl::L1UC_L2WB_L3UC: + case StoreCacheControl::L1WT_L2UC_L3UC: + case StoreCacheControl::L1WT_L2WB_L3UC: + case StoreCacheControl::L1S_L2UC_L3UC: + case StoreCacheControl::L1S_L2WB_L3UC: + case StoreCacheControl::L1WB_L2UC_L3UC: + case StoreCacheControl::L1WB_L2WB_L3UC: + control = 1; + break; + case StoreCacheControl::L1UC_L2UC_L3WB: + case StoreCacheControl::L1UC_L2WB_L3WB: + case StoreCacheControl::L1WT_L2UC_L3WB: + case StoreCacheControl::L1WT_L2WB_L3WB: + case StoreCacheControl::L1S_L2UC_L3WB: + case StoreCacheControl::L1S_L2WB_L3WB: + case StoreCacheControl::L1WB_L2UC_L3WB: + control = 2; + break; + } + return control; +} + +static std::optional getCacheControl(PrefetchOp op) { + return op.getCacheControl(); +} + +static std::optional getCacheControl(BlockLoad2dOp op) { + return op.getCacheControl(); +} + +static std::optional getCacheControl(BlockPrefetch2dOp op) { + return op.getCacheControl(); +} + +static std::optional getCacheControl(BlockStore2dOp op) { + return op.getCacheControl(); +} + +static std::optional getCacheControl(LLVM::LoadOp op) { + if (op->hasAttr("cache_control")) { + auto attr = op->getAttrOfType("cache_control"); + if (!attr) + return std::nullopt; + return std::optional(attr.getValue()); + } + return std::nullopt; +} + +static std::optional getCacheControl(LLVM::StoreOp op) { + if (op->hasAttr("cache_control")) { + auto attr = op->getAttrOfType("cache_control"); + if (!attr) + return std::nullopt; + return std::optional(attr.getValue()); + } + return std::nullopt; +} + +template +int32_t getL1CacheControl(OpType op) { + return getL1CacheControl(*getCacheControl(op)); +} + +template +int32_t getL3CacheControl(OpType op) { + return getL3CacheControl(*getCacheControl(op)); +} + +template static std::optional getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) { - if (!op.getCacheControl()) + if (!getCacheControl(op)) return {}; constexpr int32_t decorationCacheControlArity{4}; constexpr int32_t loadCacheControlKey{6442}; constexpr int32_t storeCacheControlKey{6443}; + constexpr bool isLoad = std::is_same_v || + std::is_same_v || + std::is_same_v || + std::is_same_v; const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey}; SmallVector decorationsL1{ - controlKey, 0, getL1CacheControl(op), 0}; + controlKey, 0, getL1CacheControl(op), 0}; SmallVector decorationsL3{ - controlKey, 1, getL3CacheControl(op), 0}; + controlKey, 1, getL3CacheControl(op), 0}; auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1); auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3); @@ -398,7 +450,7 @@ class PrefetchToOCLPattern : public OpConversionPattern { rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, {}, funcAttr, op.getOperation()); if (std::optional optCacheControls = - getCacheControlMetadata(rewriter, op)) + getCacheControlMetadata(rewriter, op)) call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); rewriter.eraseOp(op); return success(); @@ -557,7 +609,7 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()), argTypes, args, paramAttrs, funcAttr, op.getOperation()); if (std::optional optCacheControls = - getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) { + getCacheControlMetadata(rewriter, op)) { call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); } if constexpr (isLoad) @@ -568,6 +620,21 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern { return success(); } }; +template +class LLVMLoadStoreToOCLPattern : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(OpType op, typename OpType::Adaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!op->hasAttr("cache_control")) + return failure(); + std::optional optCacheControls = + getCacheControlMetadata(rewriter, op); + op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls); + op->removeAttr("cache_control"); + return success(); + } +}; //===----------------------------------------------------------------------===// // Pass Definition @@ -583,10 +650,8 @@ struct ConvertXeVMToLLVMPass void runOnOperation() override { ConversionTarget target(getContext()); - target.addLegalDialect(); - target.addIllegalDialect(); RewritePatternSet patterns(&getContext()); - populateXeVMToLLVMConversionPatterns(patterns); + populateXeVMToLLVMConversionPatterns(target, patterns); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) signalPassFailure(); @@ -611,7 +676,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void populateConvertToLLVMConversionPatterns( ConversionTarget &target, LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) const final { - populateXeVMToLLVMConversionPatterns(patterns); + populateXeVMToLLVMConversionPatterns(target, patterns); } }; } // namespace @@ -620,12 +685,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface { // Pattern Population //===----------------------------------------------------------------------===// -void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) { +void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target, + RewritePatternSet &patterns) { + target.addDynamicallyLegalDialect( + [](Operation *op) { return !op->hasAttr("cache_control"); }); + target.addIllegalDialect(); patterns.add, LoadStorePrefetchToOCLPattern, LoadStorePrefetchToOCLPattern, - MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>( - patterns.getContext()); + MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern, + LLVMLoadStoreToOCLPattern, + LLVMLoadStoreToOCLPattern>(patterns.getContext()); } void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry ®istry) { diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp index 73b166d045d5b..7e9318ad3c019 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp @@ -55,10 +55,6 @@ class XeVMDialectLLVMIRTranslationInterface return handleDecorationCacheControl(instructions.front(), cacheControlsArray.getValue()); } - auto func = dyn_cast(op); - if (!func) - return failure(); - return success(); } diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir index bdbb12bbe0cbb..8f60a0797652b 100644 --- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir +++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir @@ -242,3 +242,22 @@ llvm.func @prefetch(%ptr: !llvm.ptr<1>) { llvm.return } +// ----- +// CHECK-LABEL: llvm.func @llvm.load +llvm.func @llvm.load(%a: !llvm.ptr<1>) -> i32 { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32 + // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32 + %val = llvm.load %a {cache_control=#xevm.load_cache_control} : !llvm.ptr<1> -> i32 + llvm.return %val : i32 +} + +// ----- +// CHECK-LABEL: llvm.func @llvm.store +llvm.func @llvm.store(%a: !llvm.ptr<1>, %val: i32) { + // CHECK: xevm.DecorationCacheControl = + // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32 + // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32 + llvm.store %val, %a {cache_control=#xevm.store_cache_control} : i32, !llvm.ptr<1> + llvm.return +} diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir index a3dd0b6c17914..112d923607060 100644 --- a/mlir/test/Target/LLVMIR/xevm.mlir +++ b/mlir/test/Target/LLVMIR/xevm.mlir @@ -19,3 +19,35 @@ module { // CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} // CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} +// ----- +module { + // CHECK-LABEL: define i32 @load(ptr addrspace(1) + // CHECK-SAME: %[[ARG0:.*]]) { + llvm.func @load(%arg0: !llvm.ptr<1>) -> i32 { + // CHECK: load i32, ptr addrspace(1) %[[ARG0]], align 4, + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + %0 = llvm.load %arg0 {xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} : !llvm.ptr<1> -> i32 + llvm.return %0 : i32 + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0} +// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0} + +// ----- +module { + // CHECK-LABEL: define void @store(ptr addrspace(1) + // CHECK-SAME: %[[ARG0:.*]], i32 %[[ARG1:.*]]) { + llvm.func @store(%arg0: !llvm.ptr<1>, %arg1: i32) { + // CHECK: store i32 %[[ARG1]], ptr addrspace(1) %[[ARG0]], align 4, + // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]] + llvm.store %arg1, %arg0 {xevm.DecorationCacheControl = [[6443 : i32, 0 : i32, 2 : i32, 0 : i32], [6443 : i32, 1 : i32, 2 : i32, 0 : i32]]} : i32, !llvm.ptr<1> + llvm.return + } +} + +// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]} +// CHECK: ![[DECO2]] = !{i32 6443, i32 0, i32 2, i32 0} +// CHECK: ![[DECO3]] = !{i32 6443, i32 1, i32 2, i32 0} +