-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[tosa] : Relax dynamic dimension checks for batch for conv decompositions #168764
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Sayan Saha (sahas3) ChangesThis PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions. Full diff: https://github.com/llvm/llvm-project/pull/168764.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index 0bec0da3f4320..022476a2f44cf 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
ShapedType weightType = cast<ShapedType>(weight.getType());
ShapedType resultType = cast<ShapedType>(op.getOutput().getType());
- if (!(inputType.hasStaticShape() && weightType.hasStaticShape() &&
- resultType.hasStaticShape())) {
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightType.hasStaticShape()) {
return failure();
}
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
index dc5c51b0abad5..8b23fd1341bc5 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp
@@ -49,8 +49,13 @@ class TransposeConvNonStridedConverter
if (llvm::any_of(stride, [](int64_t v) { return v != 1; }))
return failure();
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t kernelHeight = weightTy.getDimSize(1);
@@ -113,8 +118,13 @@ class TransposeConvStridedConverter
if (llvm::all_of(stride, [](int64_t v) { return v == 1; }))
return rewriter.notifyMatchFailure(op, "non-one stride found.");
- if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() ||
- !biasTy.hasStaticShape() || !resultTy.hasStaticShape())
+ // Any dimensions other than batchSize cannot be dynamic for input/output
+ for (unsigned int i = 1; i < 4; ++i) {
+ if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i))
+ return failure();
+ }
+
+ if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape())
return failure();
int64_t batch = inputTy.getDimSize(0);
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
index c7eeb5281679b..d4c4595e84ee0 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir
@@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar
%0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32>
return %0 : tensor<4x10x10x6xi32>
}
+
+// -----
+// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(
+// CHECK-SAME: %[[INP:.*]]: tensor<?x10x10x2xf32>,
+// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>,
+// CHECK-SAME: %[[BIAS:.*]]: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5>
+// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor<?x10x10x2xf32>, !tosa.shape<5>) -> tensor<?x10x10x2x1xf32>
+// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32>
+// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor<?x10x10x2x1xf32>, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor<?x10x10x2x3xf32>
+// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor<?x10x10x2x3xf32>, !tosa.shape<4>) -> tensor<?x10x10x6xf32>
+// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor<?xf32>, !tosa.shape<4>) -> tensor<1x1x1x?xf32>
+// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor<?x10x10x6xf32>, tensor<1x1x1x?xf32>) -> tensor<?x10x10x6xf32>
+// CHECK: return %[[RES]]
+func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor<?x10x10x2xf32>, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor<?xf32>) -> tensor<?x10x10x6xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>, dilation = array<i64: 1, 1>} : (tensor<?x10x10x2xf32>, tensor<1x1x2x3xf32>, tensor<?xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x10x10x6xf32>
+ return %0 : tensor<?x10x10x6xf32>
+}
diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
index 810135f6f531b..61ca0aedf6a46 100644
--- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir
@@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 :
(tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32>
"func.return" (%2) : (tensor<1x19x2x1xi32>) -> ()
}
+
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor<?x16x14x3xf32>, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x18x19x5xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<?x16x14x3xf32>, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x18x19x5xf32>
+ return %0 : tensor<?x18x19x5xf32>
+}
+
+// -----
+// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch
+// CHECK: tosa.conv2d
+// CHECK-NOT: tosa.transpose_conv2d
+func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor<?x17x15x3xf32>, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor<?x35x47x5xf32> {
+ %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 2, 3>} : (tensor<?x17x15x3xf32>, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<?x35x47x5xf32>
+ return %0 : tensor<?x35x47x5xf32>
+}
|
|
@Jerry-Ge, @Tai78641 FYI, the added LIT test for depthwise conv in this PR actually locks down the fix in #168564 since the decomposition generates a
|
🐧 Linux x64 Test Results
|
|
LGTM |
Jerry-Ge
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This PR relaxes the validation checks to allow input/output data to have dynamic batch dimensions.