Skip to content

Conversation

silee2
Copy link
Contributor

@silee2 silee2 commented Sep 3, 2025

Add lowering support for LLVM load / store ops with XeVM cache control attributes.

@llvmbot
Copy link
Member

llvmbot commented Sep 3, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-llvm

Author: Sang Ik Lee (silee2)

Changes

Add lowering support for LLVM load / store ops with XeVM cache control attributes.


Full diff: https://github.com/llvm/llvm-project/pull/156768.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h (+3-1)
  • (modified) mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp (+175-108)
  • (modified) mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp (-4)
  • (modified) mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir (+19)
  • (modified) mlir/test/Target/LLVMIR/xevm.mlir (+32)
diff --git a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
index 7ffdbd4307f9e..f591407b602df 100644
--- a/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
+++ b/mlir/include/mlir/Conversion/XeVMToLLVM/XeVMToLLVM.h
@@ -11,6 +11,7 @@
 #include <memory>
 
 namespace mlir {
+class ConversionTarget;
 class DialectRegistry;
 class LLVMTypeConverter;
 class RewritePatternSet;
@@ -19,7 +20,8 @@ class Pass;
 #define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
 #include "mlir/Conversion/Passes.h.inc"
 
-void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
+void populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
+                                          RewritePatternSet &patterns);
 
 void registerConvertXeVMToLLVMInterface(DialectRegistry &registry);
 } // namespace mlir
diff --git a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
index 4dfcb2b43c19c..41f5d6d908c6c 100644
--- a/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
+++ b/mlir/lib/Conversion/XeVMToLLVM/XeVMToLLVM.cpp
@@ -98,127 +98,175 @@ std::string mangle(StringRef baseName, ArrayRef<Type> types,
   return os.str();
 }
 
-template <bool isLoad, typename OpType>
-int32_t getL1CacheControl(OpType op) {
+static int32_t getL1CacheControl(LoadCacheControl cc) {
   int32_t control = 0;
-  if constexpr (isLoad) {
-    switch (*op.getCacheControl()) {
-    case LoadCacheControl::L1UC_L2UC_L3UC:
-    case LoadCacheControl::L1UC_L2UC_L3C:
-    case LoadCacheControl::L1UC_L2C_L3UC:
-    case LoadCacheControl::L1UC_L2C_L3C:
-      control = 1;
-      break;
-    case LoadCacheControl::L1C_L2UC_L3UC:
-    case LoadCacheControl::L1C_L2UC_L3C:
-    case LoadCacheControl::L1C_L2C_L3UC:
-    case LoadCacheControl::L1C_L2C_L3C:
-      control = 2;
-      break;
-    case LoadCacheControl::L1S_L2UC_L3UC:
-    case LoadCacheControl::L1S_L2UC_L3C:
-    case LoadCacheControl::L1S_L2C_L3UC:
-    case LoadCacheControl::L1S_L2C_L3C:
-      control = 3;
-      break;
-    case LoadCacheControl::INVALIDATE_READ:
-      control = 4;
-      break;
-    }
-  } else {
-    switch (*op.getCacheControl()) {
-    case StoreCacheControl::L1UC_L2UC_L3UC:
-    case StoreCacheControl::L1UC_L2UC_L3WB:
-    case StoreCacheControl::L1UC_L2WB_L3UC:
-    case StoreCacheControl::L1UC_L2WB_L3WB:
-      control = 1;
-      break;
-    case StoreCacheControl::L1WT_L2UC_L3UC:
-    case StoreCacheControl::L1WT_L2UC_L3WB:
-    case StoreCacheControl::L1WT_L2WB_L3UC:
-    case StoreCacheControl::L1WT_L2WB_L3WB:
-      control = 2;
-      break;
-    case StoreCacheControl::L1S_L2UC_L3UC:
-    case StoreCacheControl::L1S_L2UC_L3WB:
-    case StoreCacheControl::L1S_L2WB_L3UC:
-    case StoreCacheControl::L1S_L2WB_L3WB:
-      control = 3;
-      break;
-    case StoreCacheControl::L1WB_L2UC_L3UC:
-    case StoreCacheControl::L1WB_L2WB_L3UC:
-    case StoreCacheControl::L1WB_L2UC_L3WB:
-      control = 4;
-      break;
-    }
+  switch (cc) {
+  case LoadCacheControl::L1UC_L2UC_L3UC:
+  case LoadCacheControl::L1UC_L2UC_L3C:
+  case LoadCacheControl::L1UC_L2C_L3UC:
+  case LoadCacheControl::L1UC_L2C_L3C:
+    control = 1;
+    break;
+  case LoadCacheControl::L1C_L2UC_L3UC:
+  case LoadCacheControl::L1C_L2UC_L3C:
+  case LoadCacheControl::L1C_L2C_L3UC:
+  case LoadCacheControl::L1C_L2C_L3C:
+    control = 2;
+    break;
+  case LoadCacheControl::L1S_L2UC_L3UC:
+  case LoadCacheControl::L1S_L2UC_L3C:
+  case LoadCacheControl::L1S_L2C_L3UC:
+  case LoadCacheControl::L1S_L2C_L3C:
+    control = 3;
+    break;
+  case LoadCacheControl::INVALIDATE_READ:
+    control = 4;
+    break;
   }
   return control;
 }
 
-template <bool isLoad, typename OpType>
-int32_t getL3CacheControl(OpType op) {
+static int32_t getL1CacheControl(StoreCacheControl cc) {
   int32_t control = 0;
-  if constexpr (isLoad) {
-    switch (*op.getCacheControl()) {
-    case LoadCacheControl::L1UC_L2UC_L3UC:
-    case LoadCacheControl::L1UC_L2C_L3UC:
-    case LoadCacheControl::L1C_L2UC_L3UC:
-    case LoadCacheControl::L1C_L2C_L3UC:
-    case LoadCacheControl::L1S_L2UC_L3UC:
-    case LoadCacheControl::L1S_L2C_L3UC:
-      control = 1;
-      break;
-    case LoadCacheControl::L1UC_L2UC_L3C:
-    case LoadCacheControl::L1UC_L2C_L3C:
-    case LoadCacheControl::L1C_L2UC_L3C:
-    case LoadCacheControl::L1C_L2C_L3C:
-    case LoadCacheControl::L1S_L2UC_L3C:
-    case LoadCacheControl::L1S_L2C_L3C:
-      control = 2;
-      break;
-    case LoadCacheControl::INVALIDATE_READ:
-      control = 4;
-      break;
-    }
-  } else {
-    switch (*op.getCacheControl()) {
-    case StoreCacheControl::L1UC_L2UC_L3UC:
-    case StoreCacheControl::L1UC_L2WB_L3UC:
-    case StoreCacheControl::L1WT_L2UC_L3UC:
-    case StoreCacheControl::L1WT_L2WB_L3UC:
-    case StoreCacheControl::L1S_L2UC_L3UC:
-    case StoreCacheControl::L1S_L2WB_L3UC:
-    case StoreCacheControl::L1WB_L2UC_L3UC:
-    case StoreCacheControl::L1WB_L2WB_L3UC:
-      control = 1;
-      break;
-    case StoreCacheControl::L1UC_L2UC_L3WB:
-    case StoreCacheControl::L1UC_L2WB_L3WB:
-    case StoreCacheControl::L1WT_L2UC_L3WB:
-    case StoreCacheControl::L1WT_L2WB_L3WB:
-    case StoreCacheControl::L1S_L2UC_L3WB:
-    case StoreCacheControl::L1S_L2WB_L3WB:
-    case StoreCacheControl::L1WB_L2UC_L3WB:
-      control = 2;
-      break;
-    }
+  switch (cc) {
+  case StoreCacheControl::L1UC_L2UC_L3UC:
+  case StoreCacheControl::L1UC_L2UC_L3WB:
+  case StoreCacheControl::L1UC_L2WB_L3UC:
+  case StoreCacheControl::L1UC_L2WB_L3WB:
+    control = 1;
+    break;
+  case StoreCacheControl::L1WT_L2UC_L3UC:
+  case StoreCacheControl::L1WT_L2UC_L3WB:
+  case StoreCacheControl::L1WT_L2WB_L3UC:
+  case StoreCacheControl::L1WT_L2WB_L3WB:
+    control = 2;
+    break;
+  case StoreCacheControl::L1S_L2UC_L3UC:
+  case StoreCacheControl::L1S_L2UC_L3WB:
+  case StoreCacheControl::L1S_L2WB_L3UC:
+  case StoreCacheControl::L1S_L2WB_L3WB:
+    control = 3;
+    break;
+  case StoreCacheControl::L1WB_L2UC_L3UC:
+  case StoreCacheControl::L1WB_L2WB_L3UC:
+  case StoreCacheControl::L1WB_L2UC_L3WB:
+    control = 4;
+    break;
+  }
+  return control;
+}
+
+static int32_t getL3CacheControl(LoadCacheControl cc) {
+  int32_t control = 0;
+  switch (cc) {
+  case LoadCacheControl::L1UC_L2UC_L3UC:
+  case LoadCacheControl::L1UC_L2C_L3UC:
+  case LoadCacheControl::L1C_L2UC_L3UC:
+  case LoadCacheControl::L1C_L2C_L3UC:
+  case LoadCacheControl::L1S_L2UC_L3UC:
+  case LoadCacheControl::L1S_L2C_L3UC:
+    control = 1;
+    break;
+  case LoadCacheControl::L1UC_L2UC_L3C:
+  case LoadCacheControl::L1UC_L2C_L3C:
+  case LoadCacheControl::L1C_L2UC_L3C:
+  case LoadCacheControl::L1C_L2C_L3C:
+  case LoadCacheControl::L1S_L2UC_L3C:
+  case LoadCacheControl::L1S_L2C_L3C:
+    control = 2;
+    break;
+  case LoadCacheControl::INVALIDATE_READ:
+    control = 4;
+    break;
+  }
+  return control;
+}
+
+static int32_t getL3CacheControl(StoreCacheControl cc) {
+  int32_t control = 0;
+  switch (cc) {
+  case StoreCacheControl::L1UC_L2UC_L3UC:
+  case StoreCacheControl::L1UC_L2WB_L3UC:
+  case StoreCacheControl::L1WT_L2UC_L3UC:
+  case StoreCacheControl::L1WT_L2WB_L3UC:
+  case StoreCacheControl::L1S_L2UC_L3UC:
+  case StoreCacheControl::L1S_L2WB_L3UC:
+  case StoreCacheControl::L1WB_L2UC_L3UC:
+  case StoreCacheControl::L1WB_L2WB_L3UC:
+    control = 1;
+    break;
+  case StoreCacheControl::L1UC_L2UC_L3WB:
+  case StoreCacheControl::L1UC_L2WB_L3WB:
+  case StoreCacheControl::L1WT_L2UC_L3WB:
+  case StoreCacheControl::L1WT_L2WB_L3WB:
+  case StoreCacheControl::L1S_L2UC_L3WB:
+  case StoreCacheControl::L1S_L2WB_L3WB:
+  case StoreCacheControl::L1WB_L2UC_L3WB:
+    control = 2;
+    break;
   }
   return control;
 }
 
+static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
+  return op.getCacheControl();
+}
+
+static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
+  return op.getCacheControl();
+}
+
+static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
+  return op.getCacheControl();
+}
+
+static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
+  return op.getCacheControl();
+}
+
+static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
+  if (op->hasAttr("cache_control")) {
+    auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
+    if (!attr)
+      return std::nullopt;
+    return std::optional<LoadCacheControl>(attr.getValue());
+  }
+  return std::nullopt;
+}
+
+static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
+  if (op->hasAttr("cache_control")) {
+    auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
+    if (!attr)
+      return std::nullopt;
+    return std::optional<StoreCacheControl>(attr.getValue());
+  }
+  return std::nullopt;
+}
+
+template <typename OpType>
+int32_t getL1CacheControl(OpType op) {
+  return getL1CacheControl(*getCacheControl(op));
+}
+
+template <typename OpType>
+int32_t getL3CacheControl(OpType op) {
+  return getL3CacheControl(*getCacheControl(op));
+}
+
 template <bool isLoad, typename OpType>
 static std::optional<ArrayAttr>
 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
-  if (!op.getCacheControl())
+  if (!getCacheControl(op))
     return {};
   constexpr int32_t decorationCacheControlArity{4};
   constexpr int32_t loadCacheControlKey{6442};
   constexpr int32_t storeCacheControlKey{6443};
   const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
-      controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
+      controlKey, 0, getL1CacheControl<OpType>(op), 0};
   SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
-      controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
+      controlKey, 1, getL3CacheControl<OpType>(op), 0};
   auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
   auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
 
@@ -568,6 +616,22 @@ class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
     return success();
   }
 };
+template <typename OpType>
+class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
+  using OpConversionPattern<OpType>::OpConversionPattern;
+  LogicalResult
+  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    if (!op->hasAttr("cache_control"))
+      return failure();
+    constexpr bool isLoad = std::is_same_v<OpType, LLVM::LoadOp>;
+    std::optional<ArrayAttr> optCacheControls =
+        getCacheControlMetadata<isLoad>(rewriter, op);
+    op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
+    op->removeAttr("cache_control");
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // Pass Definition
@@ -583,10 +647,8 @@ struct ConvertXeVMToLLVMPass
 
   void runOnOperation() override {
     ConversionTarget target(getContext());
-    target.addLegalDialect<LLVM::LLVMDialect>();
-    target.addIllegalDialect<XeVMDialect>();
     RewritePatternSet patterns(&getContext());
-    populateXeVMToLLVMConversionPatterns(patterns);
+    populateXeVMToLLVMConversionPatterns(target, patterns);
     if (failed(applyPartialConversion(getOperation(), target,
                                       std::move(patterns))))
       signalPassFailure();
@@ -611,7 +673,7 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
   void populateConvertToLLVMConversionPatterns(
       ConversionTarget &target, LLVMTypeConverter &typeConverter,
       RewritePatternSet &patterns) const final {
-    populateXeVMToLLVMConversionPatterns(patterns);
+    populateXeVMToLLVMConversionPatterns(target, patterns);
   }
 };
 } // namespace
