Skip to content

[MLIR][XeVM] Update cache control values and metadata format.#175274

Merged
silee2 merged 7 commits into
llvm:mainfrom
silee2:fixXeVMCacheControl
Feb 18, 2026
Merged

[MLIR][XeVM] Update cache control values and metadata format.#175274
silee2 merged 7 commits into
llvm:mainfrom
silee2:fixXeVMCacheControl

Conversation

@silee2
Copy link
Copy Markdown
Contributor

@silee2 silee2 commented Jan 10, 2026

Fix incorrect cache control metadata values and format.

Cache control metadata is now always attached to getelementptr op.
@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Jan 10, 2026

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Sang Ik Lee (silee2)

Changes

Cache control metadata is now always attached to getelementptr op.


Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175274.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+68-64)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp (+1-1)
  • (modified) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+67-54)
  • (modified) mlir/test/Target/LLVMIR/xevm.mlir (+11-29)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 20a420dfda65c..13ea9ba26d07c 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -99,26 +99,22 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
 static int32_t getL1CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case LoadCacheControl::L1UC_L2UC_L3UC:
-  case LoadCacheControl::L1UC_L2UC_L3C:
-  case LoadCacheControl::L1UC_L2C_L3UC:
-  case LoadCacheControl::L1UC_L2C_L3C:
-    control = 1;
-    break;
   case LoadCacheControl::L1C_L2UC_L3UC:
   case LoadCacheControl::L1C_L2UC_L3C:
   case LoadCacheControl::L1C_L2C_L3UC:
   case LoadCacheControl::L1C_L2C_L3C:
-    control = 2;
+    control = 1;
     break;
   case LoadCacheControl::L1S_L2UC_L3UC:
   case LoadCacheControl::L1S_L2UC_L3C:
   case LoadCacheControl::L1S_L2C_L3UC:
   case LoadCacheControl::L1S_L2C_L3C:
-    control = 3;
+    control = 2;
     break;
   case LoadCacheControl::INVALIDATE_READ:
-    control = 4;
+    control = 3;
+    break;
+  default:
     break;
   }
   return control;
@@ -127,16 +123,15 @@ static int32_t getL1CacheControl(LoadCacheControl cc) {
 static int32_t getL1CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case StoreCacheControl::L1UC_L2UC_L3UC:
-  case StoreCacheControl::L1UC_L2UC_L3WB:
-  case StoreCacheControl::L1UC_L2WB_L3UC:
-  case StoreCacheControl::L1UC_L2WB_L3WB:
-    control = 1;
-    break;
   case StoreCacheControl::L1WT_L2UC_L3UC:
   case StoreCacheControl::L1WT_L2UC_L3WB:
   case StoreCacheControl::L1WT_L2WB_L3UC:
   case StoreCacheControl::L1WT_L2WB_L3WB:
+    control = 1;
+    break;
+  case StoreCacheControl::L1WB_L2UC_L3UC:
+  case StoreCacheControl::L1WB_L2WB_L3UC:
+  case StoreCacheControl::L1WB_L2UC_L3WB:
     control = 2;
     break;
   case StoreCacheControl::L1S_L2UC_L3UC:
@@ -145,10 +140,7 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
   case StoreCacheControl::L1S_L2WB_L3WB:
     control = 3;
     break;
-  case StoreCacheControl::L1WB_L2UC_L3UC:
-  case StoreCacheControl::L1WB_L2WB_L3UC:
-  case StoreCacheControl::L1WB_L2UC_L3WB:
-    control = 4;
+  default:
     break;
   }
   return control;
@@ -157,24 +149,18 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
 static int32_t getL3CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case LoadCacheControl::L1UC_L2UC_L3UC:
