diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index c1210eef4e589..4d03b7b2b2064 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -1706,14 +1706,20 @@ struct ShapeOfOpToConstShapeOp : public OpRewritePattern { auto type = llvm::dyn_cast(op.getArg().getType()); if (!type || !type.hasStaticShape()) return failure(); + + Type resultType = op.getResult().getType(); Location loc = op.getLoc(); + Type constResType = + isa(resultType) + ? resultType + : RankedTensorType::get({type.getRank()}, rewriter.getIndexType()); Value constShape = - ConstShapeOp::create(rewriter, loc, + ConstShapeOp::create(rewriter, loc, constResType, rewriter.getIndexTensorAttr(type.getShape())) .getResult(); - if (constShape.getType() != op.getResult().getType()) - constShape = tensor::CastOp::create(rewriter, loc, - op.getResult().getType(), constShape); + if (constShape.getType() != resultType) + constShape = + tensor::CastOp::create(rewriter, loc, resultType, constShape); rewriter.replaceOp(op, constShape); return success(); } diff --git a/mlir/test/Dialect/Shape/canonicalize.mlir b/mlir/test/Dialect/Shape/canonicalize.mlir index f3c25b8c8100e..22add87ff3ed4 100644 --- a/mlir/test/Dialect/Shape/canonicalize.mlir +++ b/mlir/test/Dialect/Shape/canonicalize.mlir @@ -1626,3 +1626,13 @@ func.func @shape_of_0d(%arg0: tensor) -> tensor { %0 = shape.shape_of %arg0 : tensor -> tensor return %0 : tensor } + +// ----- + +// CHECK-LABEL: func @shape_of_static_with_shape_result( +func.func @shape_of_static_with_shape_result(%arg0: tensor<3xf32>) -> !shape.shape { + // CHECK: %[[const:.*]] = shape.const_shape [3] : !shape.shape + // CHECK: return %[[const]] : !shape.shape + %0 = shape.shape_of %arg0 : tensor<3xf32> -> !shape.shape + return %0 : !shape.shape +}