From d9c26b9d560f4362503b8f0ec97a52a0a36a57ce Mon Sep 17 00:00:00 2001 From: Hsiangkai Wang Date: Thu, 11 Jul 2024 14:45:36 +0100 Subject: [PATCH] [mlir][linalg] Add transform operator for Winograd Conv2D algorithm (#96182) Add a transform operation structured.winograd_conv2d to convert linalg.conv_2d_nhwc_fhwc to Linalg winograd operations. Reviewers: ftynse, Max191, GeorgeARM, nicolasvasilache, MaheshRavishankar, dcaballe, rengolin Reviewed By: ftynse, Max191 Pull Request: https://github.com/llvm/llvm-project/pull/96182 --- .../Linalg/TransformOps/LinalgTransformOps.td | 51 +++++++++++++ .../Dialect/Linalg/Transforms/Transforms.h | 7 ++ .../TransformOps/LinalgTransformOps.cpp | 31 ++++++++ .../Linalg/Transforms/WinogradConv2D.cpp | 9 ++- .../Linalg/transform-winograd-conv2d.mlir | 76 +++++++++++++++++++ 5 files changed, 173 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 866275cedf68b..ecc86999006db 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -2646,4 +2646,55 @@ def MapCopyToThreadsOp : }]; } +//===----------------------------------------------------------------------===// +// Winograd Conv2D +//===----------------------------------------------------------------------===// + +def WinogradConv2DOp : Op { + let description = [{ + Winograd Conv2D algorithm will convert linalg Conv2D operation into batched + matrix multiply. Before the matrix multiply, it will convert filter and + input into a format suitable for batched matrix multiply. After the matrix + multiply, it will convert output to the final result tensor. + + The algorithm F(m x m, r x r) is + + Y = A^T x [(G x g x G^T) @ (B^T x d x B)] x A + + The size of output Y is m x m. The size of filter g is r x r. The size of + input d is (m + r - 1) x (m + r - 1). A^T, A, G^T, G, B^T, and B are + transformation matrices. + + #### Return modes: + + This operation produces a silenceable failure if `target` is unsupported. + Otherwise, the operation succeeds and returns a handle of the sequence that + replaces the original convolution. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + I64Attr:$m, + I64Attr:$r); + let results = (outs TransformHandleTypeInterface:$transformed); + + let assemblyFormat = + "$target attr-dict `:` functional-type($target, results)"; + + let builders = [ + OpBuilder<(ins "Value":$target)> + ]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::linalg::LinalgOp target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + #endif // LINALG_TRANSFORM_OPS diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 80b1f2ec363eb..eac6eb4387a0f 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -1332,6 +1332,13 @@ FailureOr transposeBatchMatmul(RewriterBase &rewriter, linalg::BatchMatmulOp op, bool transposeLHS = true); +/// Convert linalg.conv_2d_nhwc_fhwc to Winograd Conv2D algorithm +/// F(m x m, r x r). m is the dimension size of output and r is the dimension +/// size of filter. +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r); + //===----------------------------------------------------------------------===// // Rewrite patterns wrapping transformations. // TODO: every single such pattern should be a close to noop wrapper around a diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 4eb334f8bbbfa..bffe7a4e7d62c 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -3711,6 +3711,37 @@ DiagnosedSilenceableFailure transform::MapCopyToThreadsOp::applyToOne( return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// WinogradConv2DOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::WinogradConv2DOp::applyToOne( + transform::TransformRewriter &rewriter, linalg::LinalgOp target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + FailureOr maybeTransformed = failure(); + bool supported = TypeSwitch(target) + .Case([&](linalg::Conv2DNhwcFhwcOp op) { + maybeTransformed = + winogradConv2D(rewriter, op, getM(), getR()); + return true; + }) + .Default([&](Operation *op) { return false; }); + + if (!supported) { + return emitSilenceableError() + << "this operation is not supported to convert to Winograd Conv2D"; + } + + if (supported && failed(maybeTransformed)) { + return emitSilenceableError() << "apply Winograd Conv2D failed"; + } + + results.push_back(*maybeTransformed); + return DiagnosedSilenceableFailure::success(); +} + #include "mlir/Dialect/Linalg/TransformOps/LinalgTransformOpsEnums.cpp.inc" #define GET_OP_CLASSES diff --git a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp index 76742f2a824e7..9b8fa7cf6bac1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/WinogradConv2D.cpp @@ -15,7 +15,9 @@ #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" #include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "llvm/Support/MathExtras.h" namespace mlir { @@ -156,7 +158,6 @@ winogradConv2DHelper(RewriterBase &rewriter, linalg::Conv2DNhwcFhwcOp convOp, auto filterType = cast(filter.getType()); auto outputType = cast(output.getType()); - // TODO: Should we support dynamic shapes? if (!inputType.hasStaticShape()) return rewriter.notifyMatchFailure(convOp, "expected a static shape for the input"); @@ -316,6 +317,12 @@ class WinogradConv2DNhwcFhwc final } // end anonymous namespace //===----------------------------------------------------------------------===// +FailureOr winogradConv2D(RewriterBase &rewriter, + linalg::Conv2DNhwcFhwcOp op, int64_t m, + int64_t r) { + return winogradConv2DHelper(rewriter, op, m, r); +} + void populateWinogradConv2DPatterns(RewritePatternSet &patterns, int64_t m, int64_t r) { MLIRContext *context = patterns.getContext(); diff --git a/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir new file mode 100644 index 0000000000000..c10e0ccebfd7c --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-winograd-conv2d.mlir @@ -0,0 +1,76 @@ +// RUN: mlir-opt %s -transform-interpreter -canonicalize --split-input-file -verify-diagnostics| FileCheck %s + +func.func @conv2d(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @conv2d +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: linalg.batch_matmul +// CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @conv2d_unaligned(%arg0: tensor<2x11x11x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x11x11x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x9x9x2xf32>) -> tensor<2x9x9x2xf32> + return %0 : tensor<2x9x9x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// CHECK-LABEL: func.func @conv2d_unaligned +// CHECK: linalg.winograd_filter_transform m(4) r(3) +// CHECK: tensor.pad +// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: linalg.winograd_input_transform m(4) r(3) +// CHECK: tensor.pad +// CHECK-SAME: low[0, 0, 0, 0] high[0, 3, 3, 0] +// CHECK: linalg.winograd_output_transform m(4) r(3) + +// ----- + +func.func @conv2d_unsupported(%arg0: tensor<2x10x10x5xf32>, %arg1: tensor<3x3x5x2xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> { + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x10x10x5xf32>, tensor<3x3x5x2xf32>) outs(%arg3 : tensor<2x8x8x2xf32>) -> tensor<2x8x8x2xf32> + return %0 : tensor<2x8x8x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+1 {{this operation is not supported to convert to Winograd Conv2D}} + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} + +// ----- + +func.func @conv2d(%arg0: tensor<2x?x?x5xf32>, %arg1: tensor<2x3x3x5xf32>, %arg2: tensor<1xf32>, %arg3: tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> { + %0 = linalg.conv_2d_nhwc_fhwc {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>} ins(%arg0, %arg1 : tensor<2x?x?x5xf32>, tensor<2x3x3x5xf32>) outs(%arg3 : tensor<2x?x?x2xf32>) -> tensor<2x?x?x2xf32> + return %0 : tensor<2x?x?x2xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.conv_2d_nhwc_fhwc"]} in %arg1 : (!transform.any_op) -> !transform.any_op + // expected-error @+1 {{apply Winograd Conv2D failed}} + %1 = transform.structured.winograd_conv2d %0 { m = 4, r = 3 } : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +}