Skip to content

Conversation

krzysz00
Copy link
Contributor

  1. Fix the fact that the full masked load pattern (which creates an if statement) could overlap with the buffer load handling pattern, since they didn't have distinct pattern benefits and were relying on order of addition to the pattern set for priority (which isn't reliable).

While I was here, add more cases to the broadcast value recognizer

  • Since this pattern often runs after broadcast lowering, recognize splat vectors.
  • Recognize broadcasts of unit vectors and convert them to the scalar case by constructing an extract()
  • Look through shape_cast ops

1. Fix the fact that the full masked load pattern (which creates an if
statement) could overlap with the buffer load handling pattern, since
they didn't have distinct pattern benefits and were relying on order
of addition to the pattern set for priority (which isn't reliable).

While I was here, add more cases to the broadcast value recognizer
- Since this pattern often runs after broadcast lowering, recognize
splat vectors.
- Recognize broadcasts of unit vectors and convert them to the scalar
case by constructing an extract()
- Look through shape_cast ops
@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2025

@llvm/pr-subscribers-mlir-gpu
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-amdgpu

@llvm/pr-subscribers-backend-amdgpu

Author: Krzysztof Drewniak (krzysz00)

Changes
  1. Fix the fact that the full masked load pattern (which creates an if statement) could overlap with the buffer load handling pattern, since they didn't have distinct pattern benefits and were relying on order of addition to the pattern set for priority (which isn't reliable).

While I was here, add more cases to the broadcast value recognizer

  • Since this pattern often runs after broadcast lowering, recognize splat vectors.
  • Recognize broadcasts of unit vectors and convert them to the scalar case by constructing an extract()
  • Look through shape_cast ops

Full diff: https://github.com/llvm/llvm-project/pull/159635.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp (+38-14)
  • (modified) mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir (+60)
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
index f15c63c166e0a..85b9543f97c27 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu;
 
 /// This pattern supports lowering of: `vector.maskedload` to `vector.load`
 /// and `arith.select` if the memref is in buffer address space.
-static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter,
-                                           vector::MaskedLoadOp maskedOp) {
-  auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType());
+static LogicalResult hasBufferAddressSpace(Type type) {
+  auto memRefType = dyn_cast<MemRefType>(type);
   if (!memRefType)
-    return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
+    return failure();
 
   Attribute addrSpace = memRefType.getMemorySpace();
   if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
-    return rewriter.notifyMatchFailure(maskedOp, "no address space");
+    return failure();
 
   if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
       amdgpu::AddressSpace::FatRawBuffer)
-    return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
+    return failure();
 
   return success();
 }
@@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc,
   return load;
 }
 
-/// Check if the given value comes from a broadcasted i1 condition.
-static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) {
+/// If the given value is the broadcast of a non-constant scalar, return that
+/// scalar, extracting it from length-1 vectors if necessary.
+static FailureOr<Value> getFullMask(RewriterBase &rw, Value val) {
+  while (auto shapeCast = val.getDefiningOp<vector::ShapeCastOp>())
+    val = shapeCast.getSource();
+  auto splatOp = val.getDefiningOp<vector::SplatOp>();
+  if (splatOp)
+    return splatOp.getInput();
   auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>();
   if (!broadcastOp)
     return failure();
-  if (isa<VectorType>(broadcastOp.getSourceType()))
-    return failure();
+  if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
+    if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1)
+      return failure();
+    SmallVector<int64_t> indices(sourceVecType.getRank(), 0);
+    Value scalarSource = vector::ExtractOp::create(
+        rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices);
+    return scalarSource;
+  }
   return broadcastOp.getSource();
 }
 
@@ -85,14 +96,14 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> {
     if (maskedOp->hasAttr(kMaskedloadNeedsMask))
       return failure();
 
-    if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) {
+    if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) {
       return failure();
     }
 
     // Check if this is either a full inbounds load or an empty, oob load. If
     // so, take the fast path and don't generate an if condition, because we
     // know doing the oob load is always safe.
-    if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) {
+    if (succeeded(getFullMask(rewriter, maskedOp.getMask()))) {
       Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(),
                                                  maskedOp, /*passthru=*/true);
       rewriter.replaceOp(maskedOp, load);
@@ -176,7 +187,11 @@ struct FullMaskedLoadToConditionalLoad
 
   LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask());
+    if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType())))
+      return rewriter.notifyMatchFailure(
+          loadOp, "buffer loads are handled by a more specialized pattern");
+
+    FailureOr<Value> maybeCond = getFullMask(rewriter, loadOp.getMask());
     if (failed(maybeCond)) {
       return failure();
     }
