Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions mlir/lib/Dialect/MemRef/Transforms/ResolveShapedTypeResultDims.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
Expand Down Expand Up @@ -131,11 +132,25 @@ struct IterArgsToInitArgs : public OpRewritePattern<tensor::DimOp> {
auto blockArg = dyn_cast<BlockArgument>(dimOp.getSource());
if (!blockArg)
return failure();
auto loopLikeOp =
dyn_cast<LoopLikeOpInterface>(blockArg.getParentBlock()->getParentOp());
if (!loopLikeOp)
// TODO: Enable this for loopLikeInterface. Restricting for scf.for
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would make sense to move this pattern to the same file.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that pattern is wrong for the reason why this is being restricted to forall only. The init arg shape can vary from iteration to iteration. There it is added as a canonicalization which is even worse.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern has an additional isShapePreserving check that makes it safe. Btw, that check is quite conservative, it could be improved by checking for DestinationStyleOpInterface instead of hard-coding a few ops in the analysis.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaheshRavishankar @matthias-springer I'll be moving the pattern from LoopCanonicalization to resolveShapedTypeResultDims and updating that to process scf.forall. That sounds reasonable to me.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll be merging this and refactoring in upcoming PRs.

// because the init args shape might change in the loop body.
// For e.g.:
// ```
// %0 = tensor.empty(%c1) : tensor<?xf32>
// %r = scf.for %iv = %c0 to %c10 step %c1 iter_args(%arg0 = %0) ->
// tensor<?xf32> {
// %1 = tensor.dim %arg0, %c0 : tensor<?xf32>
// %2 = arith.addi %c1, %1 : index
// %3 = tensor.empty(%2) : tensor<?xf32>
// scf.yield %3 : tensor<?xf32>
// }
//
// ```
auto forAllOp =
dyn_cast<scf::ForallOp>(blockArg.getParentBlock()->getParentOp());
if (!forAllOp)
return failure();
Value initArg = loopLikeOp.getTiedLoopInit(blockArg)->get();
Value initArg = forAllOp.getTiedLoopInit(blockArg)->get();
rewriter.modifyOpInPlace(
dimOp, [&]() { dimOp.getSourceMutable().assign(initArg); });
return success();
Expand Down
Loading