Skip to content

Commit

Permalink
[MLIR][Shape] Canonicalize casted dynamic extent tensor
Browse files Browse the repository at this point in the history
Differential Revision: https://reviews.llvm.org/D99161
  • Loading branch information
frgossen committed Mar 29, 2021
1 parent c6e5c46 commit 630afc6
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 2 deletions.
34 changes: 33 additions & 1 deletion mlir/lib/Dialect/Shape/IR/Shape.cpp
Expand Up @@ -987,11 +987,43 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
return success();
}
};

// Canonicalize
// ```
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
// ```
// to
// ```
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// ```
struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
auto ty = op.getType().dyn_cast<RankedTensorType>();
if (!ty || ty.getRank() != 1)
return failure();

auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
return failure();

// Argument type must be ranked and must not conflict.
auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
return failure();

rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
return success();
}
};
} // namespace

void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfWithTensor>(context);
patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
46 changes: 45 additions & 1 deletion mlir/test/Dialect/Shape/canonicalize.mlir
Expand Up @@ -648,7 +648,7 @@ func @f() {
// CHECK: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, 8] : !shape.shape
%cs2 = shape.const_shape [1, -1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
Expand Down Expand Up @@ -1144,3 +1144,47 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
"use"(%0) : (tensor<?xindex>) -> ()
return
}

// -----

// CHECK-LABEL: @casted_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
%1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
}

// -----

// CHECK-LABEL: @casted_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
// CHECK: return %[[RESULT]] : tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}

// -----

// CHECK-LABEL: @casted_extent_tensor
func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}

// -----

// CHECK-LABEL: @casted_extent_tensor
func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}

0 comments on commit 630afc6

Please sign in to comment.