-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[MLIR] Add fusability query to TilingInterface #166502
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||
|---|---|---|---|---|
|
|
@@ -359,6 +359,52 @@ struct LinalgOpTilingInterface | |||
| /// Inline the op payload and store the result. | ||||
| return inlinePayload(builder, linalgOp, ivs, indexedValues); | ||||
| } | ||||
|
|
||||
| bool isOpFusableWithConsumerSlice(Operation *op, unsigned resultNumber, | ||||
| ArrayRef<OpFoldResult> offsets, | ||||
| ArrayRef<OpFoldResult> sizes) const { | ||||
| return !cast<LinalgOp>(op).getShapesToLoopsMap(); | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am not sure I follow the logic here.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The op's iteration space isn't invertible if the ShapesToLoopsMap doesn't exist (e.g. in the case of exotic indexing maps that include mods and divs). In the case of linalg ops support for such maps is unimplemented (we could emit loops directly) but we still need to reflect the unimplemented case here. |
||||
| } | ||||
|
|
||||
| bool isOpFusableWithProducerSlices( | ||||
| Operation *op, ArrayRef<unsigned> operandNumbers, | ||||
| ArrayRef<SmallVector<OpFoldResult>> allOffsets, | ||||
| ArrayRef<SmallVector<OpFoldResult>> allSizes) const { | ||||
|
|
||||
| auto linalgOp = cast<LinalgOp>(op); | ||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again I am not sure about why we need this. In theory it is feasible to fuse with any producer, irrespective of what slice of the operand is used. This seems like a cost function, rather than a structural requirement for fusion.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The ShapesToLoopsMap check is for the same reason as above.
This is certainly not true for all operations. Not all tensor operands and dimensions are partitionable. For example, the inner most dimension of a
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, got confused by the naming. This is for consumer fusion cases I think. Isnt this the same logic that should be here https://github.com/iree-org/llvm-project/blob/7e88b401a4c731fa04524ccbc1542fce56609507/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp#L156 ? Maybe we use that/common this out?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I can make it common. When I first wrote this I didn't understand why we were ignoring non-AffineDimExprs in the above matcher and made this one a little more restrictive but I can update the same there.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is kind of based on how the ShapesToLoopsMap works. you could have dimensions used in any arbitrary AffineExpr, but the shapes of the loop are determined by those accesses that are accessed through AffineDimExpr. I am doing the same logic in that method
|
||||
| SmallVector<AffineMap> indexingMaps = | ||||
| llvm::map_to_vector(operandNumbers, [&](unsigned operandNumber) { | ||||
| OpOperand &opOperand = linalgOp->getOpOperand(operandNumber); | ||||
| return linalgOp.getMatchingIndexingMap(&opOperand); | ||||
| }); | ||||
| // First verify that the iteration domain on operand subranges is well | ||||
| // defined. | ||||
| if (!linalgOp.getShapesToLoopsMap()) | ||||
| return false; | ||||
| // Next verify that operand slices are consistent. | ||||
| DenseMap<unsigned, OpFoldResult> mappedOffsets, mappedSizes; | ||||
| for (auto [indexingMap, offsets, sizes] : | ||||
| llvm::zip_equal(indexingMaps, allOffsets, allSizes)) { | ||||
| for (auto [resultExpr, offset, size] : | ||||
| llvm::zip_equal(indexingMap.getResults(), offsets, sizes)) { | ||||
| auto dimExpr = dyn_cast<AffineDimExpr>(resultExpr); | ||||
| if (!dimExpr) | ||||
| return false; | ||||
| unsigned position = dimExpr.getPosition(); | ||||
| auto it = mappedOffsets.find(position); | ||||
| if (it != mappedOffsets.end()) { | ||||
| OpFoldResult seenOffset = it->second; | ||||
| OpFoldResult seenSize = mappedSizes.lookup(position); | ||||
| if (seenOffset != offset || seenSize != size) | ||||
| return false; | ||||
| } else { | ||||
| mappedOffsets[position] = offset; | ||||
| mappedSizes[position] = size; | ||||
| } | ||||
| } | ||||
| } | ||||
| return true; | ||||
| } | ||||
| }; | ||||
|
|
||||
| //===----------------------------------------------------------------------===// | ||||
|
|
||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,49 @@ | ||
| // RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics | ||
|
|
||
| func.func @fusable_with_matching_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c10 = arith.constant 10 : index | ||
| %c20 = arith.constant 20 : index | ||
|
|
||
| %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> | ||
| %slice1 = tensor.insert_slice %arg1 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> | ||
|
|
||
| // expected-remark @+1 {{can be fused with producer tensor.insert_slice ops}} | ||
| %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) | ||
| outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> | ||
|
|
||
| return %result : tensor<100x200xf32> | ||
| } | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg: !transform.any_op) { | ||
| %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op | ||
| transform.test.query_producer_fusability %add : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } | ||
|
|
||
| // ----- | ||
|
|
||
| func.func @not_fusable_with_different_offsets(%arg0: tensor<10x20xf32>, %arg1: tensor<10x20xf32>, %dest: tensor<100x200xf32>) -> tensor<100x200xf32> { | ||
| %c0 = arith.constant 0 : index | ||
| %c10 = arith.constant 10 : index | ||
| %c20 = arith.constant 20 : index | ||
|
|
||
| %slice0 = tensor.insert_slice %arg0 into %dest[%c0, %c0] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> | ||
| %slice1 = tensor.insert_slice %arg1 into %dest[%c10, %c20] [10, 20] [1, 1] : tensor<10x20xf32> into tensor<100x200xf32> | ||
|
|
||
| // expected-remark @+1 {{cannot be fused with producer tensor.insert_slice ops}} | ||
| %result = linalg.add ins(%slice0, %slice1 : tensor<100x200xf32>, tensor<100x200xf32>) | ||
| outs(%dest : tensor<100x200xf32>) -> tensor<100x200xf32> | ||
|
|
||
| return %result : tensor<100x200xf32> | ||
| } | ||
|
|
||
| module attributes {transform.with_named_sequence} { | ||
| transform.named_sequence @__transform_main(%arg: !transform.any_op) { | ||
| %add = transform.structured.match ops{["linalg.add"]} in %arg : (!transform.any_op) -> !transform.any_op | ||
| transform.test.query_producer_fusability %add : !transform.any_op | ||
| transform.yield | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the consumer slice API has been modified to handle multiple results at the same time. I think this needs to do the same.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
getIterationDomainTileFromResultTilestill only looks at a single result at a time, so this does too. If producing multiple results via a single tiling call is implemented I would expect the existing method to be updated to reflect that too, so I'd rather keep them consistent right now.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok, I think I am confused with the terminology. This is basically saying if you can fuse an operation with its consumer slice (as the name correctly reads), but this is for producer fusion, not the consumer fusion cases.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consumer is in reference to the slice here, so
This:
tilableOp -> consumer
Consumer Fusion:
LoopOp -> consumer tilable op
Consumer is unsurprisingly overloaded. I'm open to better naming suggestions if you have one (ditto below).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nope your naming makes sense.