-  case LoadCacheControl::L1UC_L2C_L3UC:
-  case LoadCacheControl::L1C_L2UC_L3UC:
-  case LoadCacheControl::L1C_L2C_L3UC:
-  case LoadCacheControl::L1S_L2UC_L3UC:
-  case LoadCacheControl::L1S_L2C_L3UC:
-    control = 1;
-    break;
   case LoadCacheControl::L1UC_L2UC_L3C:
   case LoadCacheControl::L1UC_L2C_L3C:
   case LoadCacheControl::L1C_L2UC_L3C:
   case LoadCacheControl::L1C_L2C_L3C:
   case LoadCacheControl::L1S_L2UC_L3C:
   case LoadCacheControl::L1S_L2C_L3C:
-    control = 2;
+    control = 1;
     break;
   case LoadCacheControl::INVALIDATE_READ:
-    control = 4;
+    control = 3;
+    break;
+  default:
     break;
   }
   return control;
@@ -183,16 +169,6 @@ static int32_t getL3CacheControl(LoadCacheControl cc) {
 static int32_t getL3CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case StoreCacheControl::L1UC_L2UC_L3UC:
-  case StoreCacheControl::L1UC_L2WB_L3UC:
-  case StoreCacheControl::L1WT_L2UC_L3UC:
-  case StoreCacheControl::L1WT_L2WB_L3UC:
-  case StoreCacheControl::L1S_L2UC_L3UC:
-  case StoreCacheControl::L1S_L2WB_L3UC:
-  case StoreCacheControl::L1WB_L2UC_L3UC:
-  case StoreCacheControl::L1WB_L2WB_L3UC:
-    control = 1;
-    break;
   case StoreCacheControl::L1UC_L2UC_L3WB:
   case StoreCacheControl::L1UC_L2WB_L3WB:
   case StoreCacheControl::L1WT_L2UC_L3WB:
@@ -202,6 +178,8 @@ static int32_t getL3CacheControl(StoreCacheControl cc) {
   case StoreCacheControl::L1WB_L2UC_L3WB:
     control = 2;
     break;
+  default:
+    break;
   }
   return control;
 }
@@ -265,7 +243,7 @@ static std::optional<ArrayAttr>
 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
   if (!getCacheControl(op))
     return {};
-  constexpr int32_t decorationCacheControlArity{4};
+  constexpr int32_t decorationCacheControlArity{3};
   constexpr int32_t loadCacheControlKey{6442};
   constexpr int32_t storeCacheControlKey{6443};
   constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
@@ -275,9 +253,9 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
                           std::is_same_v<OpType, PrefetchOp>;
   const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
-      controlKey, 0, getL1CacheControl<OpType>(op), 0};
+      controlKey, 0, getL1CacheControl<OpType>(op)};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
-      controlKey, 1, getL3CacheControl<OpType>(op), 0};
+      controlKey, 1, getL3CacheControl<OpType>(op)};
   auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
   auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
 
@@ -445,7 +423,16 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
     const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
     Value one =
         LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
-    SmallVector<Value> args{op.getPtr(), one};
+    Value ptrOp = op.getPtr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep = LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(),
+                                          rewriter.getI32Type(), ptrOp,
+                                          ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    SmallVector<Value> args{gep, one};
     SmallVector<Type> argTypes;
     for (auto arg : args)
       argTypes.push_back(arg.getType());
@@ -459,12 +446,9 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
         /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
     funcAttr.memEffectsAttr = memAttr;
 
-    LLVM::CallOp call = createDeviceFunctionCall(
-        rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
-        argTypes, args, {}, funcAttr, op.getOperation());
-    if (std::optional<ArrayAttr> optCacheControls =
-            getCacheControlMetadata(rewriter, op))
-      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    createDeviceFunctionCall(rewriter, fnName,
+                             LLVM::LLVMVoidType::get(rewriter.getContext()),
+                             argTypes, args, {}, funcAttr, op.getOperation());
     rewriter.eraseOp(op);
     return success();
   }
