diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp index f15c63c166e0a..89ef51f922cad 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/MaskedloadToLoad.cpp @@ -33,19 +33,18 @@ using namespace mlir::amdgpu; /// This pattern supports lowering of: `vector.maskedload` to `vector.load` /// and `arith.select` if the memref is in buffer address space. -static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, - vector::MaskedLoadOp maskedOp) { - auto memRefType = dyn_cast(maskedOp.getBase().getType()); +static LogicalResult hasBufferAddressSpace(Type type) { + auto memRefType = dyn_cast(type); if (!memRefType) - return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); + return failure(); Attribute addrSpace = memRefType.getMemorySpace(); if (!isa_and_nonnull(addrSpace)) - return rewriter.notifyMatchFailure(maskedOp, "no address space"); + return failure(); if (dyn_cast(addrSpace).getValue() != amdgpu::AddressSpace::FatRawBuffer) - return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); + return failure(); return success(); } @@ -83,10 +82,11 @@ struct MaskedLoadLowering final : OpRewritePattern { LogicalResult matchAndRewrite(vector::MaskedLoadOp maskedOp, PatternRewriter &rewriter) const override { if (maskedOp->hasAttr(kMaskedloadNeedsMask)) - return failure(); + return rewriter.notifyMatchFailure(maskedOp, "already rewritten"); - if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { - return failure(); + if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { + return rewriter.notifyMatchFailure( + maskedOp, "isn't a load from a fat buffer resource"); } // Check if this is either a full inbounds load or an empty, oob load. If @@ -176,9 +176,14 @@ struct FullMaskedLoadToConditionalLoad LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, PatternRewriter &rewriter) const override { + if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) + return rewriter.notifyMatchFailure( + loadOp, "buffer loads are handled by a more specialized pattern"); + FailureOr maybeCond = matchFullMask(rewriter, loadOp.getMask()); if (failed(maybeCond)) { - return failure(); + return rewriter.notifyMatchFailure(loadOp, + "isn't loading a broadcasted scalar"); } Value cond = maybeCond.value(); @@ -203,6 +208,15 @@ struct FullMaskedStoreToConditionalStore LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, PatternRewriter &rewriter) const override { + // A condition-free implementation of fully masked stores requires + // 1) an accessor for the num_records field on buffer resources/fat pointers + // 2) knowledge that said field will always be set accurately - that is, + // that writes to x < num_records of offset wouldn't trap, which is + // something a pattern user would need to assert or we'd need to prove. + // + // Therefore, conditional stores to buffers still go down this path at + // present. + FailureOr maybeCond = matchFullMask(rewriter, storeOp.getMask()); if (failed(maybeCond)) { return failure();