Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flash attention forward op #11815

Closed
wants to merge 5 commits into from
Closed

Conversation

harsh-nod
Copy link
Contributor

This patch adds the flash attention forward op
to LinalgExt. For tiling, we expose the two
outer dimensions corresponding to batch and sequence length.

@benvanik
Copy link
Collaborator

Drive-by: including specific algorithms from papers for specific models is a slippery slope - is there a plan to remove this once it can be represented more generally? e.g. what's missing in linalg/linalg_ext from being able to model this without having a concrete op defined and how is that going to be approached? A region op that allows for declarative tiling or something? (don't expect an answer, but want to plant that seed and ensure someone is thinking about it this year)

Conceptually this is something near the frontend above hlo/tosa/etc so having it plumbed deep feels dirty. General things like scans, sorts, and even winograd have wider applicability but this seems transformer specific. I support the land/iterate/use-it-to-figure-out-how-to-build-the-general-mechanism flow given how beneficial this can be and just want to make sure that's something that's being thought about before we end up with 100 of these each used by one or two models that we need to support for all of time - that should be the frontend's job :)

@harsh-nod
Copy link
Contributor Author

Great points @benvanik! This definitely falls under the bucket of land/iterate/use-it-to-figure-out-how-to-build-the-general-mechanism flow. @MaheshRavishankar had some ideas of how we can create a general interface for these kind of ops and even move these ops completely out of iree (something about iree not even knowing about them but @MaheshRavishankar can clarify). We were just waiting to flesh this out to get a better sense of what that interface would look like.

@benvanik
Copy link
Collaborator

Awesome, that sounds great!
When I saw flash attention last year it seemed like just the kind of stuff we want to do in lots of places and I'm excited to see how we can make this something that can be applied to lots of different algorithms :)

@MaheshRavishankar
Copy link
Contributor

Drive-by: including specific algorithms from papers for specific models is a slippery slope - is there a plan to remove this once it can be represented more generally? e.g. what's missing in linalg/linalg_ext from being able to model this without having a concrete op defined and how is that going to be approached? A region op that allows for declarative tiling or something? (don't expect an answer, but want to plant that seed and ensure someone is thinking about it this year)

Conceptually this is something near the frontend above hlo/tosa/etc so having it plumbed deep feels dirty. General things like scans, sorts, and even winograd have wider applicability but this seems transformer specific. I support the land/iterate/use-it-to-figure-out-how-to-build-the-general-mechanism flow given how beneficial this can be and just want to make sure that's something that's being thought about before we end up with 100 of these each used by one or two models that we need to support for all of time - that should be the frontend's job :)

Yeah this is meant to be something closer to hlo/tosa/etc. (so would go into user defined preprocessing flow that we've discussed briefly). The only thing this op needs is a couple of interfaces for IREE to not actually know what this op is, but just use the interface. One of it is the TilingInterface, the other one is what we want to explore and have a couple of examples for to be able to get a better idea of the interface. Thats why the op is in LinalgExt. Its just a place-holder for the interface...

@powderluv
Copy link
Collaborator

Thank you for the pragmatic land/iterate/use-it-to-figure-out-how-to-build-the-general-mechanism flow . It helps us be competitive now while also building the right thing long term.

@benvanik
Copy link
Collaborator

yep! and linalg_ext is a great place to experiment with these things!

@nicolasvasilache
Copy link
Contributor

Cool, very exciting stuff !

I'd love to see the lowering to Linalg ops be a "functional style" transformation that produces k ops(i.e. the 4-6 softmax ops + the 2 gemm ops), I can make direct use of that with a generic transform dialect ExpandOp and immediately apply some strategy to map to threads.

This would save the pain of the initial matcher 8ops matcher and get us right into the block-level transform.

@ftynse also to see that part of the design / simplification space.

@nicolasvasilache
Copy link
Contributor

Note, you could also do that lowering through a softmax op and have the softmax lower to the 4-6 ops version and let things compose.

@MaheshRavishankar
Copy link
Contributor

Note, you could also do that lowering through a softmax op and have the softmax lower to the 4-6 ops version and let things compose.

Yes! Yes! (a thousand times yes!!)

@harsh-nod
Copy link
Contributor Author

harsh-nod commented Jan 13, 2023

I just uploaded a patch to add the lowering. @nicolasvasilache - not sure if this is a functional lowering but currently this op lowers to the following.

module {
  func.func @_flash_attention_fwd_dispatch_0() {
    %cst = arith.constant 0.000000e+00 : f32
    %cst_0 = arith.constant -1.000000e+02 : f32
    %c64 = arith.constant 64 : index
    %c192 = arith.constant 192 : index
    %c1024 = arith.constant 1024 : index
    %c0 = arith.constant 0 : index
    %0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf32>>
    %1 = hal.interface.binding.subspan set(0) binding(1) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf32>>
    %2 = hal.interface.binding.subspan set(0) binding(2) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf32>>
    %3 = hal.interface.binding.subspan set(0) binding(3) type(storage_buffer) offset(%c0) alignment(64) : !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf32>>
    %workgroup_id_x = hal.interface.workgroup.id[0] : index
    %workgroup_count_x = hal.interface.workgroup.count[0] : index
    %workgroup_id_y = hal.interface.workgroup.id[1] : index
    %workgroup_count_y = hal.interface.workgroup.count[1] : index
    scf.for %arg0 = %workgroup_id_y to %c192 step %workgroup_count_y {
      %4 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_id_x]
      %5 = affine.apply affine_map<()[s0] -> (s0 * 64)>()[%workgroup_count_x]
      scf.for %arg1 = %4 to %c1024 step %5 {
        %6 = flow.dispatch.tensor.load %3, offsets = [%arg0, %arg1, 0], sizes = [1, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf32>> -> tensor<1x64x64xf32>
        %7 = flow.dispatch.tensor.load %0, offsets = [%arg0, %arg1, 0], sizes = [1, 64, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf32>> -> tensor<1x64x64xf32>
        %8 = flow.dispatch.tensor.load %1, offsets = [%arg0, 0, 0], sizes = [1, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf32>> -> tensor<1x1024x64xf32>
        %9 = flow.dispatch.tensor.load %2, offsets = [%arg0, 0, 0], sizes = [1, 1024, 64], strides = [1, 1, 1] : !flow.dispatch.tensor<readonly:tensor<192x1024x64xf32>> -> tensor<1x1024x64xf32>
        %10 = tensor.empty() : tensor<64xf32>
        %11 = linalg.fill ins(%cst_0 : f32) outs(%10 : tensor<64xf32>) -> tensor<64xf32>
        %cast = tensor.cast %11 : tensor<64xf32> to tensor<?xf32>
        %12 = linalg.fill ins(%cst : f32) outs(%10 : tensor<64xf32>) -> tensor<64xf32>
        %cast_1 = tensor.cast %12 : tensor<64xf32> to tensor<?xf32>
        %13:3 = scf.for %arg2 = %c0 to %c1024 step %c64 iter_args(%arg3 = %6, %arg4 = %cast, %arg5 = %cast_1) -> (tensor<1x64x64xf32>, tensor<?xf32>, tensor<?xf32>) {
          %extracted_slice = tensor.extract_slice %8[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<1x1024x64xf32> to tensor<64x64xf32>
          %extracted_slice_2 = tensor.extract_slice %9[0, %arg2, 0] [1, 64, 64] [1, 1, 1] : tensor<1x1024x64xf32> to tensor<64x64xf32>
          %extracted_slice_3 = tensor.extract_slice %7[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<1x64x64xf32> to tensor<64x64xf32>
          %extracted_slice_4 = tensor.extract_slice %arg3[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<1x64x64xf32> to tensor<64x64xf32>
          %14 = tensor.empty() : tensor<64x64xf32>
          %transposed = linalg.transpose ins(%extracted_slice : tensor<64x64xf32>) outs(%14 : tensor<64x64xf32>) permutation = [1, 0]
          %15 = linalg.fill ins(%cst : f32) outs(%14 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %16 = linalg.matmul ins(%extracted_slice_3, %transposed : tensor<64x64xf32>, tensor<64x64xf32>) outs(%15 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %17 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["reduction", "parallel"]} ins(%16 : tensor<64x64xf32>) outs(%10 : tensor<64xf32>) {
          ^bb0(%in: f32, %out: f32):
            %26 = arith.maxf %in, %out : f32
            linalg.yield %26 : f32
          } -> tensor<64xf32>
          %18 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>], iterator_types = ["reduction", "parallel"]} ins(%16 : tensor<64x64xf32>) outs(%10 : tensor<64xf32>) {
          ^bb0(%in: f32, %out: f32):
            %26 = arith.addf %in, %out : f32
            linalg.yield %26 : f32
          } -> tensor<64xf32>
          %cast_5 = tensor.cast %arg4 : tensor<?xf32> to tensor<64xf32>
          %19 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%17 : tensor<64xf32>) outs(%cast_5 : tensor<64xf32>) {
          ^bb0(%in: f32, %out: f32):
            %26 = arith.maxf %out, %in : f32
            linalg.yield %26 : f32
          } -> tensor<64xf32>
          %20 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<64xf32>) outs(%cast_5 : tensor<64xf32>) {
          ^bb0(%in: f32, %out: f32):
            %26 = arith.subf %out, %in : f32
            %27 = math.exp %26 : f32
            linalg.yield %27 : f32
          } -> tensor<64xf32>
          %21 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%19 : tensor<64xf32>) outs(%17 : tensor<64xf32>) {
          ^bb0(%in: f32, %out: f32):
            %26 = arith.subf %out, %in : f32
            %27 = math.exp %26 : f32
            linalg.yield %27 : f32
          } -> tensor<64xf32>
          %cast_6 = tensor.cast %arg5 : tensor<?xf32> to tensor<64xf32>
          %22 = linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%18, %20, %21 : tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) outs(%cast_6 : tensor<64xf32>) {
          ^bb0(%in: f32, %in_9: f32, %in_10: f32, %out: f32):
            %26 = arith.mulf %in_9, %out : f32
            %27 = arith.mulf %in_10, %in : f32
            %28 = arith.addf %26, %27 : f32
            linalg.yield %28 : f32
          } -> tensor<64xf32>
          %23 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%17, %21, %22 : tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) outs(%16 : tensor<64x64xf32>) {
          ^bb0(%in: f32, %in_9: f32, %in_10: f32, %out: f32):
            %26 = arith.subf %out, %in : f32
            %27 = math.exp %26 : f32
            %28 = arith.mulf %27, %in_9 : f32
            %29 = arith.divf %28, %in_10 : f32
            linalg.yield %29 : f32
          } -> tensor<64x64xf32>
          %24 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%cast_6, %22, %20 : tensor<64xf32>, tensor<64xf32>, tensor<64xf32>) outs(%extracted_slice_4 : tensor<64x64xf32>) {
          ^bb0(%in: f32, %in_9: f32, %in_10: f32, %out: f32):
            %26 = arith.mulf %in, %in_10 : f32
            %27 = arith.mulf %out, %26 : f32
            %28 = arith.divf %27, %in_9 : f32
            linalg.yield %28 : f32
          } -> tensor<64x64xf32>
          %25 = linalg.matmul ins(%23, %extracted_slice_2 : tensor<64x64xf32>, tensor<64x64xf32>) outs(%24 : tensor<64x64xf32>) -> tensor<64x64xf32>
          %inserted_slice = tensor.insert_slice %25 into %arg3[0, 0, 0] [1, 64, 64] [1, 1, 1] : tensor<64x64xf32> into tensor<1x64x64xf32>
          %cast_7 = tensor.cast %19 : tensor<64xf32> to tensor<?xf32>
          %cast_8 = tensor.cast %22 : tensor<64xf32> to tensor<?xf32>
          scf.yield %inserted_slice, %cast_7, %cast_8 : tensor<1x64x64xf32>, tensor<?xf32>, tensor<?xf32>
        }
        flow.dispatch.tensor.store %13#0, %3, offsets = [%arg0, %arg1, 0], sizes = [1, 64, 64], strides = [1, 1, 1] : tensor<1x64x64xf32> -> !flow.dispatch.tensor<writeonly:tensor<192x1024x64xf32>>
      }
    }
    return
  }
}

This patch adds the flash attention forward op
to LinalgExt. For tiling, we expose the two
outer dimensions corresponding to batch and sequence
length.
This patch adds a pass to lower the flash attention
forward op to a tiled version of 2 matmuls + softmax
that can then be run on a particular backend.
Also fix bugs and update corresponding unit tests.
Copy link
Contributor

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What Nicolas means by "functional-style" is expressing the top-level rewrite rule as a function that takes a FlashAttnForwardOp and returns the most significant ops the rewrite produced. This can also be done in a pattern, see https://github.com/llvm/llvm-project/blob/ebfb1ddbe18457723ea62f4ecc94667a44ca4569/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp#L680 for example. Thus, one can take the result of the rewrite and directly apply another rewrite rule on it, without going over the IR and matching or the rewriter driver. If we were in a sufficiently advanced functional language, we would have been able to do a function composition of the two rewrite rules (provided they return their most significant ops) to obtain a single rewrite rule.

Comment on lines +2765 to +2770
if (getNumInputs() != 3) {
return op->emitOpError("expected three input operands");
}
if (getNumOutputs() != 1) {
return op->emitOpError("expected one output operand");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why these aren't specified in the ODS instead of using variadic?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just because they implement the LinalgExtInterface.... We could probably do it better and reduce what we use the interface for.... But the interface expects inputs and outputs. It was done to just make a few things easier (short term).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interface just requires that getInputs and getOutputs methods to exist, it doesn't specify how they are implemented. The ODS part could just have:

let extraClassDeclaration = [{
  ValueRange getInputs() { return getOperands().take_front(3); }
  ValueRange getOutputs() { return getOperands().drop_front(3).take_front(); }
}];

and have proper accessors and verifiers generated.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense. Thanks! Will incorporate into subsequent patches.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just a follow up on this, the generated file for LinalgExtInterfaces specifies

ValueRange getInputs(const Concept *impl, ::mlir::Operation *tablegen_opaque_val);

but

return getOperands().take_front(3);

returns a mlir::OperandRange. Do you know how to convert one to the other?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines +1150 to +1158
Value query() {
return getInputOperand(0)->get();
}
Value key() {
return getInputOperand(1)->get();
}
Value value() {
return getInputOperand(2)->get();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't the op just have:

let arguments = (ins AnyShaped:$query, AnyShaped:$key, AnyShaped:$value, ...)`

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe that we need inputs and outputs because these ops implement the LinalgExtInterface.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See my suggestion above. The interface cannot require ops to have ODS arguments with specific names or kinds. FWIW, it's better to explicitly implement interface methods instead of rely on implicitly generated methods to fit what the interface requires.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed. Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking into this more, it seems like we can't have
let arguments = (ins AnyShaped:$query, AnyShaped:$key, AnyShaped:$value, ...)
because of the AttrSizedOperandSegments trait that every LinalgExtOp needs to satisfy

. Since the variadic is still required, I don't think it makes sense to separate out into query, key and value because each one of those will be required to be variadic. I propose we stick with the current formulation until AttrSizedOperandSegments is removed from LinalgExtOps. What do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah... we should drop it from the LinalgExt base op and move it into individual ops.... Just one of the things that fell through the cracks...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we create a bug to track these cleanups? I don't want to block this PR on legacy refactorings.

Comment on lines 2771 to 2777
const int64_t keyRank = getKeyRank();
if (keyRank != 3) {
return op->emitError("expected key tensor to have rank 3");
}
const ShapedType queryType = getQueryType();
const ShapedType keyType = getKeyType();
const ShapedType valueType = getValueType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we don't use this kind of const in LLVM/MLIR, the code in iree-dialect is intended for upstreaming, so it's better to align on that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont use const in IREE as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will remove.

LogicalResult matchAndRewrite(FlashAttentionFwdOp attnOp,
PatternRewriter &rewriter) const override {
Location loc = attnOp.getLoc();
rewriter.setInsertionPoint(attnOp);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is highly desirable to use OpBuilder::InsertionGuard before modifying the insertion point of the rewriter passed by reference from above, the caller may not expect modifications to the insertion point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, will change. Thanks!

yieldOp, ValueRange{secondLoopNest.results[0]});
}

attnOp.getResults()[0].replaceAllUsesWith(firstLoopNest.results[0]);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Direct IR manipulations such as RAUW are disallowed in rewrite patterns. This may work accidentally, but may lead to catastrophic issues if this pattern is combined with other patterns or if the logic of the rewriter changes. This should use rewriter APIs instead.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

This needs an explicit enumeration of uses and perform replacement under an updateRootInPlace or startRootUpdate/finalizeRootUpdate pair.

There is very unfortunately zero API support to guide users to do the right things atm.
We unforunately have such anti-patterns in SCF and other places ..

see my comment here: https://discourse.llvm.org/t/rfc-canonicalizerpass-convergence-error-handling/67333/32

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thanks! I was not aware of that. Will modify as per Nicolas' suggestions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at the source for RAUW

  /// Find uses of `from` and replace it with `to`
  void RewriterBase::replaceAllUsesWith(Value from, Value to) {
    for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
      Operation *op = operand.getOwner();
      updateRootInPlace(op, [&]() { operand.set(to); });
    }
  }

This looks exactly like Nicolas' suggestion? Or am I missing something? Why is RAUW unsafe while looping over uses + updateRootInPlace safe?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this is the right thing to use. It was added recently and I don't know if IREE's version of LLVM has caught up with it.

This implementation of this new method is similar to RAUW but additionally calls updateRootInPlace. This notifies the derived rewriter that an op was modified. Different rewriters may be using that for different purposes, typically updating some internal data structures. Not notifying them about the change will make that internal state invalid wrt the IR. Specifically, the greedy rewriter (depending on the options) will consider the ops it knows were modified again, in case modifications allowed more patterns to apply. Without notification, it wouldn't and the result would be different. This happens to work stably because the flow here doesn't actually need the fixed-point greedy rewriter like pointed below. This isn't catastrophic, but other rewriters, such as the tracking one, may have a pointer to ops that satisfy certain invariants that will be violated by unreported in-place modifications, leading to a nasty crash.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation. Also reminds me that I need to go back and re-read the source of greedy rewriter :)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's being actively worked on upstream, so maybe give it some time to stabilize.

Comment on lines 404 to 436
RewritePatternSet patterns(&getContext());
patterns.insert<ReifyFlashAttentionFwdTransform>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not for this patch to solve, but I really wonder if the complexity of the greedy pattern rewriter is warranted just to lower one specific op with no pattern matching involved. Sounds like a simple walk over the IR with an IRRewriter would be largely sufficient.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to that. Its a habit in MLIR, that we (including me) have to unlearn.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here. It is one of those handy tools that you habitually gravitate towards. Will change.

@nicolasvasilache
Copy link
Contributor

@harsh-nod hmm .. I would expect a significantly simpler form and test that should involve no loops at all.

We should have a test where flash attention lowers to exactly
matmul + softmax (new op) + matmul.

We should also have a test where softmax lowers to this form (or an equivalent one but please do not mix paralell and reduction result like it is shown here).

You should adapt the above to make the ops that require it rank-polymorphic enough (maybe softmax has multiple leading parallel dimensions while flash attention does not, depends on the usage, I am unsure).

One we we have the ops introduced with these simple lowerings + tests (1 or 2 PRs for that), then we should start looking at lowering to loops, tiling etc.

All this should compose.

@stellaraccident
Copy link
Collaborator

Note, you could also do that lowering through a softmax op and have the softmax lower to the 4-6 ops version and let things compose.

Signal boosting this.

I think the right way to do this is to also implement the softmax fusion op. Otherwise, we will be doing various raisings to recover it, and this is a good chance to not entrench that further.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for the delay in review. Mostly looks fine. But a big part of the TileAndDecompose... is hard to read with some more comments. Ill do another round shortly.

Comment on lines +2765 to +2770
if (getNumInputs() != 3) {
return op->emitOpError("expected three input operands");
}
if (getNumOutputs() != 1) {
return op->emitOpError("expected one output operand");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just because they implement the LinalgExtInterface.... We could probably do it better and reduce what we use the interface for.... But the interface expects inputs and outputs. It was done to just make a few things easier (short term).

Comment on lines 2771 to 2777
const int64_t keyRank = getKeyRank();
if (keyRank != 3) {
return op->emitError("expected key tensor to have rank 3");
}
const ShapedType queryType = getQueryType();
const ShapedType keyType = getKeyType();
const ShapedType valueType = getValueType();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont use const in IREE as well.

Comment on lines 404 to 436
RewritePatternSet patterns(&getContext());
patterns.insert<ReifyFlashAttentionFwdTransform>(context);
if (failed(
applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) {
return signalPassFailure();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1 to that. Its a habit in MLIR, that we (including me) have to unlearn.

public:
using OpRewritePattern::OpRewritePattern;

LogicalResult matchAndRewrite(FlashAttentionFwdOp attnOp,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we add some IR comments about what is going on here.. .Will be easier to parse.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, I can put in a description of what is going on.

@MaheshRavishankar
Copy link
Contributor

Signal boosting comments above on why the explicit scf.for is introduced.. This seems to be code very similar to tiling. We could just use tiling for it....

@harsh-nod
Copy link
Contributor Author

What Nicolas means by "functional-style" is expressing the top-level rewrite rule as a function that takes a FlashAttnForwardOp and returns the most significant ops the rewrite produced. This can also be done in a pattern, see https://github.com/llvm/llvm-project/blob/ebfb1ddbe18457723ea62f4ecc94667a44ca4569/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp#L680 for example. Thus, one can take the result of the rewrite and directly apply another rewrite rule on it, without going over the IR and matching or the rewriter driver. If we were in a sufficiently advanced functional language, we would have been able to do a function composition of the two rewrite rules (provided they return their most significant ops) to obtain a single rewrite rule.

Thanks that makes sense. Seems like that would also be useful for transform dialect manipulations.

@harsh-nod
Copy link
Contributor Author

@harsh-nod hmm .. I would expect a significantly simpler form and test that should involve no loops at all.

We should have a test where flash attention lowers to exactly matmul + softmax (new op) + matmul.

We should also have a test where softmax lowers to this form (or an equivalent one but please do not mix paralell and reduction result like it is shown here).

You should adapt the above to make the ops that require it rank-polymorphic enough (maybe softmax has multiple leading parallel dimensions while flash attention does not, depends on the usage, I am unsure).

One we we have the ops introduced with these simple lowerings + tests (1 or 2 PRs for that), then we should start looking at lowering to loops, tiling etc.

All this should compose.

When I first read through the flash attention paper, it seemed like a somewhat straightforward application of
tiling and fusion to the attention operator (matmul + softmax + matmul), with the only caveats being that
I wasn't sure whether the tiling for the softmax operator (splitting it into a max and sum statistic) was
implemented and whether the fusion of the softmax + two matmuls would lower properly. The code that I have here
implements the tiling + fusion of the entire operator in one pass. I do agree that that treating this as a
composable problem makes sense and is definitely the way to go.

So to that end, I propose the following set of patches

  1. Creating a linalg_ext.softmax op + tests
  2. Creating a linalg_ext.attention op which lowers to (matmul + softmax + matmul) + tests
  3. Tiling of softmax op + tests
  4. Fusion of matmul + softmax + matmul + tests
  5. Performance optimization patches

What do you think?

In parallel, I propose we keep this patch as a reference because in some ways this is where we are headed
(as it already has tiled and fused all 3 operators together). I will address Alex's comments and get this
patch functional as a baseline while I put up the other patches.

@harsh-nod
Copy link
Contributor Author

Note, you could also do that lowering through a softmax op and have the softmax lower to the 4-6 ops version and let things compose.

Signal boosting this.

I think the right way to do this is to also implement the softmax fusion op. Otherwise, we will be doing various raisings to recover it, and this is a good chance to not entrench that further.

Totally agree. Will start putting up patches as per proposed plan outlined above.

@harsh-nod
Copy link
Contributor Author

Signal boosting comments above on why the explicit scf.for is introduced.. This seems to be code very similar to tiling. We could just use tiling for it....

Sure. Do you have any pointers on what functions I should be using instead?

@MaheshRavishankar
Copy link
Contributor

@harsh-nod hmm .. I would expect a significantly simpler form and test that should involve no loops at all.
We should have a test where flash attention lowers to exactly matmul + softmax (new op) + matmul.
We should also have a test where softmax lowers to this form (or an equivalent one but please do not mix paralell and reduction result like it is shown here).
You should adapt the above to make the ops that require it rank-polymorphic enough (maybe softmax has multiple leading parallel dimensions while flash attention does not, depends on the usage, I am unsure).
One we we have the ops introduced with these simple lowerings + tests (1 or 2 PRs for that), then we should start looking at lowering to loops, tiling etc.
All this should compose.

When I first read through the flash attention paper, it seemed like a somewhat straightforward application of tiling and fusion to the attention operator (matmul + softmax + matmul), with the only caveats being that I wasn't sure whether the tiling for the softmax operator (splitting it into a max and sum statistic) was implemented and whether the fusion of the softmax + two matmuls would lower properly. The code that I have here implements the tiling + fusion of the entire operator in one pass. I do agree that that treating this as a composable problem makes sense and is definitely the way to go.

Maybe we can hop on a GVC and chat about that. If softmax operator you describe below implements the tiling interface it will get tiled and fused. But question is do we need to split the max and sum statistic ? That is the part that I need to understand better.... I am happy to chat more and figure out how to orchestrate this.

So to that end, I propose the following set of patches

  1. Creating a linalg_ext.softmax op + tests
  2. Creating a linalg_ext.attention op which lowers to (matmul + softmax + matmul) + tests
  3. Tiling of softmax op + tests
  4. Fusion of matmul + softmax + matmul + tests
  5. Performance optimization patches

This would be really awesome! If you are willing to this it would be fantastic indeed (I hope you can pull in some resources to get this working!).

What do you think?

In parallel, I propose we keep this patch as a reference because in some ways this is where we are headed (as it already has tiled and fused all 3 operators together). I will address Alex's comments and get this patch functional as a baseline while I put up the other patches.

@harsh-nod
Copy link
Contributor Author

@harsh-nod hmm .. I would expect a significantly simpler form and test that should involve no loops at all.
We should have a test where flash attention lowers to exactly matmul + softmax (new op) + matmul.
We should also have a test where softmax lowers to this form (or an equivalent one but please do not mix paralell and reduction result like it is shown here).
You should adapt the above to make the ops that require it rank-polymorphic enough (maybe softmax has multiple leading parallel dimensions while flash attention does not, depends on the usage, I am unsure).
One we we have the ops introduced with these simple lowerings + tests (1 or 2 PRs for that), then we should start looking at lowering to loops, tiling etc.
All this should compose.

When I first read through the flash attention paper, it seemed like a somewhat straightforward application of tiling and fusion to the attention operator (matmul + softmax + matmul), with the only caveats being that I wasn't sure whether the tiling for the softmax operator (splitting it into a max and sum statistic) was implemented and whether the fusion of the softmax + two matmuls would lower properly. The code that I have here implements the tiling + fusion of the entire operator in one pass. I do agree that that treating this as a composable problem makes sense and is definitely the way to go.

Maybe we can hop on a GVC and chat about that. If softmax operator you describe below implements the tiling interface it will get tiled and fused. But question is do we need to split the max and sum statistic ? That is the part that I need to understand better.... I am happy to chat more and figure out how to orchestrate this.

So to that end, I propose the following set of patches

  1. Creating a linalg_ext.softmax op + tests
  2. Creating a linalg_ext.attention op which lowers to (matmul + softmax + matmul) + tests
  3. Tiling of softmax op + tests
  4. Fusion of matmul + softmax + matmul + tests
  5. Performance optimization patches

This would be really awesome! If you are willing to this it would be fantastic indeed (I hope you can pull in some resources to get this working!).

What do you think?
In parallel, I propose we keep this patch as a reference because in some ways this is where we are headed (as it already has tiled and fused all 3 operators together). I will address Alex's comments and get this patch functional as a baseline while I put up the other patches.

@MaheshRavishankar - sure let's chat on GVC and I can give more details. I will ping on the internal channel to schedule a time. This is my primary focus right now but we can chat more about this on GVC as well.

@nicolasvasilache
Copy link
Contributor

So to that end, I propose the following set of patches

  1. Creating a linalg_ext.softmax op + tests
  2. Creating a linalg_ext.attention op which lowers to (matmul + softmax + matmul) + tests
  3. Tiling of softmax op + tests
  4. Fusion of matmul + softmax + matmul + tests
  5. Performance optimization patches

Would be great, yes !

I would just modulate point 4. as I think you could just tile linalg_ext.attention directly and then lower to (matmul + softmax + matmul) that would give you some of the fusion you are after at the first level (I think .. I haven't looked at the problem deeply here but I had some basic version with 8 ops that did that (but I may also had written a batched version)).

The remaining fusion + 5. would prob. be better done with transform dialect stuff once it is easy to match these ops.

@harsh-nod
Copy link
Contributor Author

So to that end, I propose the following set of patches

  1. Creating a linalg_ext.softmax op + tests
  2. Creating a linalg_ext.attention op which lowers to (matmul + softmax + matmul) + tests
  3. Tiling of softmax op + tests
  4. Fusion of matmul + softmax + matmul + tests
  5. Performance optimization patches

Would be great, yes !

I would just modulate point 4. as I think you could just tile linalg_ext.attention directly and then lower to (matmul + softmax + matmul) that would give you some of the fusion you are after at the first level (I think .. I haven't looked at the problem deeply here but I had some basic version with 8 ops that did that (but I may also had written a batched version)).

The remaining fusion + 5. would prob. be better done with transform dialect stuff once it is easy to match these ops.

Sounds good! Will start along this plan and revisit if I run into any issues.

@nicolasvasilache
Copy link
Contributor

nicolasvasilache commented Jan 18, 2023 via email

Also fix a bug that was updating the wrong
iter args. This now lowers through the
CUDA backend.
The patch now outputs the correct result when
using the CPU backend.
} // namespace

namespace {
struct TileAndDecomposeFlashAttentionTransformPass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm trying to understand why we cannot tile using the TilingInterface and the existing tileAndDistribute pass. Is there a good place to read about that? I thought the first level of tiling would just tile along a common parallel dimension.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you look at the attention patch #11928, this get us pretty close to the original flash attention paper. Here, we do an entire row of the first matmul, then the softmax and then the second matmul. The problem is that in general that row can be quite long. So we need to tile that row and do it one tile at a time. This means also maintaining statistics about the other linalg generics. Would be more than happy to go over it in GVC. Let me know what's a good time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also here is a link to the paper: https://arxiv.org/pdf/2205.14135.pdf

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ThomasRaoux might need to use the Reduction tiling interface you added, but with slight changes to how the initialization and merge is used... thats part is to be resolved going forward.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was looking at the PartialReductionOpInterface here (https://github.com/llvm/llvm-project/blob/b41eb9601cf1aa6e33eedbb1c923b8ed0dcdcd42/mlir/include/mlir/Interfaces/TilingInterface.td#L159) but its not obvious how to extend this to multiple ops. (matmuls + linalg generics, with no fixup needed for matmuls but needed for generics)

@hanhanW hanhanW removed their request for review March 6, 2023 18:17
@benvanik benvanik removed their request for review May 1, 2023 18:31
@harsh-nod
Copy link
Contributor Author

Has been superseded by many other patches. Closing since no long required.

@harsh-nod harsh-nod closed this Jun 6, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

8 participants