@@ -620,12 +682,17 @@ struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
 // Pattern Population
 //===----------------------------------------------------------------------===//
 
-void ::mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
+void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
+                                                  RewritePatternSet &patterns) {
+  target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
+      [](Operation *op) { return !op->hasAttr("cache_control"); });
+  target.addIllegalDialect<XeVMDialect>();
   patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
                LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
                LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
-               MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
-      patterns.getContext());
+               MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
+               LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
+               LLVMLoadStoreToOCLPattern<LLVM::StoreOp>>(patterns.getContext());
 }
 
 void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
diff --git a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
index 73b166d045d5b..7e9318ad3c019 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/XeVM/XeVMToLLVMIRTranslation.cpp
@@ -55,10 +55,6 @@ class XeVMDialectLLVMIRTranslationInterface
       return handleDecorationCacheControl(instructions.front(),
                                           cacheControlsArray.getValue());
     }
-    auto func = dyn_cast<LLVM::LLVMFuncOp>(op);
-    if (!func)
-      return failure();
-
     return success();
   }
 
diff --git a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
index bdbb12bbe0cbb..8f60a0797652b 100644
--- a/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
+++ b/mlir/test/Conversion/XeVMToLLVM/xevm-to-llvm.mlir
@@ -242,3 +242,22 @@ llvm.func @prefetch(%ptr: !llvm.ptr<1>) {
   llvm.return
 }
 
+// -----
+// CHECK-LABEL: llvm.func @llvm.load
+llvm.func @llvm.load(%a: !llvm.ptr<1>) -> i32 {
+  // CHECK: xevm.DecorationCacheControl =
+  // CHECK-SAME: 6442 : i32, 0 : i32, 1 : i32, 0 : i32
+  // CHECK-SAME: 6442 : i32, 1 : i32, 1 : i32, 0 : i32
+  %val = llvm.load %a {cache_control=#xevm.load_cache_control<L1uc_L2uc_L3uc>} : !llvm.ptr<1> -> i32
+  llvm.return %val : i32
+}
+
+// -----
+// CHECK-LABEL: llvm.func @llvm.store
+llvm.func @llvm.store(%a: !llvm.ptr<1>, %val: i32) {
+  // CHECK: xevm.DecorationCacheControl =
+  // CHECK-SAME: 6443 : i32, 0 : i32, 2 : i32, 0 : i32
+  // CHECK-SAME: 6443 : i32, 1 : i32, 2 : i32, 0 : i32
+  llvm.store %val, %a {cache_control=#xevm.store_cache_control<L1wt_L2uc_L3wb>} : i32, !llvm.ptr<1>
+  llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/xevm.mlir b/mlir/test/Target/LLVMIR/xevm.mlir
index a3dd0b6c17914..112d923607060 100644
--- a/mlir/test/Target/LLVMIR/xevm.mlir
+++ b/mlir/test/Target/LLVMIR/xevm.mlir
@@ -19,3 +19,35 @@ module {
 // CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
 // CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
 
+// -----
+module {
+  // CHECK-LABEL: define i32 @load(ptr addrspace(1)
+  // CHECK-SAME: %[[ARG0:.*]]) {
+  llvm.func @load(%arg0: !llvm.ptr<1>) -> i32 {
+    // CHECK: load i32, ptr addrspace(1) %[[ARG0]], align 4,
+    // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+    %0 = llvm.load %arg0 {xevm.DecorationCacheControl = [[6442 : i32, 0 : i32, 1 : i32, 0 : i32], [6442 : i32, 1 : i32, 1 : i32, 0 : i32]]} : !llvm.ptr<1> -> i32
+    llvm.return %0 : i32
+  }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6442, i32 0, i32 1, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6442, i32 1, i32 1, i32 0}
+
+// -----
+module {
+  // CHECK-LABEL: define void @store(ptr addrspace(1)
+  // CHECK-SAME: %[[ARG0:.*]], i32 %[[ARG1:.*]]) {
+  llvm.func @store(%arg0: !llvm.ptr<1>, %arg1: i32) {
+    // CHECK: store i32 %[[ARG1]], ptr addrspace(1) %[[ARG0]], align 4,
+    // CHECK-SAME: !spirv.DecorationCacheControlINTEL ![[DECO1:.*]]
+    llvm.store %arg1, %arg0 {xevm.DecorationCacheControl = [[6443 : i32, 0 : i32, 2 : i32, 0 : i32], [6443 : i32, 1 : i32, 2 : i32, 0 : i32]]} : i32, !llvm.ptr<1>
+    llvm.return
+  }
+}
+
+// CHECK: ![[DECO1]] = !{![[DECO2:.*]], ![[DECO3:.*]]}
+// CHECK: ![[DECO2]] = !{i32 6443, i32 0, i32 2, i32 0}
+// CHECK: ![[DECO3]] = !{i32 6443, i32 1, i32 2, i32 0}
+

@silee2
Copy link
Contributor Author

silee2 commented Sep 3, 2025

FYI, @akroviakov

@silee2 silee2 changed the title [MLIR][XeVM] Add lower llvm load store cache control [MLIR][XeVM] Add lowering for llvm load store ops with XeVM cache control Sep 3, 2025
Copy link
Contributor

@akroviakov akroviakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically, we say that llvm dialect's nontemporal and volatile are not enough to cover our load/store caching options, so at ***-to-xevm we have inject a custom attribute (which is the same as we have for xevm block ops) to llvm's load/store? Looks fine to me.

ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
return failure();
constexpr bool isLoad = std::is_same_v<OpType, LLVM::LoadOp>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can it be pushed into getCacheControlMetadata() and get rid of the template argument there? Also applies to LoadStorePrefetchToOCLPattern

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LogicalResult
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
if (!op->hasAttr("cache_control"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about defining this attr name similarly to XeVMDialect::getCacheControlsAttrName()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, and I also thought about it but that would touch dialect files.
Will do in a separate PR to keep this one small.

return control;
}

static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we collapse these functions into a type-restricted templated version?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried but hot sure how to do it. The case is a bit tricky because return type depends on input type, but not something that can be inferred by the compiler or easily expressed with declspec.
If there is a clever way, will do in a separate PR alone with attr name mentioned above.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I meant it at least for cases that have a matching return type, a separate PR is fine.

@silee2 silee2 requested a review from mshahneo September 8, 2025 17:29
Copy link
Contributor

@Jianhui-Li Jianhui-Li left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@silee2 silee2 merged commit 3b6bd49 into llvm:main Sep 9, 2025
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants