diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp index d0b6274ddc0a1d..37687337e10b81 100644 --- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp +++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp @@ -1029,56 +1029,49 @@ class TransposeConvConverter getValuesFromIntArrayAttribute(op.stride().cast(), stride); getValuesFromIntArrayAttribute(op.dilation().cast(), dilation); - // We have not solved for stride / dilation yet. Dilation should be - // straight forward but stride is more complicated. Linalg work is likely - // required for efficient implementation. - if (llvm::any_of(stride, [](int64_t v) { return v != 1; })) - return failure(); - if (llvm::any_of(dilation, [](int64_t v) { return v != 1; })) - return failure(); - - if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || - !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) - return failure(); + // If striding is all 1 we can modify padding and reverse the kernel along + // the x/y direction to make it a regular convolution. This is much simpler + // then handling striding.... + if (llvm::all_of(stride, [](int64_t v) { return v == 1; })) { + if (!inputTy.hasStaticShape() || !weightTy.hasStaticShape() || + !biasTy.hasStaticShape() || !resultTy.hasStaticShape()) + return failure(); + + int64_t kernelHeight = (weightTy.getDimSize(1) - 1) * dilation[0] + 1; + int64_t kernelWidth = (weightTy.getDimSize(2) - 1) * dilation[1] + 1; + int64_t requiredInputHeight = resultTy.getDimSize(1) + kernelHeight - 1; + int64_t requiredInputWidth = resultTy.getDimSize(2) + kernelWidth - 1; + + llvm::SmallVector convPad(4, 0); + convPad[0] = kernelHeight - 1 - pad[0]; + convPad[2] = kernelWidth - 1 - pad[1]; + convPad[1] = requiredInputHeight - convPad[0] - inputTy.getDimSize(1); + convPad[3] = requiredInputWidth - convPad[2] - inputTy.getDimSize(2); + + auto reverse1 = rewriter.create( + loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); + auto reverse2 = rewriter.create( + loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); + + Value conv2d; + if (op.quantization_info().hasValue()) { + conv2d = rewriter.create( + loc, resultTy, input, reverse2, bias, + rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), + rewriter.getI64ArrayAttr(dilation), + op.quantization_info().getValue()); + } else { + conv2d = rewriter.create( + loc, resultTy, input, reverse2, bias, + rewriter.getI64ArrayAttr(convPad), rewriter.getI64ArrayAttr(stride), + rewriter.getI64ArrayAttr(dilation)); + } - int64_t inputHeight = inputTy.getDimSize(1); - int64_t inputWidth = inputTy.getDimSize(2); - int64_t kernelHeight = weightTy.getDimSize(1); - int64_t kernelWidth = weightTy.getDimSize(2); - int64_t outputHeight = resultTy.getDimSize(1); - int64_t outputWidth = resultTy.getDimSize(2); - - int64_t requiredInputHeight = outputHeight + kernelHeight - 1; - int64_t requiredInputWidth = outputWidth + kernelWidth - 1; - - llvm::SmallVector newPad(4, 0); - newPad[0] = kernelHeight - 1 - pad[0]; - newPad[2] = kernelWidth - 1 - pad[1]; - - newPad[1] = requiredInputHeight - newPad[0] - inputHeight; - newPad[3] = requiredInputWidth - newPad[2] - inputWidth; - - auto reverse1 = rewriter.create( - loc, weightTy, weight, rewriter.getI64IntegerAttr(1)); - auto reverse2 = rewriter.create( - loc, weightTy, reverse1, rewriter.getI64IntegerAttr(2)); - - Value conv2d; - if (op.quantization_info().hasValue()) { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride), - rewriter.getI64ArrayAttr(dilation), - op.quantization_info().getValue()); - } else { - conv2d = rewriter.create( - loc, resultTy, input, reverse2, bias, - rewriter.getI64ArrayAttr(newPad), rewriter.getI64ArrayAttr(stride), - rewriter.getI64ArrayAttr(dilation)); + rewriter.replaceOp(op, conv2d); + return success(); } - rewriter.replaceOp(op, conv2d); - return success(); + return failure(); } }; diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir index 554fa2d4ff1515..376a9103de76ae 100644 --- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir +++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir @@ -1274,6 +1274,16 @@ func @transpose_conv(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, return } +// ----- + +// CHECK-LABEL: @transpose_conv_dilated +func @transpose_conv_dilated(%arg0 : tensor<1x12x12x2xf32>, %arg1 : tensor<4x3x3x2xf32>, %arg2 : tensor<4xf32>) -> () { + // CHECK: [[PAD:%.+]] = linalg.pad_tensor %arg0 low[0, 4, 4, 0] high[0, 4, 4, 0] + // CHECK: linalg.conv_2d_input_nhwc_filter_ohwi_poly {dilations = dense<2> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins([[PAD]], {{%.+}} : tensor<1x20x20x2xf32>, tensor<4x3x3x2xf32>) + %0 = "tosa.transpose_conv2d"(%arg0, %arg1, %arg2) {dilation = [2, 2], out_pad = [0, 0], out_shape = [1, 16, 16, 4], stride = [1, 1]} : (tensor<1x12x12x2xf32>, tensor<4x3x3x2xf32>, tensor<4xf32>) -> tensor<1x16x16x4xf32> + return +} + // -----