@@ -203,7 +218,16 @@ struct FullMaskedStoreToConditionalStore
 
   LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp,
                                 PatternRewriter &rewriter) const override {
-    FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask());
+    // A condition-free implementation of fully masked stores requires
+    // 1) an accessor for the num_records field on buffer resources/fat pointers
+    // 2) knowledge that said field will always be set accurately - that is,
+    // that writes to x < num_records of offset wouldn't trap, which is
+    // something a pattern user would need to assert or we'd need to prove.
+    //
+    // Therefore, conditional stores to buffers still go down this path at
+    // present.
+
+    FailureOr<Value> maybeCond = getFullMask(rewriter, storeOp.getMask());
     if (failed(maybeCond)) {
       return failure();
     }
diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
index f1d0ad545539a..20084dc72fd01 100644
--- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
+++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir
@@ -167,3 +167,63 @@ func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index,
 // CHECK-NOT: vector.maskedstore
 // CHECK: scf.if %[[PRED]]
 // CHECK:   vector.store
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_to_load_splat
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PRED:.+]]: i1,
+// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
+func.func @full_select_maskedload_to_load_splat(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> {
+  %0 = vector.splat %arg2 : vector<4xi1>
+  %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+  return %1 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: scf.if %[[PRED]]
+// CHECK:   %[[LOAD:.+]] = vector.load
+// CHECK:   scf.yield %[[LOAD]]
+// CHECK: else
+// CHECK:   scf.yield %[[PASSTHRU]]
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_to_load_unit_vector_pred
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PREDVEC:.+]]: vector<1xi1>,
+// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
+func.func @full_select_maskedload_to_load_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> {
+  %0 = vector.broadcast %arg2 : vector<1xi1> to vector<4xi1>
+  %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+  return %1 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0] : i1 from vector<1xi1>
+// CHECK: scf.if %[[PRED]]
+// CHECK:   %[[LOAD:.+]] = vector.load
+// CHECK:   scf.yield %[[LOAD]]
+// CHECK: else
+// CHECK:   scf.yield %[[PASSTHRU]]
+
+// -----
+
+// CHECK-LABEL: func.func @full_select_maskedload_to_load_2d_unit_vector_pred
+// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>,
+// CHECK-SAME: %[[IDX:.+]]: index,
+// CHECK-SAME: %[[PREDVEC:.+]]: vector<1x1xi1>,
+// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>)
+func.func @full_select_maskedload_to_load_2d_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1x1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> {
+  %0 = vector.broadcast %arg2 : vector<1x1xi1> to vector<2x2xi1>
+  %1 = vector.shape_cast %0 : vector<2x2xi1> to vector<4xi1>
+  %2 = vector.maskedload %arg0[%arg1, %arg1], %1, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16>
+  return %2 : vector<4xf16>
+}
+// CHECK-NOT: vector.maskedload
+// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0, 0] : i1 from vector<1x1xi1>
+// CHECK: scf.if %[[PRED]]
+// CHECK:   %[[LOAD:.+]] = vector.load
+// CHECK:   scf.yield %[[LOAD]]
+// CHECK: else
+// CHECK:   scf.yield %[[PASSTHRU]]

Comment on lines +69 to +71
auto splatOp = val.getDefiningOp<vector::SplatOp>();
if (splatOp)
return splatOp.getInput();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vector::SplatOp is deprecated please do not add support for it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What deprecation? There are parents that canonicalize broadcasts to splats still, no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, check the documentation for vector.splat https://mlir.llvm.org/docs/Dialects/Vector/#vectorsplat-vectorsplatop

Comment on lines +75 to +82
if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType())) {
if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1)
return failure();
SmallVector<int64_t> indices(sourceVecType.getRank(), 0);
Value scalarSource = vector::ExtractOp::create(
rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices);
return scalarSource;
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we should do this. Every pattern cannot handle the world. This pattern is supposed to be ran after unrolling/flattening.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But after unrolling or flattening that's aren't any broadcasts?

We might need to move broadcast lowering much, much later in out downstream pipeline if you want to keep this as is?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrolling/Flattening preserve broadcasts from scalar -> vector. If they are not, we should fix it.

Comment on lines 38 to +47
if (!memRefType)
return rewriter.notifyMatchFailure(maskedOp, "not a memref source");
return failure();

Attribute addrSpace = memRefType.getMemorySpace();
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace))
return rewriter.notifyMatchFailure(maskedOp, "no address space");
return failure();

if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() !=
amdgpu::AddressSpace::FatRawBuffer)
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space");
return failure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove the error messages?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the actual pattern doesn't have failure remarks.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactoring so I get a shared buffer address space utility

I can put an error message in the buffer loads pattern though

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@jerryyin
Copy link
Member

Having read through the existing review comments it looks like after addressing all comments, what's left in this PR adds buffer address space check. Should probably update title/PR description to that. I've no further comments beyond what Kunwar already has, and will let Kunwar decide for approval.

@krzysz00
Copy link
Contributor Author

On further inspection of vector broadcast lowering, it doesn't do splats anymore, and agreed that we probably don't want to handle the general case - or pierce shape casts, which'll probably be gone by the point we're running this. So closing in favor of a simpler fix

@krzysz00 krzysz00 closed this Sep 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants