diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp index c8e9ad8bd3346..d9d2164adb1ad 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp @@ -1612,19 +1612,25 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( MLIRContext *context, ::std::optional location, TileOp::Adaptor adaptor, SmallVectorImpl &inferredReturnShapes) { - DenseIntElementsAttr multiplesAttr; - if (!matchPattern(adaptor.getMultiples(), m_Constant(&multiplesAttr))) - return failure(); - - SmallVector multiples = llvm::to_vector( - llvm::map_range(multiplesAttr.getValues(), - [](const APInt &val) { return val.getSExtValue(); })); + Type inputType = getElementTypeOrSelf(adaptor.getInput1().getType()); + SmallVector multiples; + if (!tosa::getConstShapeValues(adaptor.getMultiples().getDefiningOp(), + multiples)) { + auto rank = + cast(adaptor.getMultiples().getType()).getRank(); + SmallVector fallback(rank, ShapedType::kDynamic); + inferredReturnShapes.push_back(ShapedTypeComponents(fallback, inputType)); + return success(); + } else { + multiples = convertToMlirShape(multiples); + } ShapeAdaptor inputShape(adaptor.getInput1().getType()); SmallVector outputShape; if (!inputShape.hasRank()) { outputShape.resize(multiples.size(), ShapedType::kDynamic); - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back( + ShapedTypeComponents(outputShape, inputType)); return success(); } else if (static_cast(inputShape.getRank()) != multiples.size()) return failure(); @@ -1632,13 +1638,17 @@ LogicalResult tosa::TileOp::inferReturnTypeComponents( // Any non dynamic dimension can be multiplied to a known size. outputShape.reserve(multiples.size()); for (int i = 0, s = inputShape.getRank(); i < s; i++) { - int64_t dim = inputShape.getDimSize(i); - if (dim != ShapedType::kDynamic) - dim *= multiples[i]; - outputShape.push_back(dim); + if (multiples[i] == ShapedType::kDynamic) { + outputShape.push_back(ShapedType::kDynamic); + } else { + int64_t dim = inputShape.getDimSize(i); + if (dim != ShapedType::kDynamic) + dim *= multiples[i]; + outputShape.push_back(dim); + } } - inferredReturnShapes.push_back(ShapedTypeComponents(outputShape)); + inferredReturnShapes.push_back(ShapedTypeComponents(outputShape, inputType)); return success(); } diff --git a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir index 761e489bdeae5..b6fa5d6f4d0ec 100644 --- a/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir +++ b/mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir @@ -599,6 +599,17 @@ func.func @test_tile(%arg0 : tensor<2x3x?xi32>) -> () { // ----- +// CHECK-LABEL: @test_tile_unknown_multiples +func.func @test_tile_unknown_multiples(%arg0 : tensor<2x3x?xi32>) -> () { + // CHECK: %[[CST:.*]] = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> + // CHECK: tosa.tile %arg0, %[[CST]] : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor<4x?x?xi32> + %cst = tosa.const_shape {values = dense<[2, -1, 5]> : tensor<3xindex>} : () -> !tosa.shape<3> + %0 = tosa.tile %arg0, %cst : (tensor<2x3x?xi32>, !tosa.shape<3>) -> tensor + return +} + +// ----- + // CHECK-LABEL: @test_transpose_static func.func @test_transpose_static(%arg0 : tensor<3x4x5xi32>) -> () { // CHECK: tosa.transpose %arg0 {perms = array} : (tensor<3x4x5xi32>) -> tensor<5x4x3xi32>