New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flash attention forward op #11815
Conversation
Drive-by: including specific algorithms from papers for specific models is a slippery slope - is there a plan to remove this once it can be represented more generally? e.g. what's missing in linalg/linalg_ext from being able to model this without having a concrete op defined and how is that going to be approached? A region op that allows for declarative tiling or something? (don't expect an answer, but want to plant that seed and ensure someone is thinking about it this year) Conceptually this is something near the frontend above hlo/tosa/etc so having it plumbed deep feels dirty. General things like scans, sorts, and even winograd have wider applicability but this seems transformer specific. I support the land/iterate/use-it-to-figure-out-how-to-build-the-general-mechanism flow given how beneficial this can be and just want to make sure that's something that's being thought about before we end up with 100 of these each used by one or two models that we need to support for all of time - that should be the frontend's job :) |
Great points @benvanik! This definitely falls under the bucket of land/iterate/use-it-to-figure-out-how-to-build-the-general-mechanism flow. @MaheshRavishankar had some ideas of how we can create a general interface for these kind of ops and even move these ops completely out of iree (something about iree not even knowing about them but @MaheshRavishankar can clarify). We were just waiting to flesh this out to get a better sense of what that interface would look like. |
Awesome, that sounds great! |
Yeah this is meant to be something closer to hlo/tosa/etc. (so would go into user defined preprocessing flow that we've discussed briefly). The only thing this op needs is a couple of interfaces for IREE to not actually know what this op is, but just use the interface. One of it is the |
Thank you for the pragmatic |
yep! and linalg_ext is a great place to experiment with these things! |
Cool, very exciting stuff ! I'd love to see the lowering to Linalg ops be a "functional style" transformation that produces k ops(i.e. the 4-6 softmax ops + the 2 gemm ops), I can make direct use of that with a generic transform dialect ExpandOp and immediately apply some strategy to map to threads. This would save the pain of the initial matcher 8ops matcher and get us right into the block-level transform. @ftynse also to see that part of the design / simplification space. |
Note, you could also do that lowering through a softmax op and have the softmax lower to the 4-6 ops version and let things compose. |
Yes! Yes! (a thousand times yes!!) |
I just uploaded a patch to add the lowering. @nicolasvasilache - not sure if this is a functional lowering but currently this op lowers to the following.
|
This patch adds the flash attention forward op to LinalgExt. For tiling, we expose the two outer dimensions corresponding to batch and sequence length.
This patch adds a pass to lower the flash attention forward op to a tiled version of 2 matmuls + softmax that can then be run on a particular backend.
Also fix bugs and update corresponding unit tests.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What Nicolas means by "functional-style" is expressing the top-level rewrite rule as a function that takes a FlashAttnForwardOp
and returns the most significant ops the rewrite produced. This can also be done in a pattern, see https://github.com/llvm/llvm-project/blob/ebfb1ddbe18457723ea62f4ecc94667a44ca4569/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp#L680 for example. Thus, one can take the result of the rewrite and directly apply another rewrite rule on it, without going over the IR and matching or the rewriter driver. If we were in a sufficiently advanced functional language, we would have been able to do a function composition of the two rewrite rules (provided they return their most significant ops) to obtain a single rewrite rule.
if (getNumInputs() != 3) { | ||
return op->emitOpError("expected three input operands"); | ||
} | ||
if (getNumOutputs() != 1) { | ||
return op->emitOpError("expected one output operand"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why these aren't specified in the ODS instead of using variadic?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just because they implement the LinalgExtInterface
.... We could probably do it better and reduce what we use the interface for.... But the interface expects inputs
and outputs
. It was done to just make a few things easier (short term).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The interface just requires that getInputs
and getOutputs
methods to exist, it doesn't specify how they are implemented. The ODS part could just have:
let extraClassDeclaration = [{
ValueRange getInputs() { return getOperands().take_front(3); }
ValueRange getOutputs() { return getOperands().drop_front(3).take_front(); }
}];
and have proper accessors and verifiers generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense. Thanks! Will incorporate into subsequent patches.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a follow up on this, the generated file for LinalgExtInterfaces specifies
ValueRange getInputs(const Concept *impl, ::mlir::Operation *tablegen_opaque_val);
but
return getOperands().take_front(3);
returns a mlir::OperandRange
. Do you know how to convert one to the other?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ValueRange is implicitly constructible form OperandRange - https://github.com/llvm/llvm-project/blob/5a7f47cc021bd7a19cb70c9a30755d6b3cb67431/mlir/include/mlir/IR/ValueRange.h#L374
Value query() { | ||
return getInputOperand(0)->get(); | ||
} | ||
Value key() { | ||
return getInputOperand(1)->get(); | ||
} | ||
Value value() { | ||
return getInputOperand(2)->get(); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't the op just have:
let arguments = (ins AnyShaped:$query, AnyShaped:$key, AnyShaped:$value, ...)`
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe that we need inputs and outputs because these ops implement the LinalgExtInterface
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See my suggestion above. The interface cannot require ops to have ODS arguments with specific names or kinds. FWIW, it's better to explicitly implement interface methods instead of rely on implicitly generated methods to fit what the interface requires.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed. Will do.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looking into this more, it seems like we can't have
let arguments = (ins AnyShaped:$query, AnyShaped:$key, AnyShaped:$value, ...)
because of the AttrSizedOperandSegments
trait that every LinalgExtOp needs to satisfy
iree/llvm-external-projects/iree-dialects/include/iree-dialects/Dialect/LinalgExt/IR/LinalgExtOps.td
Line 48 in 52c5d14
[AttrSizedOperandSegments, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah... we should drop it from the LinalgExt base op and move it into individual ops.... Just one of the things that fell through the cracks...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we create a bug to track these cleanups? I don't want to block this PR on legacy refactorings.
const int64_t keyRank = getKeyRank(); | ||
if (keyRank != 3) { | ||
return op->emitError("expected key tensor to have rank 3"); | ||
} | ||
const ShapedType queryType = getQueryType(); | ||
const ShapedType keyType = getKeyType(); | ||
const ShapedType valueType = getValueType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: we don't use this kind of const in LLVM/MLIR, the code in iree-dialect is intended for upstreaming, so it's better to align on that.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont use const
in IREE as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, will remove.
LogicalResult matchAndRewrite(FlashAttentionFwdOp attnOp, | ||
PatternRewriter &rewriter) const override { | ||
Location loc = attnOp.getLoc(); | ||
rewriter.setInsertionPoint(attnOp); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is highly desirable to use OpBuilder::InsertionGuard
before modifying the insertion point of the rewriter passed by reference from above, the caller may not expect modifications to the insertion point.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, will change. Thanks!
yieldOp, ValueRange{secondLoopNest.results[0]}); | ||
} | ||
|
||
attnOp.getResults()[0].replaceAllUsesWith(firstLoopNest.results[0]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Direct IR manipulations such as RAUW are disallowed in rewrite patterns. This may work accidentally, but may lead to catastrophic issues if this pattern is combined with other patterns or if the logic of the rewriter changes. This should use rewriter APIs instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
This needs an explicit enumeration of uses and perform replacement under an updateRootInPlace or startRootUpdate/finalizeRootUpdate pair.
There is very unfortunately zero API support to guide users to do the right things atm.
We unforunately have such anti-patterns in SCF and other places ..
see my comment here: https://discourse.llvm.org/t/rfc-canonicalizerpass-convergence-error-handling/67333/32
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see thanks! I was not aware of that. Will modify as per Nicolas' suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking at the source for RAUW
/// Find uses of `from` and replace it with `to`
void RewriterBase::replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
updateRootInPlace(op, [&]() { operand.set(to); });
}
}
This looks exactly like Nicolas' suggestion? Or am I missing something? Why is RAUW unsafe while looping over uses + updateRootInPlace safe?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, this is the right thing to use. It was added recently and I don't know if IREE's version of LLVM has caught up with it.
This implementation of this new method is similar to RAUW but additionally calls updateRootInPlace
. This notifies the derived rewriter that an op was modified. Different rewriters may be using that for different purposes, typically updating some internal data structures. Not notifying them about the change will make that internal state invalid wrt the IR. Specifically, the greedy rewriter (depending on the options) will consider the ops it knows were modified again, in case modifications allowed more patterns to apply. Without notification, it wouldn't and the result would be different. This happens to work stably because the flow here doesn't actually need the fixed-point greedy rewriter like pointed below. This isn't catastrophic, but other rewriters, such as the tracking one, may have a pointer to ops that satisfy certain invariants that will be violated by unreported in-place modifications, leading to a nasty crash.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the explanation. Also reminds me that I need to go back and re-read the source of greedy rewriter :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's being actively worked on upstream, so maybe give it some time to stabilize.
RewritePatternSet patterns(&getContext()); | ||
patterns.insert<ReifyFlashAttentionFwdTransform>(context); | ||
if (failed( | ||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { | ||
return signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not for this patch to solve, but I really wonder if the complexity of the greedy pattern rewriter is warranted just to lower one specific op with no pattern matching involved. Sounds like a simple walk over the IR with an IRRewriter
would be largely sufficient.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to that. Its a habit in MLIR, that we (including me) have to unlearn.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here. It is one of those handy tools that you habitually gravitate towards. Will change.
@harsh-nod hmm .. I would expect a significantly simpler form and test that should involve no loops at all. We should have a test where flash attention lowers to exactly We should also have a test where softmax lowers to this form (or an equivalent one but please do not mix paralell and reduction result like it is shown here). You should adapt the above to make the ops that require it rank-polymorphic enough (maybe softmax has multiple leading parallel dimensions while flash attention does not, depends on the usage, I am unsure). One we we have the ops introduced with these simple lowerings + tests (1 or 2 PRs for that), then we should start looking at lowering to loops, tiling etc. All this should compose. |
Signal boosting this. I think the right way to do this is to also implement the softmax fusion op. Otherwise, we will be doing various raisings to recover it, and this is a good chance to not entrench that further. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry for the delay in review. Mostly looks fine. But a big part of the TileAndDecompose...
is hard to read with some more comments. Ill do another round shortly.
if (getNumInputs() != 3) { | ||
return op->emitOpError("expected three input operands"); | ||
} | ||
if (getNumOutputs() != 1) { | ||
return op->emitOpError("expected one output operand"); | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just because they implement the LinalgExtInterface
.... We could probably do it better and reduce what we use the interface for.... But the interface expects inputs
and outputs
. It was done to just make a few things easier (short term).
const int64_t keyRank = getKeyRank(); | ||
if (keyRank != 3) { | ||
return op->emitError("expected key tensor to have rank 3"); | ||
} | ||
const ShapedType queryType = getQueryType(); | ||
const ShapedType keyType = getKeyType(); | ||
const ShapedType valueType = getValueType(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We dont use const
in IREE as well.
RewritePatternSet patterns(&getContext()); | ||
patterns.insert<ReifyFlashAttentionFwdTransform>(context); | ||
if (failed( | ||
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) { | ||
return signalPassFailure(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1 to that. Its a habit in MLIR, that we (including me) have to unlearn.
public: | ||
using OpRewritePattern::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(FlashAttentionFwdOp attnOp, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we add some IR comments about what is going on here.. .Will be easier to parse.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, I can put in a description of what is going on.
Signal boosting comments above on why the explicit |
Thanks that makes sense. Seems like that would also be useful for transform dialect manipulations. |
When I first read through the flash attention paper, it seemed like a somewhat straightforward application of So to that end, I propose the following set of patches
What do you think? In parallel, I propose we keep this patch as a reference because in some ways this is where we are headed |
Totally agree. Will start putting up patches as per proposed plan outlined above. |
Sure. Do you have any pointers on what functions I should be using instead? |
Maybe we can hop on a GVC and chat about that. If softmax operator you describe below implements the tiling interface it will get tiled and fused. But question is do we need to split the max and sum statistic ? That is the part that I need to understand better.... I am happy to chat more and figure out how to orchestrate this.
This would be really awesome! If you are willing to this it would be fantastic indeed (I hope you can pull in some resources to get this working!).
|
@MaheshRavishankar - sure let's chat on GVC and I can give more details. I will ping on the internal channel to schedule a time. This is my primary focus right now but we can chat more about this on GVC as well. |
Would be great, yes ! I would just modulate point 4. as I think you could just tile The remaining fusion + 5. would prob. be better done with transform dialect stuff once it is easy to match these ops. |
Sounds good! Will start along this plan and revisit if I run into any issues. |
Ah yes, it's been added recently to the API, it didn't use to exist.
Using that is good.
…On Wed, Jan 18, 2023 at 11:16 PM harsh-nod ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In
llvm-external-projects/iree-dialects/lib/Dialect/LinalgExt/Passes/TileAndDecomposeFlashAttentionPass.cpp
<#11815 (comment)>:
> + queryShape, ivs, sequenceTileLength, loc, rewriter);
+
+ if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
+ secondLoopNest.loops.back().getBody()->getTerminator())) {
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(
+ yieldOp, ValueRange{updatedAcc, updatedMax, updatedSum});
+ }
+
+ if (scf::YieldOp yieldOp = dyn_cast<scf::YieldOp>(
+ firstLoopNest.loops.back().getBody()->getTerminator())) {
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<scf::YieldOp>(
+ yieldOp, ValueRange{secondLoopNest.results[0]});
+ }
+
+ attnOp.getResults()[0].replaceAllUsesWith(firstLoopNest.results[0]);
I was looking at the source for RAUW
/// Find uses of `from` and replace it with `to`
void RewriterBase::replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
updateRootInPlace(op, [&]() { operand.set(to); });
}
}
This looks exactly like Nicolas' suggestion? Or am I missing something?
Why is RAUW unsafe while looping over uses + updateRootInPlace safe?
—
Reply to this email directly, view it on GitHub
<#11815 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/ACNNU5EZW2LHBCGQDLYLWC3WTBTUVANCNFSM6AAAAAATZRSXDQ>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
--
N
|
Also fix a bug that was updating the wrong iter args. This now lowers through the CUDA backend.
The patch now outputs the correct result when using the CPU backend.
} // namespace | ||
|
||
namespace { | ||
struct TileAndDecomposeFlashAttentionTransformPass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm trying to understand why we cannot tile using the TilingInterface and the existing tileAndDistribute pass. Is there a good place to read about that? I thought the first level of tiling would just tile along a common parallel dimension.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you look at the attention patch #11928, this get us pretty close to the original flash attention paper. Here, we do an entire row of the first matmul, then the softmax and then the second matmul. The problem is that in general that row can be quite long. So we need to tile that row and do it one tile at a time. This means also maintaining statistics about the other linalg generics. Would be more than happy to go over it in GVC. Let me know what's a good time.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also here is a link to the paper: https://arxiv.org/pdf/2205.14135.pdf
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ThomasRaoux might need to use the Reduction tiling interface you added, but with slight changes to how the initialization and merge is used... thats part is to be resolved going forward.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was looking at the PartialReductionOpInterface here (https://github.com/llvm/llvm-project/blob/b41eb9601cf1aa6e33eedbb1c923b8ed0dcdcd42/mlir/include/mlir/Interfaces/TilingInterface.td#L159) but its not obvious how to extend this to multiple ops. (matmuls + linalg generics, with no fixup needed for matmuls but needed for generics)
Has been superseded by many other patches. Closing since no long required. |
This patch adds the flash attention forward op
to LinalgExt. For tiling, we expose the two
outer dimensions corresponding to batch and sequence length.