Skip to content

Conversation

@sahas3
Copy link
Member

@sahas3 sahas3 commented Nov 19, 2025

This PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions.

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tosa

Author: Sayan Saha (sahas3)

Changes

This PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions.


Full diff: https://github.com/llvm/llvm-project/pull/168764.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+7-2)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp (+14-4)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+23)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir (+21)
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 0bec0da3f4320..022476a2f44cf 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
     ShapedType weightType = cast<ShapedType>(weight.getType());
     ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
 
-    if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
-          resultType.hasStaticShape())) {
+    // Any dimensions other than batchSize cannot be dynamic for input/output
+    for (unsigned int i = 1; i < 4; ++i) {
+      if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
+        return failure();
+    }
+
+    if (!weightType.hasStaticShape()) {
       return failure();
     }
 
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b0abad5..8b23fd1341bc5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ class TransposeConvNonStridedConverter
     if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
       return failure();
 
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+    // Any dimensions other than batchSize cannot be dynamic for input/output
+    for (unsigned int i = 1; i < 4; ++i) {
+      if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+        return failure();
+    }
+
+    if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return failure();
 
     int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ class TransposeConvStridedConverter
     if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
       return rewriter.notifyMatchFailure(op, "non-one stride found.");
 
-    if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
-        !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+    // Any dimensions other than batchSize cannot be dynamic for input/output
+    for (unsigned int i = 1; i < 4; ++i) {
+      if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+        return failure();
+    }
+
+    if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
       return failure();
 
     int64_t batch = inputTy.getDimSize(0);
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index c7eeb5281679b..d4c4595e84ee0 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar
   %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
   return %0 : tensor<4x10x10x6xi32>
 }
+
+// -----
+// CHECK-LABEL:   func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(
+// CHECK-SAME:      %[[INP:.*]]: tensor<?x10x10x2xf32>,
+// CHECK-SAME:      %[[WTS:.*]]: tensor<1x1x2x3xf32>,
+// CHECK-SAME:      %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+// CHECK:           %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK:           %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK:           %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape  {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK:           %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32>
+// CHECK:           %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32>
+// CHECK:           %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32>
+// CHECK:           %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32>
+// CHECK:           %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
+// CHECK:           %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32>
+// CHECK:           return %[[RES]]
+func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+  %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32>
+  return %0 : tensor<?x10x10x6xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 810135f6f531b..61ca0aedf6a46 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
     (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
   "func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
 }
+
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> {
+  %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32>
+  return %0 : tensor<?x18x19x5xf32>
+}
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> {
+  %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+  %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32>
+  return %0 : tensor<?x35x47x5xf32>
+}

@sahas3 sahas3 requested a review from Jerry-Ge November 19, 2025 19:39
@sahas3
Copy link
Member Author

sahas3 commented Nov 19, 2025

@Jerry-Ge, @Tai78641 FYI, the added LIT test for depthwise conv in this PR actually locks down the fix in #168564 since the decomposition generates a tosa.add op for tensor<?x10x10x6xf32> (outputValue) and tensor<?xf32> (bias) in

if (EqualizeRanks(rewriter, op.getLoc(), outputValue, bias).failed()) {

@github-actions
Copy link

🐧 Linux x64 Test Results

  • 7101 tests passed
  • 594 tests skipped

@Tai78641
Copy link
Contributor

LGTM

Copy link
Member

@Jerry-Ge Jerry-Ge left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sahas3 sahas3 merged commit def8ecb into llvm:main Nov 20, 2025
13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants