[MLIR][XeVM] Update cache control values and metadata format.#175274
Conversation
Cache control metadata is now always attached to getelementptr op.
|
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Sang Ik Lee (silee2) ChangesCache control metadata is now always attached to getelementptr op. Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175274.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 20a420dfda65c..13ea9ba26d07c 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -99,26 +99,22 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case LoadCacheControl::L1UC_L2UC_L3UC:
- case LoadCacheControl::L1UC_L2UC_L3C:
- case LoadCacheControl::L1UC_L2C_L3UC:
- case LoadCacheControl::L1UC_L2C_L3C:
- control = 1;
- break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
- control = 2;
+ control = 1;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
- control = 3;
+ control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
- control = 4;
+ control = 3;
+ break;
+ default:
break;
}
return control;
@@ -127,16 +123,15 @@ static int32_t getL1CacheControl(LoadCacheControl cc) {
static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case StoreCacheControl::L1UC_L2UC_L3UC:
- case StoreCacheControl::L1UC_L2UC_L3WB:
- case StoreCacheControl::L1UC_L2WB_L3UC:
- case StoreCacheControl::L1UC_L2WB_L3WB:
- control = 1;
- break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
@@ -145,10 +140,7 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
- case StoreCacheControl::L1WB_L2UC_L3UC:
- case StoreCacheControl::L1WB_L2WB_L3UC:
- case StoreCacheControl::L1WB_L2UC_L3WB:
- control = 4;
+ default:
break;
}
return control;
@@ -157,24 +149,18 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
static int32_t getL3CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case LoadCacheControl::L1UC_L2UC_L3UC:
- case LoadCacheControl::L1UC_L2C_L3UC:
- case LoadCacheControl::L1C_L2UC_L3UC:
- case LoadCacheControl::L1C_L2C_L3UC:
- case LoadCacheControl::L1S_L2UC_L3UC:
- case LoadCacheControl::L1S_L2C_L3UC:
- control = 1;
- break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
- control = 2;
+ control = 1;
break;
case LoadCacheControl::INVALIDATE_READ:
- control = 4;
+ control = 3;
+ break;
+ default:
break;
}
return control;
@@ -183,16 +169,6 @@ static int32_t getL3CacheControl(LoadCacheControl cc) {
static int32_t getL3CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case StoreCacheControl::L1UC_L2UC_L3UC:
- case StoreCacheControl::L1UC_L2WB_L3UC:
- case StoreCacheControl::L1WT_L2UC_L3UC:
- case StoreCacheControl::L1WT_L2WB_L3UC:
- case StoreCacheControl::L1S_L2UC_L3UC:
- case StoreCacheControl::L1S_L2WB_L3UC:
- case StoreCacheControl::L1WB_L2UC_L3UC:
- case StoreCacheControl::L1WB_L2WB_L3UC:
- control = 1;
- break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
@@ -202,6 +178,8 @@ static int32_t getL3CacheControl(StoreCacheControl cc) {
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
+ default:
+ break;
}
return control;
}
@@ -265,7 +243,7 @@ static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!getCacheControl(op))
return {};
- constexpr int32_t decorationCacheControlArity{4};
+ constexpr int32_t decorationCacheControlArity{3};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
@@ -275,9 +253,9 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
- controlKey, 0, getL1CacheControl<OpType>(op), 0};
+ controlKey, 0, getL1CacheControl<OpType>(op)};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
- controlKey, 1, getL3CacheControl<OpType>(op), 0};
+ controlKey, 1, getL3CacheControl<OpType>(op)};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
@@ -445,7 +423,16 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
- SmallVector<Value> args{op.getPtr(), one};
+ Value ptrOp = op.getPtr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep = LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(),
+ rewriter.getI32Type(), ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ SmallVector<Value> args{gep, one};
SmallVector<Type> argTypes;
for (auto arg : args)
argTypes.push_back(arg.getType());
@@ -459,12 +446,9 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
- LLVM::CallOp call = createDeviceFunctionCall(
- rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
- argTypes, args, {}, funcAttr, op.getOperation());
- if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op))
- call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ createDeviceFunctionCall(rewriter, fnName,
+ LLVM::LLVMVoidType::get(rewriter.getContext()),
+ argTypes, args, {}, funcAttr, op.getOperation());
rewriter.eraseOp(op);
return success();
}
@@ -548,7 +532,16 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
byteCoord = LLVM::InsertElementOp::create(
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
- SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
+ Value ptrOp = op.getPtr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep =
+ LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(), i32Type, ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ SmallVector<Value> args{gep, op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
SmallVector<Type> retTypes;
Value spvLoadDstPtr;
@@ -624,10 +617,6 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
- if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op)) {
- call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
- }
if constexpr (isLoad)
rewriter.replaceOp(
op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
@@ -672,8 +661,17 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
// arg1 - only if store : vector to store
// Prepare arguments
SmallVector<Value, 2> args{};
- args.push_back(op.getPtr());
- argTypes.push_back(op.getPtr().getType());
+ Value ptrOp = op.getPtr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep = LLVM::GEPOp::create(
+ rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ args.push_back(gep);
+ argTypes.push_back(gep.getType());
isUnsigned.push_back(true);
Type retType;
if constexpr (isStore) {
@@ -695,10 +693,6 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
LLVM::CallOp call =
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
{}, funcAttr, op.getOperation());
- if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op)) {
- call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
- }
if constexpr (isStore)
rewriter.eraseOp(op);
else
@@ -715,10 +709,20 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
return failure();
- std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op);
- op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
- op->removeAttr("cache_control");
+ constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
+ Value ptrOp = op.getAddr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep = LLVM::GEPOp::create(
+ rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ if constexpr (isStore)
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.getValue(), gep);
+ else
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getType(), gep);
return success();
}
};
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
index 7e9318ad3c019..ba098aa5fde50 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -68,7 +68,7 @@ class XeVMDialectLLVMIRTranslationInterface
attrs, std::back_inserter(decorations),
[&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
- std::array<llvm::Metadata *, 4> metadata;
+ std::array<llvm::Metadata *, 3> metadata;
llvm::transform(
valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 7f01526cb0a06..dab735c7df31f 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -15,17 +15,18 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
// CHECK-SAME: will_return} :
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -36,8 +37,8 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
// CHECK: xevm.DecorationCacheControl =
- // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
- // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
+ // CHECK-SAME: 6442 : i32, 0 : i32, 0 : i32
+ // CHECK-SAME: 6442 : i32, 1 : i32, 0 : i32
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -56,17 +57,18 @@ llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_heig
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
// CHECK-SAME: will_return}
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16>
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<16xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
@@ -85,17 +87,18 @@ llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
// CHECK-SAME: will_return} :
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi32>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false,
pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
@@ -114,17 +117,18 @@ llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_hei
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32...
[truncated]
|
|
@llvm/pr-subscribers-mlir-llvm Author: Sang Ik Lee (silee2) ChangesCache control metadata is now always attached to getelementptr op. Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175274.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 20a420dfda65c..13ea9ba26d07c 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -99,26 +99,22 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
static int32_t getL1CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case LoadCacheControl::L1UC_L2UC_L3UC:
- case LoadCacheControl::L1UC_L2UC_L3C:
- case LoadCacheControl::L1UC_L2C_L3UC:
- case LoadCacheControl::L1UC_L2C_L3C:
- control = 1;
- break;
case LoadCacheControl::L1C_L2UC_L3UC:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3UC:
case LoadCacheControl::L1C_L2C_L3C:
- control = 2;
+ control = 1;
break;
case LoadCacheControl::L1S_L2UC_L3UC:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3UC:
case LoadCacheControl::L1S_L2C_L3C:
- control = 3;
+ control = 2;
break;
case LoadCacheControl::INVALIDATE_READ:
- control = 4;
+ control = 3;
+ break;
+ default:
break;
}
return control;
@@ -127,16 +123,15 @@ static int32_t getL1CacheControl(LoadCacheControl cc) {
static int32_t getL1CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case StoreCacheControl::L1UC_L2UC_L3UC:
- case StoreCacheControl::L1UC_L2UC_L3WB:
- case StoreCacheControl::L1UC_L2WB_L3UC:
- case StoreCacheControl::L1UC_L2WB_L3WB:
- control = 1;
- break;
case StoreCacheControl::L1WT_L2UC_L3UC:
case StoreCacheControl::L1WT_L2UC_L3WB:
case StoreCacheControl::L1WT_L2WB_L3UC:
case StoreCacheControl::L1WT_L2WB_L3WB:
+ control = 1;
+ break;
+ case StoreCacheControl::L1WB_L2UC_L3UC:
+ case StoreCacheControl::L1WB_L2WB_L3UC:
+ case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
case StoreCacheControl::L1S_L2UC_L3UC:
@@ -145,10 +140,7 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
case StoreCacheControl::L1S_L2WB_L3WB:
control = 3;
break;
- case StoreCacheControl::L1WB_L2UC_L3UC:
- case StoreCacheControl::L1WB_L2WB_L3UC:
- case StoreCacheControl::L1WB_L2UC_L3WB:
- control = 4;
+ default:
break;
}
return control;
@@ -157,24 +149,18 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
static int32_t getL3CacheControl(LoadCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case LoadCacheControl::L1UC_L2UC_L3UC:
- case LoadCacheControl::L1UC_L2C_L3UC:
- case LoadCacheControl::L1C_L2UC_L3UC:
- case LoadCacheControl::L1C_L2C_L3UC:
- case LoadCacheControl::L1S_L2UC_L3UC:
- case LoadCacheControl::L1S_L2C_L3UC:
- control = 1;
- break;
case LoadCacheControl::L1UC_L2UC_L3C:
case LoadCacheControl::L1UC_L2C_L3C:
case LoadCacheControl::L1C_L2UC_L3C:
case LoadCacheControl::L1C_L2C_L3C:
case LoadCacheControl::L1S_L2UC_L3C:
case LoadCacheControl::L1S_L2C_L3C:
- control = 2;
+ control = 1;
break;
case LoadCacheControl::INVALIDATE_READ:
- control = 4;
+ control = 3;
+ break;
+ default:
break;
}
return control;
@@ -183,16 +169,6 @@ static int32_t getL3CacheControl(LoadCacheControl cc) {
static int32_t getL3CacheControl(StoreCacheControl cc) {
int32_t control = 0;
switch (cc) {
- case StoreCacheControl::L1UC_L2UC_L3UC:
- case StoreCacheControl::L1UC_L2WB_L3UC:
- case StoreCacheControl::L1WT_L2UC_L3UC:
- case StoreCacheControl::L1WT_L2WB_L3UC:
- case StoreCacheControl::L1S_L2UC_L3UC:
- case StoreCacheControl::L1S_L2WB_L3UC:
- case StoreCacheControl::L1WB_L2UC_L3UC:
- case StoreCacheControl::L1WB_L2WB_L3UC:
- control = 1;
- break;
case StoreCacheControl::L1UC_L2UC_L3WB:
case StoreCacheControl::L1UC_L2WB_L3WB:
case StoreCacheControl::L1WT_L2UC_L3WB:
@@ -202,6 +178,8 @@ static int32_t getL3CacheControl(StoreCacheControl cc) {
case StoreCacheControl::L1WB_L2UC_L3WB:
control = 2;
break;
+ default:
+ break;
}
return control;
}
@@ -265,7 +243,7 @@ static std::optional<ArrayAttr>
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
if (!getCacheControl(op))
return {};
- constexpr int32_t decorationCacheControlArity{4};
+ constexpr int32_t decorationCacheControlArity{3};
constexpr int32_t loadCacheControlKey{6442};
constexpr int32_t storeCacheControlKey{6443};
constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
@@ -275,9 +253,9 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
std::is_same_v<OpType, PrefetchOp>;
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
- controlKey, 0, getL1CacheControl<OpType>(op), 0};
+ controlKey, 0, getL1CacheControl<OpType>(op)};
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
- controlKey, 1, getL3CacheControl<OpType>(op), 0};
+ controlKey, 1, getL3CacheControl<OpType>(op)};
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
@@ -445,7 +423,16 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
Value one =
LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
- SmallVector<Value> args{op.getPtr(), one};
+ Value ptrOp = op.getPtr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep = LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(),
+ rewriter.getI32Type(), ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ SmallVector<Value> args{gep, one};
SmallVector<Type> argTypes;
for (auto arg : args)
argTypes.push_back(arg.getType());
@@ -459,12 +446,9 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
/*targetMem1=*/LLVM::ModRefInfo::NoModRef);
funcAttr.memEffectsAttr = memAttr;
- LLVM::CallOp call = createDeviceFunctionCall(
- rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
- argTypes, args, {}, funcAttr, op.getOperation());
- if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op))
- call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ createDeviceFunctionCall(rewriter, fnName,
+ LLVM::LLVMVoidType::get(rewriter.getContext()),
+ argTypes, args, {}, funcAttr, op.getOperation());
rewriter.eraseOp(op);
return success();
}
@@ -548,7 +532,16 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
byteCoord = LLVM::InsertElementOp::create(
rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
- SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
+ Value ptrOp = op.getPtr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep =
+ LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(), i32Type, ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ SmallVector<Value> args{gep, op.getBaseWidth(), op.getBaseHeight(),
op.getBasePitch(), byteCoord};
SmallVector<Type> retTypes;
Value spvLoadDstPtr;
@@ -624,10 +617,6 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
LLVM::CallOp call = createDeviceFunctionCall(
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
argTypes, args, paramAttrs, funcAttr, op.getOperation());
- if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op)) {
- call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
- }
if constexpr (isLoad)
rewriter.replaceOp(
op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
@@ -672,8 +661,17 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
// arg1 - only if store : vector to store
// Prepare arguments
SmallVector<Value, 2> args{};
- args.push_back(op.getPtr());
- argTypes.push_back(op.getPtr().getType());
+ Value ptrOp = op.getPtr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep = LLVM::GEPOp::create(
+ rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ args.push_back(gep);
+ argTypes.push_back(gep.getType());
isUnsigned.push_back(true);
Type retType;
if constexpr (isStore) {
@@ -695,10 +693,6 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
LLVM::CallOp call =
createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
{}, funcAttr, op.getOperation());
- if (std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op)) {
- call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
- }
if constexpr (isStore)
rewriter.eraseOp(op);
else
@@ -715,10 +709,20 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
return failure();
- std::optional<ArrayAttr> optCacheControls =
- getCacheControlMetadata(rewriter, op);
- op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
- op->removeAttr("cache_control");
+ constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
+ Value ptrOp = op.getAddr();
+ // Create getelementptr op to attach cache control metadata
+ // element type doesn't matter here as we use zero index, so use i32
+ LLVM::GEPOp gep = LLVM::GEPOp::create(
+ rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+ ArrayRef<LLVM::GEPArg>{0});
+ if (std::optional<ArrayAttr> optCacheControls =
+ getCacheControlMetadata(rewriter, op))
+ gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+ if constexpr (isStore)
+ rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.getValue(), gep);
+ else
+ rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getType(), gep);
return success();
}
};
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
index 7e9318ad3c019..ba098aa5fde50 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -68,7 +68,7 @@ class XeVMDialectLLVMIRTranslationInterface
attrs, std::back_inserter(decorations),
[&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
- std::array<llvm::Metadata *, 4> metadata;
+ std::array<llvm::Metadata *, 3> metadata;
llvm::transform(
valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 7f01526cb0a06..dab735c7df31f 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -15,17 +15,18 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
// CHECK-SAME: will_return} :
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -36,8 +37,8 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
// CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
// CHECK: xevm.DecorationCacheControl =
- // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
- // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
+ // CHECK-SAME: 6442 : i32, 0 : i32, 0 : i32
+ // CHECK-SAME: 6442 : i32, 1 : i32, 0 : i32
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -56,17 +57,18 @@ llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_heig
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(16 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
// CHECK-SAME: will_return}
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16>
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<16xi16>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false,
pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
@@ -85,17 +87,18 @@ llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
// CHECK-SAME: linkage = #llvm.linkage<external>, no_unwind, sym_name =
// CHECK-SAME: "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
// CHECK-SAME: will_return} :
// CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
// CHECK-SAME: !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
- // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
+ // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi32>
%loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
<{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false,
pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
@@ -114,17 +117,18 @@ llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_hei
// CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
// CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
- // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
- // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+ // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+ // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+ // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
// CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
- // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+ // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
// CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32...
[truncated]
|
|
Out of curiosity, Is there a way to stop the GEP instruction elimination? |
Not directly but can do with a combination of generating GEP instruction as late as possible and disabling folding in pattern drivers. |
Fix incorrect cache control metadata values and format.