@@ -548,7 +532,16 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
         rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
     byteCoord = LLVM::InsertElementOp::create(
         rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
-    SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
+    Value ptrOp = op.getPtr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep =
+        LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(), i32Type, ptrOp,
+                            ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    SmallVector<Value> args{gep, op.getBaseWidth(), op.getBaseHeight(),
                             op.getBasePitch(), byteCoord};
     SmallVector<Type> retTypes;
     Value spvLoadDstPtr;
@@ -624,10 +617,6 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
     LLVM::CallOp call = createDeviceFunctionCall(
         rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
         argTypes, args, paramAttrs, funcAttr, op.getOperation());
-    if (std::optional<ArrayAttr> optCacheControls =
-            getCacheControlMetadata(rewriter, op)) {
-      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
-    }
     if constexpr (isLoad)
       rewriter.replaceOp(
           op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
@@ -672,8 +661,17 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
     // arg1 - only if store : vector to store
     // Prepare arguments
     SmallVector<Value, 2> args{};
-    args.push_back(op.getPtr());
-    argTypes.push_back(op.getPtr().getType());
+    Value ptrOp = op.getPtr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep = LLVM::GEPOp::create(
+        rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+        ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    args.push_back(gep);
+    argTypes.push_back(gep.getType());
     isUnsigned.push_back(true);
     Type retType;
     if constexpr (isStore) {
@@ -695,10 +693,6 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
     LLVM::CallOp call =
         createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
                                  {}, funcAttr, op.getOperation());
-    if (std::optional<ArrayAttr> optCacheControls =
-            getCacheControlMetadata(rewriter, op)) {
-      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
-    }
     if constexpr (isStore)
       rewriter.eraseOp(op);
     else
@@ -715,10 +709,20 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
                   ConversionPatternRewriter &rewriter) const override {
     if (!op->hasAttr("cache_control"))
       return failure();
-    std::optional<ArrayAttr> optCacheControls =
-        getCacheControlMetadata(rewriter, op);
-    op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
-    op->removeAttr("cache_control");
+    constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
+    Value ptrOp = op.getAddr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep = LLVM::GEPOp::create(
+        rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+        ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    if constexpr (isStore)
+      rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.getValue(), gep);
+    else
+      rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getType(), gep);
     return success();
   }
 };
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
index 7e9318ad3c019..ba098aa5fde50 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -68,7 +68,7 @@ class XeVMDialectLLVMIRTranslationInterface
         attrs, std::back_inserter(decorations),
         [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
           auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
-          std::array<llvm::Metadata *, 4> metadata;
+          std::array<llvm::Metadata *, 3> metadata;
           llvm::transform(
               valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
                 return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 7f01526cb0a06..dab735c7df31f 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -15,17 +15,18 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
   // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
   // CHECK-SAME:   "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
   // CHECK-SAME:   will_return} :
   // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
   // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
-  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+  // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi16>
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
       pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -36,8 +37,8 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
 // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
 llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
   // CHECK: xevm.DecorationCacheControl =
-  // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
-  // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
+  // CHECK-SAME: 6442 : i32, 0 : i32, 0 : i32
+  // CHECK-SAME: 6442 : i32, 1 : i32, 0 : i32
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
       pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -56,17 +57,18 @@ llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_heig
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
   // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
   // CHECK-SAME:   "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
   // CHECK-SAME:   will_return}
   // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
   // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
-  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16>
+  // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<16xi16>
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false,
       pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
@@ -85,17 +87,18 @@ llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
   // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
   // CHECK-SAME:   "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
   // CHECK-SAME:   will_return} :
   // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
   // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
-  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
+  // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi32>
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false,
       pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
@@ -114,17 +117,18 @@ llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_hei
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32...
[truncated]

@llvmbot
Copy link
Copy Markdown
Member

llvmbot commented Jan 10, 2026

@llvm/pr-subscribers-mlir-llvm

Author: Sang Ik Lee (silee2)

Changes

Cache control metadata is now always attached to getelementptr op.


Patch is 36.52 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/175274.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+68-64)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp (+1-1)
  • (modified) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+67-54)
  • (modified) mlir/test/Target/LLVMIR/xevm.mlir (+11-29)
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 20a420dfda65c..13ea9ba26d07c 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -99,26 +99,22 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
 static int32_t getL1CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case LoadCacheControl::L1UC_L2UC_L3UC:
-  case LoadCacheControl::L1UC_L2UC_L3C:
-  case LoadCacheControl::L1UC_L2C_L3UC:
-  case LoadCacheControl::L1UC_L2C_L3C:
-    control = 1;
-    break;
   case LoadCacheControl::L1C_L2UC_L3UC:
   case LoadCacheControl::L1C_L2UC_L3C:
   case LoadCacheControl::L1C_L2C_L3UC:
   case LoadCacheControl::L1C_L2C_L3C:
-    control = 2;
+    control = 1;
     break;
   case LoadCacheControl::L1S_L2UC_L3UC:
   case LoadCacheControl::L1S_L2UC_L3C:
   case LoadCacheControl::L1S_L2C_L3UC:
   case LoadCacheControl::L1S_L2C_L3C:
-    control = 3;
+    control = 2;
     break;
   case LoadCacheControl::INVALIDATE_READ:
-    control = 4;
+    control = 3;
+    break;
+  default:
     break;
   }
   return control;
@@ -127,16 +123,15 @@ static int32_t getL1CacheControl(LoadCacheControl cc) {
 static int32_t getL1CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case StoreCacheControl::L1UC_L2UC_L3UC:
-  case StoreCacheControl::L1UC_L2UC_L3WB:
-  case StoreCacheControl::L1UC_L2WB_L3UC:
-  case StoreCacheControl::L1UC_L2WB_L3WB:
-    control = 1;
-    break;
   case StoreCacheControl::L1WT_L2UC_L3UC:
   case StoreCacheControl::L1WT_L2UC_L3WB:
   case StoreCacheControl::L1WT_L2WB_L3UC:
   case StoreCacheControl::L1WT_L2WB_L3WB:
+    control = 1;
+    break;
+  case StoreCacheControl::L1WB_L2UC_L3UC:
+  case StoreCacheControl::L1WB_L2WB_L3UC:
+  case StoreCacheControl::L1WB_L2UC_L3WB:
     control = 2;
     break;
   case StoreCacheControl::L1S_L2UC_L3UC:
@@ -145,10 +140,7 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
   case StoreCacheControl::L1S_L2WB_L3WB:
     control = 3;
     break;
-  case StoreCacheControl::L1WB_L2UC_L3UC:
-  case StoreCacheControl::L1WB_L2WB_L3UC:
-  case StoreCacheControl::L1WB_L2UC_L3WB:
-    control = 4;
+  default:
     break;
   }
   return control;
@@ -157,24 +149,18 @@ static int32_t getL1CacheControl(StoreCacheControl cc) {
 static int32_t getL3CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case LoadCacheControl::L1UC_L2UC_L3UC:
-  case LoadCacheControl::L1UC_L2C_L3UC:
-  case LoadCacheControl::L1C_L2UC_L3UC:
-  case LoadCacheControl::L1C_L2C_L3UC:
-  case LoadCacheControl::L1S_L2UC_L3UC:
-  case LoadCacheControl::L1S_L2C_L3UC:
-    control = 1;
-    break;
   case LoadCacheControl::L1UC_L2UC_L3C:
   case LoadCacheControl::L1UC_L2C_L3C:
   case LoadCacheControl::L1C_L2UC_L3C:
   case LoadCacheControl::L1C_L2C_L3C:
   case LoadCacheControl::L1S_L2UC_L3C:
   case LoadCacheControl::L1S_L2C_L3C:
-    control = 2;
+    control = 1;
     break;
   case LoadCacheControl::INVALIDATE_READ:
-    control = 4;
+    control = 3;
+    break;
+  default:
     break;
   }
   return control;
