diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index a1419322afb3a8..bb7ed5cf05cec5 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -987,11 +987,43 @@ struct ShapeOfWithTensor : public OpRewritePattern { return success(); } }; + +// Canonicalize +// ``` +// %0 = shape.shape_of %arg : tensor -> tensor<3xindex> +// %1 = tensor.cast %0 : tensor<3xindex> to tensor +// ``` +// to +// ``` +// %1 = shape.shape_of %arg : tensor -> tensor +// ``` +struct ShapeOfCastedExtentTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::CastOp op, + PatternRewriter &rewriter) const override { + auto ty = op.getType().dyn_cast(); + if (!ty || ty.getRank() != 1) + return failure(); + + auto shapeOfOp = op.source().getDefiningOp(); + if (!shapeOfOp) + return failure(); + + // Argument type must be ranked and must not conflict. + auto argTy = shapeOfOp.arg().getType().dyn_cast(); + if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank())) + return failure(); + + rewriter.replaceOpWithNewOp(op, ty, shapeOfOp.arg()); + return success(); + } +}; } // namespace void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { - patterns.add(context); + patterns.add(context); } //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index 39f17e9d253f6d..b0c12ea0b14995 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -648,7 +648,7 @@ func @f() { // CHECK: shape.cstr_broadcastable // CHECK-NEXT: consume.witness // CHECK-NEXT: return - %cs0 = shape.const_shape [8, 1] : !shape.shape + %cs0 = shape.const_shape [8, 1] : !shape.shape %cs1 = shape.const_shape [1, 8] : !shape.shape %cs2 = shape.const_shape [1, -1] : !shape.shape %0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape @@ -1144,3 +1144,47 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) { "use"(%0) : (tensor) -> () return } + +// ----- + +// CHECK-LABEL: @casted_extent_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor +func @casted_extent_tensor(%arg : tensor) -> tensor { + // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor + // CHECK: return %[[RESULT]] : tensor + %0 = shape.shape_of %arg : tensor -> tensor<3xindex> + %1 = tensor.cast %0 : tensor<3xindex> to tensor + return %1 : tensor +} + +// ----- + +// CHECK-LABEL: @casted_extent_tensor +// CHECK-SAME: (%[[ARG:.*]]: tensor) -> tensor<3xindex> +func @casted_extent_tensor(%arg : tensor) -> tensor<3xindex> { + // CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor -> tensor<3xindex> + // CHECK: return %[[RESULT]] : tensor<3xindex> + %0 = shape.shape_of %arg : tensor -> tensor + %1 = tensor.cast %0 : tensor to tensor<3xindex> + return %1 : tensor<3xindex> +} + +// ----- + +// CHECK-LABEL: @casted_extent_tensor +func @casted_extent_tensor(%arg : tensor) -> tensor<3xindex> { + // CHECK: tensor.cast %{{.*}} : tensor to tensor<3xindex> + %0 = shape.shape_of %arg : tensor -> tensor + %1 = tensor.cast %0 : tensor to tensor<3xindex> + return %1 : tensor<3xindex> +} + +// ----- + +// CHECK-LABEL: @casted_extent_tensor +func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> { + // CHECK: tensor.cast %{{.*}} : tensor to tensor<3xindex> + %0 = shape.shape_of %arg : tensor<*xf32> -> tensor + %1 = tensor.cast %0 : tensor to tensor<3xindex> + return %1 : tensor<3xindex> +}