-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Enable fuse consumer #85528
Conversation
@llvm/pr-subscribers-mlir-tensor @llvm/pr-subscribers-mlir Author: donald chen (cxy-1993) ChangesThis patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.
Patch is 40.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85528.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bdeab55091b9f3..2c501a3ecb14f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -310,51 +310,61 @@ def FuseIntoContainingOp :
["allowsRepeatedHandleOperands"]>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
ReportTrackingListenerFailuresOpTrait]> {
- let summary = "Fuse a producer into a containing operation.";
+ let summary = "Fuse a target into a containing operation.";
let description = [{
- Fuses the `producer_op` into the `containing_op`.
+ Fuses the `target_op` into the `containing_op`.
Returns a handle to the fused ops and the `new_containing_op`.
- The producer is typically a slice of a tileable op (i.e., implements
- TilingInterface). In that case, this transform computes the accessed
- producer slice inside of the containing op ("tile and fuse") and if required,
- creates a new containing op with outputs from the fused producer. Otherwise,
- the entire producer is cloned inside the containing op ("clone and fuse").
+ This operation supports fusion of producer or fusion of consumer. We will
+ refer to the value connecting the containing operation and the target
+ operation as the "bridge" below.
+
+ When fuse producer, the bridge is typically a slice of a tileable op (i.e.,
+ implements TilingInterface). In that case, this transform computes the
+ accessed bridge slice inside of the containing op ("tile and fuse") and
+ if required, creates a new containing op with outputs from the fused target.
+ Otherwise, the entire target is cloned inside the containing op ("clone
+ and fuse").
+
+ When fuse consumer, the bridge is the result of containing op and a operand
+ of a tileable op (i.2., implements TilingInterface). In this case, this
+ transform computes the access bridge slice inside the containing op ("tile
+ and fuse") and creates a new containing op with consumer's output.
The containing op handle must be associated with exactly one payload op. The
- producer op handle may be associated with multiple payload ops. This
- transform fuses producers one-by-one, always picking an unspecified producer
+ target op handle may be associated with multiple payload ops. This
+ transform fuses targets one-by-one, always picking an unspecified target
that has at least one use inside the containing op among the
- producers. A producer can be listed multiple times in the handle.
+ targets. A target can be listed multiple times in the handle.
- Note: If a producer has multiple uses inside the containing op, it is
+ Note: If a target has multiple uses inside the containing op, it is
currently tiled and/or cloned multiple times into the containing op.
TODO: Reuse already fused OpResults instead of tiling/cloning a second time
- when possible. Fuse producers according to a topological sorting to achieve
+ when possible. Fuse targets according to a topological sorting to achieve
the largest amount of reuse.
#### Return modes
- If at least one producer could not be fused, this operation produces a
+ If at least one target could not be fused, this operation produces a
silenceable failure. This is the case when tiling fails or when no
- producer op could be found among the remaining producers that has at least
- one use within the containing op. I.e., "producers" that are not consumed
+ target op could be found among the remaining targets that has at least
+ one use within the containing op. I.e., "targets" that are not consumed
within the containing op are rejected by this operation.
- This operation consumes the producer handle.
+ This operation consumes the target handle.
This operation only reads the containing op handle.
}];
- let arguments = (ins TransformHandleTypeInterface:$producer_op,
+ let arguments = (ins TransformHandleTypeInterface:$target_op,
TransformHandleTypeInterface:$containing_op);
let results = (outs TransformHandleTypeInterface:$fused_op,
TransformHandleTypeInterface:$new_containing_op);
- let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
+ let assemblyFormat = "$target_op `into` $containing_op attr-dict "
" `:` functional-type(operands, results)";
let builders = [
- OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+ OpBuilder<(ins "Value":$targetOp, "Value":$containingOp)>
];
}
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return iterator domain position computed by the
+ input operand position.
+ }],
+ /*retType=*/"LogicalResult",
+ /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVector<OpFoldResult> &":$iterDomainSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand position.
+
+ Generates the IR that generate the tiled implementation of an
+ operation from operand position. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the position of the
+ operand required. This method enables fusion consumer by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandPosition",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..ecffb910b236e8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -546,9 +546,9 @@ LogicalResult transform::FuseOp::verify() {
void transform::FuseIntoContainingOp::build(OpBuilder &builder,
OperationState &result,
- Value producerOp,
+ Value targetOp,
Value containingOp) {
- result.addOperands({producerOp, containingOp});
+ result.addOperands({targetOp, containingOp});
auto resultType = transform::AnyOpType::get(builder.getContext());
result.addTypes({resultType, resultType});
}
@@ -631,6 +631,223 @@ static Operation *replaceForAllWithNewSignature(
return newforallOp;
}
+static std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseParallelInsertSlice(RewriterBase &rewriter, Diagnostic &diag,
+ Operation *consumerOp, Operation *containingOp) {
+ // Check consumer has tiling interface.
+ LLVM_DEBUG(DBGS() << "Try to fuse a consumer\n");
+ auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+ if (!tileableConsumer) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer is not a TileableInterface: " << *consumerOp;
+ return {};
+ }
+
+ // Check containing op is "scf::ForallOp".
+ auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+ if (!forallOp) {
+ diag.attachNote(containingOp->getLoc())
+ << "containing op is not a scf.forall: " << containingOp;
+ return {};
+ }
+
+ // Check dominance.
+ DominanceInfo domInfo(
+ containingOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>());
+ if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+ return v.getDefiningOp() != containingOp &&
+ !domInfo.properlyDominates(v, containingOp);
+ })) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer's operand can't dominate containing op";
+ return {};
+ }
+
+ // Check consumer don't use more than one result of containingOp.
+ Value bridge(nullptr);
+ SmallVector<unsigned> operandNums;
+ for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+ if (opd.getDefiningOp() == containingOp) {
+ operandNums.push_back(idx);
+ if (!bridge) {
+ bridge = opd;
+ } else if (bridge != opd) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer's operand use more than one containingOp's result";
+ return {};
+ }
+ }
+ }
+
+ // TODO: We have to init result of consumer before scf.forall, use
+ // DestinationStyleOpInterface to get result shape from init for now.
+ // Add support for other op such as op has InferTypeOpInterface.
+ // Check consumer has DestinationStyleOpInterface.
+ auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+ if (!dstOp) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer op should have destination style op interface";
+ return {};
+ }
+
+ // Check consumer doon't use scf.forall's output as init.
+ SmallVector<Value> dpsInits = llvm::to_vector<4>(
+ llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+ if (llvm::is_contained(dpsInits, bridge)) {
+ diag.attachNote(consumerOp->getLoc())
+ << "consumer op take result of scf.forall as init";
+ return {};
+ }
+
+ // Check result was inserted only once.
+ int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+ auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+ scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+ tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+ for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+ auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+ if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+ if (!targetInsertOp) {
+ targetInsertOp = parallelInsertSliceOp;
+ } else {
+ diag.attachNote(containingOp->getLoc())
+ << "containingOp's result inserted multi time";
+ return {};
+ }
+ }
+ }
+
+ if (!targetInsertOp) {
+ diag.attachNote(containingOp->getLoc())
+ << "containingOp's result was not inserted";
+ return {};
+ }
+
+ SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+ SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+ SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+ // Check all insert stride is 1.
+ if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+ if (auto attr = foldRes.dyn_cast<Attribute>()) {
+ return cast<IntegerAttr>(attr).getInt() != 1;
+ }
+ return true;
+ })) {
+ diag.attachNote(containingOp->getLoc())
+ << "containingOp's result yield with stride";
+ return {};
+ }
+
+ Location loc = forallOp.getLoc();
+ OpBuilder::InsertionGuard g(rewriter);
+ rewriter.setInsertionPoint(terminatorOp);
+
+ SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+ // Try to get iter domain position from input position.
+ if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+ iterDomainSizes))) {
+ diag.attachNote(consumerOp->getLoc())
+ << "can't get iter domain position from input position";
+ return {};
+ }
+
+ // Try to get all containing op result's position from iter domain position.
+ llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+ llvm::SmallVector<OpFoldResult>>>
+ resultPositions(consumerOp->getNumResults());
+ for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+ if (failed(tileableConsumer.getResultTilePosition(
+ rewriter, idx, iterDomainOffsets, iterDomainSizes,
+ resultPositions[idx].first, resultPositions[idx].second))) {
+ diag.attachNote(consumerOp->getLoc())
+ << "can't get result domain position from iter domain position";
+ return {};
+ }
+ }
+
+ // All check passed, try to fuse consumer.
+ // Create tiled implementation of containing op.
+ FailureOr<TilingResult> tileAndFuseResult =
+ tileableConsumer.getTiledImplementationFromOperandPosition(
+ rewriter, operandNums.front(), offsets, sizes);
+ if (failed(tileAndFuseResult)) {
+ diag.attachNote(consumerOp->getLoc()) << "get tiled implementation failed";
+ return {};
+ }
+
+ auto tiledOps = tileAndFuseResult->tiledOps;
+ if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+ diag.attachNote(tileableConsumer->getLoc())
+ << "failed to tile consumer op: " << *tileableConsumer;
+ return {};
+ }
+
+ // Replace tiled op's operand .
+ for (auto operandNum : operandNums) {
+ tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+ }
+ rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+ [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return forallOp->isProperAncestor(op);
+ });
+
+ SmallVector<Value> newOuts(forallOp.getOutputs());
+ newOuts.append(dpsInits);
+
+ // Create new scf.forall op.
+ rewriter.setInsertionPoint(forallOp);
+ auto newforallOp = rewriter.create<scf::ForallOp>(
+ loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+ forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+ rewriter.eraseBlock(newforallOp.getBody());
+ newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+ for (auto v : dpsInits) {
+ newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+ Operation *op = use.getOwner();
+ return newforallOp->isProperAncestor(op);
+ });
+ }
+
+ // Fix terminator.
+ scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+ SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+ newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+ Operation *firstYieldOp = yieldingOps.front();
+ rewriter.setInsertionPoint(firstYieldOp);
+ auto bbArgs = newforallOp.getBody()->getArguments();
+ for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+ SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+ rewriter.getIndexAttr(1));
+ rewriter.create<tensor::ParallelInsertSliceOp>(
+ firstYieldOp->getLoc(), v,
+ bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+ resultPositions[idx].first, resultPositions[idx].second, strides);
+ }
+
+ // Replace the result of forall and consumer op.
+ for (auto result : llvm::enumerate(forallOp.getResults())) {
+ rewriter.replaceAllUsesWith(result.value(),
+ newforallOp->getResult(result.index()));
+ }
+
+ for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+ rewriter.replaceAllUsesWith(
+ consumerResult.value(),
+ newforallOp->getResult(forallOp.getOutputs().size() +
+ consumerResult.index()));
+ }
+
+ return std::make_tuple(tileAndFuseResult->tiledOps, newforallOp);
+}
+
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
@@ -880,7 +1097,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
SmallVector<Operation *> fusedOps;
- auto producerOps = state.getPayloadOps(getProducerOp());
+ auto targetOps = state.getPayloadOps(getTargetOp());
auto containingOps = state.getPayloadOps(getContainingOp());
if (!llvm::hasSingleElement(containingOps)) {
return emitDefiniteFailure()
@@ -890,69 +1107,115 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
Operation *containingOp = *containingOps.begin();
// If nothing to fuse, propagate success.
- if (std::empty(producerOps)) {
+ if (std::empty(targetOps)) {
results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
return DiagnosedSilenceableFailure::success();
}
- // Helper function to find the next producer that should be fused. Take any
- // producer that has a use inside the containing op.
- SetVector<Operation *> remainingProducers(producerOps.begin(),
- producerOps.end());
- auto getNextProducer = [&]() -> FailureOr<Operation *> {
- for (const auto &it : enumerate(remainingProducers)) {
- Operation *producerOp = it.value();
- // The containing op may be a user of producerOp: use isAncestor.
+ // Helper function to find the next target that should be fused. Take any
+ // target that has a use inside the containing op. Return target operation
+ // and a bool variable indicate if this target op is a producer.
+ SetVector<Operation *> remainingTargets(targetOps.begin(), targetOps.end());
+ auto getNextTarget = [&]() -> FailureOr<std::pair<Operation *, bool>> {
+ for (const auto &it : enumerate(remainingTargets)) {
+ Operation *targetOp = it.value();
+ // The containing op may be a user of targetOp: use isAncestor.
int64_t numUsesInContainingOp =
- llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+ llvm::count_if(targetOp->getUsers(), [&](Operation *op) {
return containingOp->isAncestor(op);
});
// TODO: When resolving the TODO below (no duplicate ops), take an op
- // that has no use among the remaining producers. This is a topological
+ // that has no use among the remaining targets. Th...
[truncated]
|
@MaheshRavishankar Could you please help me review this patch? |
7ca5ae5
to
c79772d
Compare
Thanks! Yes, I will definitely review the patch. I am in a middle of a crunch week. I'll probably get to it in about 10 days if that is ok |
Don't worry, there's no rush. Take your time. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes. I looked at the changes to the TilingInterface
... those look good. Just left a few nits.
I missed this, but could you instead look at the implementation in https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp instead. That file does tile consumer and fuse producer using the LoopLikeOpInterface
and works for both scf.for
and scf.forall
. The path that is here in LinalgTransforms
is supposed to be deprecated (by me when I get some time). Last time I touched it, I made things fairly modular and readable (afaik) and just following that but instead fusing with consumers would be a great help. If added here, it will just be on my list of things to deprecate and move it to the above file anyway.
I havent looked at the LinalgTransforms.*
changes (cause I am also not an expert in transform dialect). I think it would be better to make these methods available outside of transform dialect to be callable as standalone utility functions (like is done in the implementation I linked to above).
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> { | |||
return {}; | |||
}] | |||
>, | |||
InterfaceMethod< | |||
/*desc=*/[{ | |||
Method to return iterator domain position computed by the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this isnt as clear. Please see the description for getResultTilePosition
maybe something closer to that (or improve both if you have a better description).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
input operand position. | ||
}], | ||
/*retType=*/"LogicalResult", | ||
/*methodName=*/"getIterDomainTilePositionFromOperandPosition", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Maybe rename to getIterationDomainTileFromOperandTile
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE rename to getIterationDomainTileFromOperandTile
Method to generate the tiled implementation of an operation from | ||
operand position. | ||
|
||
Generates the IR that generate the tiled implementation of an |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit :
Generates the IR that computes the tiled implementation of an operation based on
the operand tile
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> { | |||
return {}; | |||
}] | |||
>, | |||
InterfaceMethod< |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Just a matter of reading. At least for me, fusing producer with consumer is more natural and is then easy to make the leap to fuse consumer with producer. So maybe position these methods textually after the corresponding methods that deal with the producer -> consumer fusion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE, move methods after profucer fusion
`getTiledImplementation` which generates the tiled | ||
implementation of the operation given a tile of the | ||
iteration space. This method generates a tiled | ||
implementation of the operation based on the position of the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
This method generates a tiled implementation of the operation based on the tile of the operand required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface | |||
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())}; | |||
} | |||
|
|||
void getMappedOffsetAndSize(Operation *op, OpBuilder &b, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: just make the operand LinalgOp linalgOp
type instead of Operation *op
. You can then do
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your comments! It's crunch week for me now, I'll address your review comments ASAP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DONE
Hi, some of this was used in #88712 . THere are two things needed here..
Two co-ordinate between these two PRs, can we first land the interface methods. The implementation of (2) here for just scf.forall has some of the same issues that I highlighted in the other PR... You can wait for that to land (and help contribute/review) and then add (3). So can we split this PR first to just land the interface methods? |
Thank you for your review and reminders. I have already extracted the interface part separately and correspondingly addressed your review comments. Could you please help review it again, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This needs tests.
implementation of the operation given a tile of the | ||
iteration space. This method generates a tiled | ||
implementation of the operation based on the tile of the | ||
operand required. This method enables fusion consumer by using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
operand required. This method enables fusion consumer by using | |
operand required. This method enables consumer fusion by using |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
"unsigned":$operandNumber, | ||
"ArrayRef<OpFoldResult> ":$offsets, | ||
"ArrayRef<OpFoldResult> ":$sizes, | ||
"SmallVector<OpFoldResult> &":$iterDomainOffsets, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see that this is cargo-culted from the method above, but we should be using SmallVectorImpl
to avoid hardcoding the implicit number of stack elements. SmallVector
is in fact SmallVector<6>
in this case.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agree. I changed parameter type to SmallVectorImpl as well as the code I cargo-culted from.
Method to return the position of iteration domain tile computed by the | ||
tiled operation. | ||
}], | ||
/*retType=*/"LogicalResult", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: prefer globally-qualified types in .td, ::mlir::LogicalResult
. Generated code may be #included in another namespace.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice catch, done
ArrayRef<OpFoldResult> sizes, | ||
SmallVector<OpFoldResult> &mappedOffsets, | ||
SmallVector<OpFoldResult> &mappedSizes) const { | ||
auto numLoops = linalgOp.getNumLoops(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: expand auto
unless the type is obvious from line-level context, e.g., there is a cast such as below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
for (const auto &range : llvm::enumerate(iterationDomain)) { | ||
mappedOffsets[range.index()] = range.value().offset; | ||
mappedSizes[range.index()] = range.value().size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for (const auto &range : llvm::enumerate(iterationDomain)) { | |
mappedOffsets[range.index()] = range.value().offset; | |
mappedSizes[range.index()] = range.value().size; | |
for (auto &&[index, value] : llvm::enumerate(iterationDomain)) { | |
mappedOffsets[index] = value.offset; | |
mappedSizes[index] = value.size; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
mappedSizes[range.index()] = range.value().size; | ||
} | ||
} | ||
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same as above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
AffineMap indexingMap = | ||
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber)); | ||
if (!indexingMap.isProjectedPermutation()) { | ||
return op->emitOpError( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: emitOpError
is for errors that are related to the validity of the op, which this is not. Use emitError
instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
Thanks for your time. I have modified the code base on your comments(as well as the original code). As for tests, I have original implemented a transform op in linalg dialect, and add tests. After the discussion above, we decided to only add the interface part for now. The tests for these parts require the implementation of the transform op, so can we implement the interface first and add tests for the transform op later? |
@Abhishek-Varma please take a look at this PR and approve if this looks fine to you and you can land your PR on top of this. For me this looks fine.. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this @cxy-1993 - this has been a really good nudge for fusing consumers!
A few comments to address before merging. LGTM otherwise.
Thanks for the comment @Abhishek-Varma. The code have modified according to your comment. I don't have write access to this repo, please help me merge this patch into master. |
This patch adds support for consumer fusion to the tiling interface. - Add interface method 'getIterationDomainTileFromOperandTile' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandTile' to tiling interface which generate tiled implementation according to operand position.
@Abhishek-Varma, @cxy-1993, @ftynse gave a clear review that this needs tests. It is not acceptable to land randomly untested code! Please revert, add tests and resend for review. |
Definitely needs tests, but I didn't see a need for revert. Should be fine to add tests after the fact. @cxy-1993 and @Abhishek-Varma could we add some tests here |
For practicality reasons, I'd prefer a revert and a new PR that includes tests as opposed to just adding tests. This will help us, the reviewers, ascertain that the tests are actually testing what they should be testing, there is enough coverage, etc. It also makes it easy to bisect and revert the change should the need arise. |
yeah, let's instead be consistent with what we've been doing so far: https://mlir.llvm.org/getting_started/TestingGuide/ Additionally, I believe this should also not have landed for the same reason as I am objecting to #88712 I'll just copy paste here: Now, I am afraid I cannot subscribe to any large code changes to tiling transforms until #77874 is addressed to my satisfaction. I already gave a pass 6 months ago and another one 3 months ago when things were supposed to be addressed "in short order". We're now past tech-debt reduction time: first let's address [mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. #77874, |
We can't block code progress because of that. I have been reviewing all the code and have been keeping track of all the changes. This is also a strict addition that isn't part of what needs to be undone from the legacy code. It is still on my plate to address this. Everyone has time commitments that they have to balance. I will get to it in short order, but we can't hold things hostage because of that. |
Let's keep the bar high. Alex's suggestion is the right one: revert and land with appropriate tests.
Can we get a f2f to talk about this separately? I've had my eye on this for a long time too, and there is a significant hairball that (believe it or not) is being undone. It is just hard and taking more iteration than expected. (I imagine if we talked through it, we might be able to see a shortcut) |
Reverts #85528. This was committed without tests, despite reviewers requesting tests to be added. The post-commit discussion leans towards revert, which would be consistent with the policy.
Reverted in f220c35. Please add tests, open a new PR, and request review from the same people. |
New PR: #89893, just add tests compare to this version. Please help me review it, thanks. @MaheshRavishankar @ftynse @Abhishek-Varma @nicolasvasilache |
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.