-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][AMDGPU] Improve masked_load(..., broadcast(...), ...) handling #159635
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -33,19 +33,18 @@ using namespace mlir::amdgpu; | |
|
||
/// This pattern supports lowering of: `vector.maskedload` to `vector.load` | ||
/// and `arith.select` if the memref is in buffer address space. | ||
static LogicalResult baseInBufferAddrSpace(PatternRewriter &rewriter, | ||
vector::MaskedLoadOp maskedOp) { | ||
auto memRefType = dyn_cast<MemRefType>(maskedOp.getBase().getType()); | ||
static LogicalResult hasBufferAddressSpace(Type type) { | ||
auto memRefType = dyn_cast<MemRefType>(type); | ||
if (!memRefType) | ||
return rewriter.notifyMatchFailure(maskedOp, "not a memref source"); | ||
return failure(); | ||
|
||
Attribute addrSpace = memRefType.getMemorySpace(); | ||
if (!isa_and_nonnull<amdgpu::AddressSpaceAttr>(addrSpace)) | ||
return rewriter.notifyMatchFailure(maskedOp, "no address space"); | ||
return failure(); | ||
|
||
if (dyn_cast<amdgpu::AddressSpaceAttr>(addrSpace).getValue() != | ||
amdgpu::AddressSpace::FatRawBuffer) | ||
return rewriter.notifyMatchFailure(maskedOp, "not in buffer address space"); | ||
return failure(); | ||
|
||
return success(); | ||
} | ||
|
@@ -62,13 +61,25 @@ static Value createVectorLoadForMaskedLoad(OpBuilder &builder, Location loc, | |
return load; | ||
} | ||
|
||
/// Check if the given value comes from a broadcasted i1 condition. | ||
static FailureOr<Value> matchFullMask(OpBuilder &b, Value val) { | ||
/// If the given value is the broadcast of a non-constant scalar, return that | ||
/// scalar, extracting it from length-1 vectors if necessary. | ||
static FailureOr<Value> getFullMask(RewriterBase &rw, Value val) { | ||
while (auto shapeCast = val.getDefiningOp<vector::ShapeCastOp>()) | ||
val = shapeCast.getSource(); | ||
auto splatOp = val.getDefiningOp<vector::SplatOp>(); | ||
if (splatOp) | ||
return splatOp.getInput(); | ||
Comment on lines
+69
to
+71
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vector::SplatOp is deprecated please do not add support for it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What deprecation? There are parents that canonicalize broadcasts to splats still, no? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, check the documentation for vector.splat https://mlir.llvm.org/docs/Dialects/Vector/#vectorsplat-vectorsplatop |
||
auto broadcastOp = val.getDefiningOp<vector::BroadcastOp>(); | ||
if (!broadcastOp) | ||
return failure(); | ||
if (isa<VectorType>(broadcastOp.getSourceType())) | ||
return failure(); | ||
if (auto sourceVecType = dyn_cast<VectorType>(broadcastOp.getSourceType())) { | ||
if (sourceVecType.isScalable() || sourceVecType.getNumElements() != 1) | ||
return failure(); | ||
SmallVector<int64_t> indices(sourceVecType.getRank(), 0); | ||
Value scalarSource = vector::ExtractOp::create( | ||
rw, broadcastOp.getLoc(), broadcastOp.getSource(), indices); | ||
return scalarSource; | ||
} | ||
Comment on lines
+75
to
+82
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we should do this. Every pattern cannot handle the world. This pattern is supposed to be ran after unrolling/flattening. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. But after unrolling or flattening that's aren't any broadcasts? We might need to move broadcast lowering much, much later in out downstream pipeline if you want to keep this as is? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Unrolling/Flattening preserve broadcasts from scalar -> vector. If they are not, we should fix it. |
||
return broadcastOp.getSource(); | ||
} | ||
|
||
|
@@ -85,14 +96,14 @@ struct MaskedLoadLowering final : OpRewritePattern<vector::MaskedLoadOp> { | |
if (maskedOp->hasAttr(kMaskedloadNeedsMask)) | ||
return failure(); | ||
|
||
if (failed(baseInBufferAddrSpace(rewriter, maskedOp))) { | ||
if (failed(hasBufferAddressSpace(maskedOp.getBase().getType()))) { | ||
return failure(); | ||
} | ||
|
||
// Check if this is either a full inbounds load or an empty, oob load. If | ||
// so, take the fast path and don't generate an if condition, because we | ||
// know doing the oob load is always safe. | ||
if (succeeded(matchFullMask(rewriter, maskedOp.getMask()))) { | ||
if (succeeded(getFullMask(rewriter, maskedOp.getMask()))) { | ||
Value load = createVectorLoadForMaskedLoad(rewriter, maskedOp.getLoc(), | ||
maskedOp, /*passthru=*/true); | ||
rewriter.replaceOp(maskedOp, load); | ||
|
@@ -176,7 +187,11 @@ struct FullMaskedLoadToConditionalLoad | |
|
||
LogicalResult matchAndRewrite(vector::MaskedLoadOp loadOp, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<Value> maybeCond = matchFullMask(rewriter, loadOp.getMask()); | ||
if (succeeded(hasBufferAddressSpace(loadOp.getBase().getType()))) | ||
return rewriter.notifyMatchFailure( | ||
loadOp, "buffer loads are handled by a more specialized pattern"); | ||
|
||
FailureOr<Value> maybeCond = getFullMask(rewriter, loadOp.getMask()); | ||
if (failed(maybeCond)) { | ||
return failure(); | ||
} | ||
|
@@ -203,7 +218,16 @@ struct FullMaskedStoreToConditionalStore | |
|
||
LogicalResult matchAndRewrite(vector::MaskedStoreOp storeOp, | ||
PatternRewriter &rewriter) const override { | ||
FailureOr<Value> maybeCond = matchFullMask(rewriter, storeOp.getMask()); | ||
// A condition-free implementation of fully masked stores requires | ||
// 1) an accessor for the num_records field on buffer resources/fat pointers | ||
// 2) knowledge that said field will always be set accurately - that is, | ||
// that writes to x < num_records of offset wouldn't trap, which is | ||
// something a pattern user would need to assert or we'd need to prove. | ||
// | ||
// Therefore, conditional stores to buffers still go down this path at | ||
// present. | ||
|
||
FailureOr<Value> maybeCond = getFullMask(rewriter, storeOp.getMask()); | ||
if (failed(maybeCond)) { | ||
return failure(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove the error messages?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now the actual pattern doesn't have failure remarks.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Refactoring so I get a shared buffer address space utility
I can put an error message in the buffer loads pattern though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good