@@ -183,16 +169,6 @@ static int32_t getL3CacheControl(LoadCacheControl cc) {
 static int32_t getL3CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
   switch (cc) {
-  case StoreCacheControl::L1UC_L2UC_L3UC:
-  case StoreCacheControl::L1UC_L2WB_L3UC:
-  case StoreCacheControl::L1WT_L2UC_L3UC:
-  case StoreCacheControl::L1WT_L2WB_L3UC:
-  case StoreCacheControl::L1S_L2UC_L3UC:
-  case StoreCacheControl::L1S_L2WB_L3UC:
-  case StoreCacheControl::L1WB_L2UC_L3UC:
-  case StoreCacheControl::L1WB_L2WB_L3UC:
-    control = 1;
-    break;
   case StoreCacheControl::L1UC_L2UC_L3WB:
   case StoreCacheControl::L1UC_L2WB_L3WB:
   case StoreCacheControl::L1WT_L2UC_L3WB:
@@ -202,6 +178,8 @@ static int32_t getL3CacheControl(StoreCacheControl cc) {
   case StoreCacheControl::L1WB_L2UC_L3WB:
     control = 2;
     break;
+  default:
+    break;
   }
   return control;
 }
@@ -265,7 +243,7 @@ static std::optional<ArrayAttr>
 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
   if (!getCacheControl(op))
     return {};
-  constexpr int32_t decorationCacheControlArity{4};
+  constexpr int32_t decorationCacheControlArity{3};
   constexpr int32_t loadCacheControlKey{6442};
   constexpr int32_t storeCacheControlKey{6443};
   constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
@@ -275,9 +253,9 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
                           std::is_same_v<OpType, PrefetchOp>;
   const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
-      controlKey, 0, getL1CacheControl<OpType>(op), 0};
+      controlKey, 0, getL1CacheControl<OpType>(op)};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
-      controlKey, 1, getL3CacheControl<OpType>(op), 0};
+      controlKey, 1, getL3CacheControl<OpType>(op)};
   auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
   auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
 
@@ -445,7 +423,16 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
     const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
     Value one =
         LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
-    SmallVector<Value> args{op.getPtr(), one};
+    Value ptrOp = op.getPtr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep = LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(),
+                                          rewriter.getI32Type(), ptrOp,
+                                          ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    SmallVector<Value> args{gep, one};
     SmallVector<Type> argTypes;
     for (auto arg : args)
       argTypes.push_back(arg.getType());
@@ -459,12 +446,9 @@ class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
         /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
     funcAttr.memEffectsAttr = memAttr;
 
-    LLVM::CallOp call = createDeviceFunctionCall(
-        rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
-        argTypes, args, {}, funcAttr, op.getOperation());
-    if (std::optional<ArrayAttr> optCacheControls =
-            getCacheControlMetadata(rewriter, op))
-      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    createDeviceFunctionCall(rewriter, fnName,
+                             LLVM::LLVMVoidType::get(rewriter.getContext()),
+                             argTypes, args, {}, funcAttr, op.getOperation());
     rewriter.eraseOp(op);
     return success();
   }
