diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index 62c015a85ee36..36e89405210a6 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -70,6 +70,8 @@ namespace { // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c]. +// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?]. LogicalResult computeReshapeOutput(ArrayRef higherRankShape, ArrayRef lowerRankShape, @@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef higherRankShape, higherRankDim = higherRankShape[i + rankDiff]; lowerRankDim = lowerRankShape[i]; - if (lowerRankDim != 1 && higherRankDim != 1 && + auto isStaticDimAndNotEqualToOne = [](int64_t dim) { + return dim != 1 && dim != ShapedType::kDynamic; + }; + + if (isStaticDimAndNotEqualToOne(lowerRankDim) && + isStaticDimAndNotEqualToOne(higherRankDim) && lowerRankDim != higherRankDim) return failure();