diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp index 0bec0da3f4320..022476a2f44cf 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp @@ -33,8 +33,13 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern { ShapedType weightType = cast(weight.getType()); ShapedType resultType = cast(op.getOutput().getType()); - if (!(inputType.hasStaticShape() && weightType.hasStaticShape() && - resultType.hasStaticShape())) { + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputType.isDynamicDim(i) || resultType.isDynamicDim(i)) + return failure(); + } + + if (!weightType.hasStaticShape()) { return failure(); } diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp index dc5c51b0abad5..8b23fd1341bc5 100644 --- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp +++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeTransposeConv.cpp @@ -49,8 +49,13 @@ class TransposeConvNonStridedConverter if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) return failure(); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t kernelHeight = weightTy.getDimSize(1); @@ -113,8 +118,13 @@ class TransposeConvStridedConverter if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) return rewriter.notifyMatchFailure(op, "non-one stride found."); - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + // Any dimensions other than batchSize cannot be dynamic for input/output + for (unsigned int i = 1; i < 4; ++i) { + if (inputTy.isDynamicDim(i) || resultTy.isDynamicDim(i)) + return failure(); + } + + if (!weightTy.hasStaticShape() || !biasTy.hasStaticShape()) return failure(); int64_t batch = inputTy.getDimSize(0); diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir index c7eeb5281679b..d4c4595e84ee0 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir @@ -98,3 +98,26 @@ func.func @depthwise_conv2d_no_const_zero_point(%arg0: tensor<4x10x10x2xi8>, %ar %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %arg3, %arg4 {acc_type = i32, pad = array, stride = array, dilation = array} : (tensor<4x10x10x2xi8>, tensor<1x1x2x3xi8>, tensor<6xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<4x10x10x6xi32> return %0 : tensor<4x10x10x6xi32> } + +// ----- +// CHECK-LABEL: func.func @depthwise_conv2d_as_mul_dynamic_batch_bias( +// CHECK-SAME: %[[INP:.*]]: tensor, +// CHECK-SAME: %[[WTS:.*]]: tensor<1x1x2x3xf32>, +// CHECK-SAME: %[[BIAS:.*]]: tensor) -> tensor { +// CHECK: %[[BIAS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[RES_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 6]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[MUL_SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[WTS_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2, 3]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[INP_EXPANDED_SHAPE:.*]] = tosa.const_shape {values = dense<[-1, 10, 10, 2, 1]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK: %[[INP_RESHAPED:.*]] = tosa.reshape %[[INP]], %[[INP_EXPANDED_SHAPE]] : (tensor, !tosa.shape<5>) -> tensor +// CHECK: %[[WTS_RESHAPED:.*]] = tosa.reshape %[[WTS]], %[[WTS_EXPANDED_SHAPE]] : (tensor<1x1x2x3xf32>, !tosa.shape<5>) -> tensor<1x1x1x2x3xf32> +// CHECK: %[[MUL:.*]] = tosa.mul %[[INP_RESHAPED]], %[[WTS_RESHAPED]], %[[MUL_SHIFT]] : (tensor, tensor<1x1x1x2x3xf32>, tensor<1xi8>) -> tensor +// CHECK: %[[RES_RESHAPED:.*]] = tosa.reshape %[[MUL]], %[[RES_EXPANDED_SHAPE]] : (tensor, !tosa.shape<4>) -> tensor +// CHECK: %[[BIAS_RESHAPED:.*]] = tosa.reshape %[[BIAS]], %[[BIAS_EXPANDED_SHAPE]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x?xf32> +// CHECK: %[[RES:.*]] = tosa.add %[[RES_RESHAPED]], %[[BIAS_RESHAPED]] : (tensor, tensor<1x1x1x?xf32>) -> tensor +// CHECK: return %[[RES]] +func.func @depthwise_conv2d_as_mul_dynamic_batch_bias(%arg0: tensor, %arg1: tensor<1x1x2x3xf32>, %arg2: tensor) -> tensor { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.depthwise_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, pad = array, stride = array, dilation = array} : (tensor, tensor<1x1x2x3xf32>, tensor, tensor<1xf32>, tensor<1xf32>) -> tensor + return %0 : tensor +} diff --git a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir index 810135f6f531b..61ca0aedf6a46 100644 --- a/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir +++ b/mlir/test/Dialect/Tosa/tosa-decompose-transpose-conv.mlir @@ -181,3 +181,24 @@ func.func @transpose_conv2d_strided_overpad(%arg0 : tensor<1x16x1x1xi8>, %arg1 : (tensor<1x16x1x1xi8>, tensor<1x2x1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x19x2x1xi32> "func.return" (%2) : (tensor<1x19x2x1xi32>) -> () } + + +// ----- +// CHECK-LABEL: @transpose_conv2d_non_strided_dynamic_batch +// CHECK: tosa.conv2d +// CHECK-NOT: tosa.transpose_conv2d +func.func @transpose_conv2d_non_strided_dynamic_batch(%arg0: tensor, %arg1: tensor<5x3x6x3xf32>, %arg2: tensor<5xf32>) -> tensor { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array, stride = array} : (tensor, tensor<5x3x6x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + return %0 : tensor +} + +// ----- +// CHECK-LABEL: @transpose_conv2d_strided_dynamic_batch +// CHECK: tosa.conv2d +// CHECK-NOT: tosa.transpose_conv2d +func.func @transpose_conv2d_strided_dynamic_batch(%arg0: tensor, %arg1: tensor<5x3x5x3xf32>, %arg2: tensor<5xf32>) -> tensor { + %zp = "tosa.const"() <{values = dense<0.0> : tensor<1xf32>}> : () -> tensor<1xf32> + %0 = tosa.transpose_conv2d %arg0, %arg1, %arg2, %zp, %zp {acc_type = f32, out_pad = array, stride = array} : (tensor, tensor<5x3x5x3xf32>, tensor<5xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor + return %0 : tensor +}