From 6a9eb1d7b95cb224f38d2c3f29089d2e3b5640f1 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 18 Sep 2025 19:34:26 +0000 Subject: [PATCH] [mlir][AMDGPU] Improve masked_load(..., broadcast(...), ...) handling 1. Fix the fact that the full masked load pattern (which creates an if statement) could overlap with the buffer load handling pattern, since they didn't have distinct pattern benefits and were relying on order of addition to the pattern set for priority (which isn't reliable). While I was here, add more cases to the broadcast value recognizer - Since this pattern often runs after broadcast lowering, recognize splat vectors. - Recognize broadcasts of unit vectors and convert them to the scalar case by constructing an extract() - Look through shape_cast ops --- .../AMDGPU/Transforms/MaskedloadToLoad.cpp | 52 +++++++++++----- .../Dialect/AMDGPU/maskedload-to-load.mlir | 60 +++++++++++++++++++ 2 files changed, 98 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c166e0a..85b9543f97c27 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -33,19 +33,18 @@ using namespace mlir::amdgpu; /// This pattern supports lowering of: `vector.maskedload` to `vector.load` /// and `arith.select` if the memref is in buffer address space. -static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, - vector::MaskedLoadOp maskedOp) { - auto memRefType = dyn_cast(maskedOp.getBase().getType()); +static LogicalResult hasBufferAddressSpace(Type type) { + auto memRefType = dyn_cast(type); if (!memRefType) - return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + return failure(); Attribute addrSpace = memRefType.getMemorySpace(); if (!isa_and_nonnull(addrSpace)) - return rewriter.notifyMatchFailure(maskedOp, "no address space"); + return failure(); if (dyn_cast(addrSpace).getValue() != amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + return failure(); return success(); } @@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, return load; } -/// Check if the given value comes from a broadcasted i1 condition. -static FailureOr matchFullMask(OpBuilder &b, Value val) { +/// If the given value is the broadcast of a non-constant scalar, return that +/// scalar, extracting it from length-1 vectors if necessary. +static FailureOr getFullMask(RewriterBase &rw, Value val) { + while (auto shapeCast = val.getDefiningOp()) + val = shapeCast.getSource(); + auto splatOp = val.getDefiningOp(); + if (splatOp) + return splatOp.getInput(); auto broadcastOp = val.getDefiningOp(); if (!broadcastOp) return failure(); - if (isa(broadcastOp.getSourceType())) - return failure(); + if (auto sourceVecType = dyn_cast(broadcastOp.getSourceType())) { + if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1) + return failure(); + SmallVector indices(sourceVecType.getRank(), 0); + Value scalarSource = vector::ExtractOp::create( + rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices); + return scalarSource; + } return broadcastOp.getSource(); } @@ -85,14 +96,14 @@ struct MaskedLoadLowering final : OpRewritePattern { if (maskedOp->hasAttr(kMaskedloadNeedsMask)) return failure(); - if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { + if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { return failure(); } // Check if this is either a full inbounds load or an empty, oob load. If // so, take the fast path and don't generate an if condition, because we // know doing the oob load is always safe. - if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) { + if (succeeded(getFullMask(rewriter, maskedOp.getMask()))) { Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp, /*passthru=*/true); rewriter.replaceOp(maskedOp, load); @@ -176,7 +187,11 @@ struct FullMaskedLoadToConditionalLoad LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { - FailureOr maybeCond = matchFullMask(rewriter, loadOp.getMask()); + if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) + return rewriter.notifyMatchFailure( + loadOp, "buffer loads are handled by a more specialized pattern"); + + FailureOr maybeCond = getFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { return failure(); } @@ -203,7 +218,16 @@ struct FullMaskedStoreToConditionalStore LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { - FailureOr maybeCond = matchFullMask(rewriter, storeOp.getMask()); + // A condition-free implementation of fully masked stores requires + // 1) an accessor for the num_records field on buffer resources/fat pointers + // 2) knowledge that said field will always be set accurately - that is, + // that writes to x < num_records of offset wouldn't trap, which is + // something a pattern user would need to assert or we'd need to prove. + // + // Therefore, conditional stores to buffers still go down this path at + // present. + + FailureOr maybeCond = getFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure(); } diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index f1d0ad545539a..20084dc72fd01 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -167,3 +167,63 @@ func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, // CHECK-NOT: vector.maskedstore // CHECK: scf.if %[[PRED]] // CHECK: vector.store + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load_splat +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PRED:.+]]: i1, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load_splat(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.splat %arg2 : vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]] + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load_unit_vector_pred +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PREDVEC:.+]]: vector<1xi1>, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : vector<1xi1> to vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0] : i1 from vector<1xi1> +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]] + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load_2d_unit_vector_pred +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PREDVEC:.+]]: vector<1x1xi1>, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load_2d_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1x1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : vector<1x1xi1> to vector<2x2xi1> + %1 = vector.shape_cast %0 : vector<2x2xi1> to vector<4xi1> + %2 = vector.maskedload %arg0[%arg1, %arg1], %1, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %2 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0, 0] : i1 from vector<1x1xi1> +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]]