@@ -548,7 +532,16 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
         rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
     byteCoord = LLVM::InsertElementOp::create(
         rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
-    SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
+    Value ptrOp = op.getPtr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep =
+        LLVM::GEPOp::create(rewriter, loc, ptrOp.getType(), i32Type, ptrOp,
+                            ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    SmallVector<Value> args{gep, op.getBaseWidth(), op.getBaseHeight(),
                             op.getBasePitch(), byteCoord};
     SmallVector<Type> retTypes;
     Value spvLoadDstPtr;
@@ -624,10 +617,6 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
     LLVM::CallOp call = createDeviceFunctionCall(
         rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
         argTypes, args, paramAttrs, funcAttr, op.getOperation());
-    if (std::optional<ArrayAttr> optCacheControls =
-            getCacheControlMetadata(rewriter, op)) {
-      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
-    }
     if constexpr (isLoad)
       rewriter.replaceOp(
           op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
@@ -672,8 +661,17 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
     // arg1 - only if store : vector to store
     // Prepare arguments
     SmallVector<Value, 2> args{};
-    args.push_back(op.getPtr());
-    argTypes.push_back(op.getPtr().getType());
+    Value ptrOp = op.getPtr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep = LLVM::GEPOp::create(
+        rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+        ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    args.push_back(gep);
+    argTypes.push_back(gep.getType());
     isUnsigned.push_back(true);
     Type retType;
     if constexpr (isStore) {
@@ -695,10 +693,6 @@ class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
     LLVM::CallOp call =
         createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
                                  {}, funcAttr, op.getOperation());
-    if (std::optional<ArrayAttr> optCacheControls =
-            getCacheControlMetadata(rewriter, op)) {
-      call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
-    }
     if constexpr (isStore)
       rewriter.eraseOp(op);
     else
@@ -715,10 +709,20 @@ class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
                   ConversionPatternRewriter &rewriter) const override {
     if (!op->hasAttr("cache_control"))
       return failure();
-    std::optional<ArrayAttr> optCacheControls =
-        getCacheControlMetadata(rewriter, op);
-    op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
-    op->removeAttr("cache_control");
+    constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
+    Value ptrOp = op.getAddr();
+    // Create getelementptr op to attach cache control metadata
+    // element type doesn't matter here as we use zero index, so use i32
+    LLVM::GEPOp gep = LLVM::GEPOp::create(
+        rewriter, op.getLoc(), ptrOp.getType(), rewriter.getI32Type(), ptrOp,
+        ArrayRef<LLVM::GEPArg>{0});
+    if (std::optional<ArrayAttr> optCacheControls =
+            getCacheControlMetadata(rewriter, op))
+      gep->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    if constexpr (isStore)
+      rewriter.replaceOpWithNewOp<LLVM::StoreOp>(op, op.getValue(), gep);
+    else
+      rewriter.replaceOpWithNewOp<LLVM::LoadOp>(op, op.getType(), gep);
     return success();
   }
 };
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
index 7e9318ad3c019..ba098aa5fde50 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -68,7 +68,7 @@ class XeVMDialectLLVMIRTranslationInterface
         attrs, std::back_inserter(decorations),
         [&ctx, i32Ty](Attribute attr) -> llvm::Metadata * {
           auto valuesArray = dyn_cast<ArrayAttr>(attr).getValue();
-          std::array<llvm::Metadata *, 4> metadata;
+          std::array<llvm::Metadata *, 3> metadata;
           llvm::transform(
               valuesArray, metadata.begin(), [i32Ty](Attribute valueAttr) {
                 return llvm::ConstantAsMetadata::get(llvm::ConstantInt::get(
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index 7f01526cb0a06..dab735c7df31f 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -15,17 +15,18 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
   // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
   // CHECK-SAME:   "_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
   // CHECK-SAME:   will_return} :
   // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
   // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
-  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi16>
+  // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi16>
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
       pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -36,8 +37,8 @@ llvm.func @blockload2d(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32
 // CHECK-LABEL: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x1cPU3AS1viiiDv2_iPt(
 llvm.func @blockload2d_cache_control(%a: !llvm.ptr<1>, %base_width_a: i32, %base_height_a: i32, %base_pitch_a: i32, %x: i32, %y: i32) -> vector<8xi16> {
   // CHECK: xevm.DecorationCacheControl =
-  // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
-  // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
+  // CHECK-SAME: 6442 : i32, 0 : i32, 0 : i32
+  // CHECK-SAME: 6442 : i32, 1 : i32, 0 : i32
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=1 : i32, transpose=false,
       pack_register=false, cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi16>
@@ -56,17 +57,18 @@ llvm.func @blockload2d_v_blocks(%a: !llvm.ptr<1>, %base_width_a: i32, %base_heig
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(16 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i16 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(16 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i16 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
   // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
   // CHECK-SAME:   "_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt", visibility_ = 0 : i64,
   // CHECK-SAME:   will_return}
   // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
   // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
-  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<16xi16>
+  // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<16xi16>
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=8 : i32, v_blocks=2 : i32, transpose=false,
       pack_register=false}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<16xi16>
@@ -85,17 +87,18 @@ llvm.func @blockload2d_pack_register(%a: !llvm.ptr<1>, %base_width_a: i32, %base
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32, vector<2xi32>, ptr)>,
   // CHECK-SAME:   linkage = #llvm.linkage<external>, no_unwind, sym_name =
   // CHECK-SAME:   "_Z52intel_sub_group_2d_block_read_transform_16b_16r16x1cPU3AS1viiiDv2_iPj", visibility_ = 0 : i64,
   // CHECK-SAME:   will_return} :
   // CHECK-SAME: (!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>,
   // CHECK-SAME:  !llvm.ptr {llvm.nonnull, llvm.writeonly}) -> ()
-  // CHECK: %[[VAR7:.*]] = llvm.load %[[VAR6]] : !llvm.ptr -> vector<8xi32>
+  // CHECK: %[[VAR8:.*]] = llvm.load %[[VAR7]] : !llvm.ptr -> vector<8xi32>
   %loaded_a = xevm.blockload2d %a, %base_width_a, %base_height_a, %base_pitch_a, %x, %y
     <{elem_size_in_bits=16 : i32, tile_width=16 : i32, tile_height=16 : i32, v_blocks=1 : i32, transpose=false,
       pack_register=true}> : (!llvm.ptr<1>, i32, i32, i32, i32, i32) -> vector<8xi32>
@@ -114,17 +117,18 @@ llvm.func @blockload2d_transpose(%a: !llvm.ptr<1>, %base_width_a: i32, %base_hei
   // CHECK: %[[VAR2:.*]] = llvm.mlir.constant(1 : i32) : i32
   // CHECK: %[[VAR3:.*]] = llvm.insertelement %[[ARG4]], %[[VAR0]][%[[VAR1]] : i32] : vector<2xi32>
   // CHECK: %[[VAR4:.*]] = llvm.insertelement %[[ARG5]], %[[VAR3]][%[[VAR2]] : i32] : vector<2xi32>
-  // CHECK: %[[VAR5:.*]] = llvm.mlir.constant(8 : i32) : i32
-  // CHECK: %[[VAR6:.*]] = llvm.alloca %[[VAR5]] x i32 : (i32) -> !llvm.ptr
+  // CHECK: %[[VAR5:.*]] = llvm.getelementptr %[[ARG0]][0] : (!llvm.ptr<1>) -> !llvm.ptr<1>, i32
+  // CHECK: %[[VAR6:.*]] = llvm.mlir.constant(8 : i32) : i32
+  // CHECK: %[[VAR7:.*]] = llvm.alloca %[[VAR6]] x i32 : (i32) -> !llvm.ptr
   // CHECK: llvm.call spir_funccc @_Z51intel_sub_group_2d_block_read_transpose_32b_16r8x1cPU3AS1viiiDv2_iPj(
-  // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR6]])
+  // CHECK-SAME: %[[VAR5]], %[[ARG1]], %[[ARG2]], %[[ARG3]], %[[VAR4]], %[[VAR7]])
   // CHECK-SAME: {function_type = !llvm.func<void (ptr<1>, i32, i32, i32...
[truncated]

@charithaintc
Copy link
Copy Markdown
Contributor

Out of curiosity, Is there a way to stop the GEP instruction elimination?

@silee2
Copy link
Copy Markdown
Contributor Author

silee2 commented Feb 5, 2026

Out of curiosity, Is there a way to stop the GEP instruction elimination?

Not directly but can do with a combination of generating GEP instruction as late as possible and disabling folding in pattern drivers.

@silee2 silee2 requested a review from fabianmcg as a code owner February 5, 2026 21:21
@silee2 silee2 merged commit f481bf1 into llvm:main Feb 18, 2026
10 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants