From 54c3b3b4375195914c369f8ed631e73cbd2d5177 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Tue, 18 Nov 2025 11:23:09 -0500 Subject: [PATCH 1/2] [tosa] : Enhance EqualizeRanks to handle dynamic dimensions. --- mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index 62c015a85ee36..bb52d15026367 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -70,6 +70,8 @@ namespace { // If lower=[a], higher=[a, a], [a] reshaped into [1, a]. // If lower=[a], target=[a, b, a], [a] reshaped into [1, 1, a]. // If lower=[], target=[a, b, c], [] reshaped into [1, 1, 1]. +// If lower=[c], higher=[?, ?, c], [c] reshaped into [1, 1, c]. +// If lower=[?], higher=[?, ?, ?], [?] reshaped into [1, 1, ?]. LogicalResult computeReshapeOutput(ArrayRef higherRankShape, ArrayRef lowerRankShape, @@ -87,7 +89,12 @@ computeReshapeOutput(ArrayRef higherRankShape, higherRankDim = higherRankShape[i + rankDiff]; lowerRankDim = lowerRankShape[i]; - if (lowerRankDim != 1 && higherRankDim != 1 && + auto isKnownStaticShapeNotEqualToOne = [](int64_t dim) { + return dim != 1 && dim != ShapedType::kDynamic; + }; + + if (isKnownStaticShapeNotEqualToOne(lowerRankDim) && + isKnownStaticShapeNotEqualToOne(higherRankDim) && lowerRankDim != higherRankDim) return failure(); From 8820877b77e3855284eed0c3744e93d927d228c4 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Wed, 19 Nov 2025 08:21:07 -0500 Subject: [PATCH 2/2] Update function name to be more readable. --- mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp index bb52d15026367..36e89405210a6 100644 --- a/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp +++ b/mlir/lib/Dialect/Tosa/Utils/ConversionUtils.cpp @@ -89,12 +89,12 @@ computeReshapeOutput(ArrayRef higherRankShape, higherRankDim = higherRankShape[i + rankDiff]; lowerRankDim = lowerRankShape[i]; - auto isKnownStaticShapeNotEqualToOne = [](int64_t dim) { + auto isStaticDimAndNotEqualToOne = [](int64_t dim) { return dim != 1 && dim != ShapedType::kDynamic; }; - if (isKnownStaticShapeNotEqualToOne(lowerRankDim) && - isKnownStaticShapeNotEqualToOne(higherRankDim) && + if (isStaticDimAndNotEqualToOne(lowerRankDim) && + isStaticDimAndNotEqualToOne(higherRankDim) && lowerRankDim != higherRankDim) return failure();