diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c166e0a..85b9543f97c27 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -33,19 +33,18 @@ using namespace mlir::amdgpu; /// This pattern supports lowering of: `vector.maskedload` to `vector.load` /// and `arith.select` if the memref is in buffer address space. -static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, - vector::MaskedLoadOp maskedOp) { - auto memRefType = dyn_cast(maskedOp.getBase().getType()); +static LogicalResult hasBufferAddressSpace(Type type) { + auto memRefType = dyn_cast(type); if (!memRefType) - return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + return failure(); Attribute addrSpace = memRefType.getMemorySpace(); if (!isa_and_nonnull(addrSpace)) - return rewriter.notifyMatchFailure(maskedOp, "no address space"); + return failure(); if (dyn_cast(addrSpace).getValue() != amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + return failure(); return success(); } @@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, return load; } -/// Check if the given value comes from a broadcasted i1 condition. -static FailureOr matchFullMask(OpBuilder &b, Value val) { +/// If the given value is the broadcast of a non-constant scalar, return that +/// scalar, extracting it from length-1 vectors if necessary. +static FailureOr getFullMask(RewriterBase &rw, Value val) { + while (auto shapeCast = val.getDefiningOp()) + val = shapeCast.getSource(); + auto splatOp = val.getDefiningOp(); + if (splatOp) + return splatOp.getInput(); auto broadcastOp = val.getDefiningOp(); if (!broadcastOp) return failure(); - if (isa(broadcastOp.getSourceType())) - return failure(); + if (auto sourceVecType = dyn_cast(broadcastOp.getSourceType())) { + if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1) + return failure(); + SmallVector indices(sourceVecType.getRank(), 0); + Value scalarSource = vector::ExtractOp::create( + rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices); + return scalarSource; + } return broadcastOp.getSource(); } @@ -85,14 +96,14 @@ struct MaskedLoadLowering final : OpRewritePattern { if (maskedOp->hasAttr(kMaskedloadNeedsMask)) return failure(); - if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { + if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { return failure(); } // Check if this is either a full inbounds load or an empty, oob load. If // so, take the fast path and don't generate an if condition, because we // know doing the oob load is always safe. - if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) { + if (succeeded(getFullMask(rewriter, maskedOp.getMask()))) { Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), maskedOp, /*passthru=*/true); rewriter.replaceOp(maskedOp, load); @@ -176,7 +187,11 @@ struct FullMaskedLoadToConditionalLoad LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { - FailureOr maybeCond = matchFullMask(rewriter, loadOp.getMask()); + if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) + return rewriter.notifyMatchFailure( + loadOp, "buffer loads are handled by a more specialized pattern"); + + FailureOr maybeCond = getFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { return failure(); } @@ -203,7 +218,16 @@ struct FullMaskedStoreToConditionalStore LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { - FailureOr maybeCond = matchFullMask(rewriter, storeOp.getMask()); + // A condition-free implementation of fully masked stores requires + // 1) an accessor for the num_records field on buffer resources/fat pointers + // 2) knowledge that said field will always be set accurately - that is, + // that writes to x < num_records of offset wouldn't trap, which is + // something a pattern user would need to assert or we'd need to prove. + // + // Therefore, conditional stores to buffers still go down this path at + // present. + + FailureOr maybeCond = getFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure(); } diff --git a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir index f1d0ad545539a..20084dc72fd01 100644 --- a/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir +++ b/mlir/test/Dialect/AMDGPU/maskedload-to-load.mlir @@ -167,3 +167,63 @@ func.func @full_mask_maskedstore_to_store(%arg0: memref<8x8xf16>, %arg1: index, // CHECK-NOT: vector.maskedstore // CHECK: scf.if %[[PRED]] // CHECK: vector.store + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load_splat +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PRED:.+]]: i1, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load_splat(%arg0: memref<8x8xf16>, %arg1: index, %arg2: i1, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.splat %arg2 : vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]] + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load_unit_vector_pred +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PREDVEC:.+]]: vector<1xi1>, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : vector<1xi1> to vector<4xi1> + %1 = vector.maskedload %arg0[%arg1, %arg1], %0, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %1 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0] : i1 from vector<1xi1> +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]] + +// ----- + +// CHECK-LABEL: func.func @full_select_maskedload_to_load_2d_unit_vector_pred +// CHECK-SAME: %[[MEM:.+]]: memref<8x8xf16>, +// CHECK-SAME: %[[IDX:.+]]: index, +// CHECK-SAME: %[[PREDVEC:.+]]: vector<1x1xi1>, +// CHECK-SAME: %[[PASSTHRU:.+]]: vector<4xf16>) +func.func @full_select_maskedload_to_load_2d_unit_vector_pred(%arg0: memref<8x8xf16>, %arg1: index, %arg2: vector<1x1xi1>, %arg3: vector<4xf16>) -> vector<4xf16> { + %0 = vector.broadcast %arg2 : vector<1x1xi1> to vector<2x2xi1> + %1 = vector.shape_cast %0 : vector<2x2xi1> to vector<4xi1> + %2 = vector.maskedload %arg0[%arg1, %arg1], %1, %arg3 : memref<8x8xf16>, vector<4xi1>, vector<4xf16> into vector<4xf16> + return %2 : vector<4xf16> +} +// CHECK-NOT: vector.maskedload +// CHECK: %[[PRED:.+]] = vector.extract %[[PREDVEC]][0, 0] : i1 from vector<1x1xi1> +// CHECK: scf.if %[[PRED]] +// CHECK: %[[LOAD:.+]] = vector.load +// CHECK: scf.yield %[[LOAD]] +// CHECK: else +// CHECK: scf.yield %[[PASSTHRU]]