diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index bbfbd2e9736a1..397e322a64dea 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -291,6 +291,9 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcOp); CONV_OP_SPECIALIZER(linalg::Conv2DNhwgcGfhwcQOp); CONV_OP_SPECIALIZER(linalg::Conv3DOp); + CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfOp); + CONV_OP_SPECIALIZER(linalg::Conv3DNdhwcDhwcfQOp); + CONV_OP_SPECIALIZER(linalg::Conv3DNcdhwFcdhwOp); // ----------------------------- // Depthwise Convolution ops. // ----------------------------- @@ -302,6 +305,8 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcQOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNhwcHwcmQOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNcdhwCdhwOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); // ----------------------------- // Pooling ops. diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 1244be90390e2..5c4a359dac4a4 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -569,7 +569,7 @@ class ConvMatcherBuilder { } /// Match body pattern. This should be called last. - bool matchBody(bool zeroPointOffset = false) { + bool matchBody(bool containsZeroPointOffset = false) { if (!matched) return false; Block *body = op.getBlock(); @@ -577,7 +577,7 @@ class ConvMatcherBuilder { switch (poolingType) { case PoolingType::None: return bodyMatcherForConvolutionOps(yieldOp.getOperand(0), body, - zeroPointOffset); + containsZeroPointOffset); case PoolingType::MaxSigned: return bodyMatcherForMaxSignedPoolOps(yieldOp.getOperand(0), body); case PoolingType::MaxUnsigned: @@ -762,7 +762,7 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, H, W, F}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); } // #inputMap = affine_map<(N, H, W, F, h, w, c) -> (N, H + h, W + w, c)> @@ -825,7 +825,7 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, H, W, F}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); } // #inputMap = affine_map<(N, F, H, W, c, h, w) -> (N, c, H + h, W + w)> @@ -888,7 +888,7 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, F, H, W}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); } // #inputMap = affine_map<(N, G, F, H, W, c, h, w) -> (N, G, c, H + h, W + w)> @@ -987,7 +987,7 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, G, F, H, W}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); } // #inputMap = affine_map<(N, H, W, G, F, h, w, c) -> (N, H + h, W + w, G, c)> @@ -1054,7 +1054,7 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, H, W, G, F}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); } // #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> @@ -1088,6 +1088,114 @@ bool isaConvolutionOpOfType(LinalgOp op, .matchBody(); } +// #inputMap = affine_map<(N, D, H, W, F, d, h, w, c) +// -> (N, D + d, H + h, W + w, c)> +// #filterMap = affine_map<(N, D, H, W, F, d, h, w, c) -> (d, h, w, c, F)> +// #outputMap = affine_map<(N, D, H, W, F, d, h, w, c) -> (N, D, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr F = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + AffineExpr c = m.dim(8); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), c}, + /*filterMap=*/{d, h, w, c, F}, + /*outputMap=*/{N, D, H, W, F}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, D, H, W, F, d, h, w, c) +// -> (N, D + d, H + h, W + w, c)> +// #filterMap = affine_map<(N, D, H, W, F, d, h, w, c) -> (d, h, w, c, F)> +// #scalarMap = affine_map<(N, D, H, W, F, d, h, w, c) -> ()> +// #outputMap = affine_map<(N, D, H, W, F, d, h, w, c) -> (N, D, H, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr F = m.dim(4); + AffineExpr d = m.dim(5); + AffineExpr h = m.dim(6); + AffineExpr w = m.dim(7); + AffineExpr c = m.dim(8); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), c}, + /*filterMap=*/{d, h, w, c, F}, + /*scalarMap=*/{}, + /*scalarMap=*/{}, + /*outputMap=*/{N, D, H, W, F}}) + .matchBody(/*containsZeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, F, D, H, W, c, d, h, w) +// -> (N, c, D + d, H + h, W + w)> +// #filterMap = affine_map<(N, F, D, H, W, c, d, h, w) -> (F, c, d, h, w)> +// #outputMap = affine_map<(N, F, D, H, W, c, d, h, w) -> (N, F, D, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr F = m.dim(1); + AffineExpr D = m.dim(2); + AffineExpr H = m.dim(3); + AffineExpr W = m.dim(4); + AffineExpr c = m.dim(5); + AffineExpr d = m.dim(6); + AffineExpr h = m.dim(7); + AffineExpr w = m.dim(8); + + return m.matchStride(/*iDim=*/2, /*fDim=*/2, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/3, /*oDim=*/3, /*idx=*/1) + .matchStride(/*iDim=*/4, /*fDim=*/4, /*oDim=*/4, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, c, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{F, c, d, h, w}, + /*outputMap=*/{N, F, D, H, W}}) + .matchBody(); +} + // #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)> // #filterMap = affine_map<(N, W, C, w) -> (C, w)> // #outputMap = affine_map<(N, W, C, w) -> (N, C, W)> @@ -1254,7 +1362,7 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, H, W, C}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); } // #inputMap = affine_map<(N, H, W, C, CM, h, w) -> (N, H + h, W + w, C)> @@ -1317,7 +1425,76 @@ bool isaConvolutionOpOfType( /*scalarMap=*/{}, /*scalarMap=*/{}, /*outputMap=*/{N, H, W, C, CM}}) - .matchBody(/*zeroPointOffset=*/true); + .matchBody(/*containsZeroPointOffset=*/true); +} + +// #inputMap = affine_map<(N, D, H, W, d, h, w, C) +// -> (N, D + d, H + h, W + w, C)> +// #filterMap = affine_map<(N, D, H, W, d, h, w, C) +// -> (d, h, w, C)> +// #outputMap = affine_map<(N, D, H, W, d, h, w, C) +// -> (N, D, H, W, C)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr d = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + AffineExpr C = m.dim(7); + + return m.matchStride(/*iDim=*/1, /*fDim=*/0, /*oDim=*/1, /*idx=*/0) + .matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/1) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2), C}, + /*filterMap=*/{d, h, w, C}, + /*outputMap=*/{N, D, H, W, C}}) + .matchBody(); +} + +// #inputMap = affine_map<(N, D, H, W, d, h, w, C) -> (N, C, D + d, H + h, W + +// w)> #filterMap = affine_map<(N, D, H, W, d, h, w, C) -> (C, d, h, w)> +// #outputMap = affine_map<(N, D, H, W, d, h, w, C) -> (N, C, D, H, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + ConvMatcherBuilder m(op, /*spatialRank=*/3, dilations, strides); + AffineExpr N = m.dim(0); + AffineExpr D = m.dim(1); + AffineExpr H = m.dim(2); + AffineExpr W = m.dim(3); + AffineExpr d = m.dim(4); + AffineExpr h = m.dim(5); + AffineExpr w = m.dim(6); + AffineExpr C = m.dim(7); + + return m.matchStride(/*iDim=*/2, /*fDim=*/1, /*oDim=*/2, /*idx=*/0) + .matchStride(/*iDim=*/3, /*fDim=*/2, /*oDim=*/3, /*idx=*/1) + .matchStride(/*iDim=*/4, /*fDim=*/3, /*oDim=*/4, /*idx=*/2) + .matchMaps({/*inputMap=*/{N, C, m.strided(D, d, 0), m.strided(H, h, 1), + m.strided(W, w, 2)}, + /*filterMap=*/{C, d, h, w}, + /*outputMap=*/{N, C, D, H, W}}) + .matchBody(); } // #inputMap = affine_map<(N, D, H, W, CM, d, h, w, C) diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 432fdd12f540d..ac9a33b0528b0 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -218,6 +218,45 @@ func.func @conv_3d(%in : tensor, %filter : tensor, %out : // ----- +func.func @conv_3d_ndhwc_dhwcf(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_3d_ndhwc_dhwcf + {dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ndhwc_dhwcf +// CHECK: linalg.conv_3d_ndhwc_dhwcf +// CHECK-SAME: dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64> + +// ----- + +func.func @conv_3d_ndhwc_dhwcf_q(%input: tensor, %filter: tensor, %output: tensor, %zp_input: i32, %zp_filter: i32) -> tensor { + %0 = linalg.conv_3d_ndhwc_dhwcf_q + {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>} + ins (%input, %filter, %zp_input, %zp_filter : tensor, tensor, i32, i32) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ndhwc_dhwcf_q +// CHECK: linalg.conv_3d_ndhwc_dhwcf_q +// CHECK-SAME: dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64> + +// ----- + +func.func @conv_3d_ncdhw_fcdhw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_3d_ncdhw_fcdhw + {dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = dense<[4, 5, 6]> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d_ncdhw_fcdhw +// CHECK: linalg.conv_3d_ncdhw_fcdhw +// CHECK-SAME: dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = dense<[4, 5, 6]> : tensor<3xi64> + +// ----- + // ------------------------------- // Depthwise Convolution ops - 1D. // ------------------------------- @@ -334,6 +373,32 @@ func.func @depthwise_conv_2d_nhwc_hwcm_q(%input: tensor, %filter: te // Depthwise Convolution ops - 3D. // ------------------------------- +func.func @depthwise_conv_3d_ndhwc_dhwc(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ndhwc_dhwc + {dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ndhwc_dhwc +// CHECK: linalg.depthwise_conv_3d_ndhwc_dhwc +// CHECK-SAME: dilations = dense<2> : tensor<3xi64>, strides = dense<3> : tensor<3xi64> + +// ----- + +func.func @depthwise_conv_3d_ncdhw_cdhw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_3d_ncdhw_cdhw + {dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = dense<[4, 5, 6]> : tensor<3xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_3d_ncdhw_cdhw +// CHECK: linalg.depthwise_conv_3d_ncdhw_cdhw +// CHECK-SAME: dilations = dense<[1, 2, 3]> : tensor<3xi64>, strides = dense<[4, 5, 6]> : tensor<3xi64> + +// ----- + func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.depthwise_conv_3d_ndhwc_dhwcm {dilations = dense<1> : tensor<3xi64>, strides = dense<1> : tensor<3xi64>}