Skip to content

Commit

Permalink
[mlir][tensor] Fold tensor.cast into tensor.collapse_shape op
Browse files Browse the repository at this point in the history
This commit folds a `tensor.cast` op into a `tensor.collapse_shape` op
when following two conditions meet:
1. the `tensor.collapse_shape` op consumes result of the `tensor.cast` op.
2. `tensor.cast` op casts to a more dynamic version of the source tensor.
This is added as a canonicalization pattern in `tensor.collapse_shape` op.

Signed-Off-By: Gaurav Shukla <gaurav@nod-labs.com>

Reviewed By: mravishankar

Differential Revision: https://reviews.llvm.org/D130650
  • Loading branch information
Shukla-Gaurav authored and Prashant Kumar committed Jul 28, 2022
1 parent 8a61749 commit 7d6ef5c
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 4 deletions.
40 changes: 36 additions & 4 deletions mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
Expand Up @@ -928,6 +928,36 @@ struct FoldReshapeWithFromElements : OpRewritePattern<TensorReshapeOp> {
}
};

// Fold CastOp into CollapseShapeOp when adding static information.
struct FoldCollapseOfCastOp : public OpRewritePattern<CollapseShapeOp> {
using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;

LogicalResult matchAndRewrite(CollapseShapeOp collapseShapeOp,
PatternRewriter &rewriter) const override {
auto castOp = collapseShapeOp.getSrc().getDefiningOp<tensor::CastOp>();
if (!tensor::canFoldIntoConsumerOp(castOp))
return failure();

RankedTensorType srcType =
castOp.getSource().getType().cast<RankedTensorType>();
RankedTensorType newResultType = computeTensorReshapeCollapsedType(
srcType, collapseShapeOp.getReassociationMaps());

if (newResultType == collapseShapeOp.getResultType()) {
rewriter.updateRootInPlace(collapseShapeOp, [&]() {
collapseShapeOp.getSrcMutable().assign(castOp.getSource());
});
} else {
auto newOp = rewriter.create<CollapseShapeOp>(
collapseShapeOp.getLoc(), newResultType, castOp.getSource(),
collapseShapeOp.getReassociation());
rewriter.replaceOpWithNewOp<tensor::CastOp>(
collapseShapeOp, collapseShapeOp.getResultType(), newOp);
}
return success();
}
};

} // namespace

void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
Expand All @@ -940,10 +970,12 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,

void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>>(context);
results
.add<ComposeReassociativeReshapeOps<CollapseShapeOp>,
ComposeCollapseOfExpandOp<CollapseShapeOp, ExpandShapeOp>,
FoldReshapeWithConstant<CollapseShapeOp>,
FoldReshapeWithFromElements<CollapseShapeOp>, FoldCollapseOfCastOp>(
context);
}

OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
Expand Down
14 changes: 14 additions & 0 deletions mlir/test/Dialect/Tensor/canonicalize.mlir
Expand Up @@ -673,6 +673,20 @@ func.func @compose_expand_of_expand_of_zero_dim(%arg0 : tensor<f32>)

// -----

// CHECK-LABEL: func.func @collapse_of_cast(
// CHECK-SAME: %[[IN:.*]]: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
// CHECK-NEXT: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[IN]] {{\[}}[0, 1], [2]] : tensor<8x12x32xf32> into tensor<96x32xf32>
// CHECK-NEXT %[[CAST:.*]] = tensor.cast %[[COLLAPSE]] : tensor<96x32xf32> to tensor<?x32xf32>
// CHECK-NEXT return %[[CAST]] : tensor<?x32xf32>
func.func @collapse_of_cast(%t: tensor<8x12x32xf32>) -> tensor<?x32xf32> {
%0 = tensor.cast %t : tensor<8x12x32xf32> to tensor<?x?x?xf32>
%1 = tensor.collapse_shape %0 [[0, 1], [2]] : tensor<?x?x?xf32> into tensor<?x?xf32>
%2 = tensor.cast %1 : tensor<?x?xf32> to tensor<?x32xf32>
return %2 : tensor<?x32xf32>
}

// -----

func.func @fold_collapse_of_expand(%arg0 : tensor<12x4xf32>) -> tensor<12x4xf32> {
%0 = tensor.expand_shape %arg0 [[0, 1], [2]]
: tensor<12x4xf32> into tensor<3x4x4xf32>
Expand Down

0 comments on commit 7d6ef5c

Please sign in to comment.