diff --git a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp index 249a74b007dce..c2485a08932dd 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp @@ -245,14 +245,22 @@ specializeToConvOp(RewriterBase &rewriter, GenericOp genericOp, ArrayRef dilations, ArrayRef strides) { SmallVector inputs = genericOp.getDpsInputs(); ValueRange outputs = genericOp.getDpsInits(); - SmallVector indexingMaps = genericOp.getIndexingMapsArray(); SmallVector resultTypes = genericOp.hasPureTensorSemantics() ? TypeRange(ValueRange(outputs)) : TypeRange{}; - Attribute stridesAttr = rewriter.getI64TensorAttr(strides); - Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); - LinalgOp namedOp = rewriter.replaceOpWithNewOp( - genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + LinalgOp namedOp; + // Ops with no dilations and no strides. + if constexpr (std::is_same_v || + std::is_same_v || + std::is_same_v) { + namedOp = rewriter.replaceOpWithNewOp(genericOp, resultTypes, + inputs, outputs); + } else { + Attribute stridesAttr = rewriter.getI64TensorAttr(strides); + Attribute dilationsAttr = rewriter.getI64TensorAttr(dilations); + namedOp = rewriter.replaceOpWithNewOp( + genericOp, resultTypes, inputs, outputs, stridesAttr, dilationsAttr); + } return namedOp; } @@ -265,9 +273,19 @@ static FailureOr specializeLinalgConvolutions(RewriterBase &rewriter, return specializeToConvOp(rewriter, genericOp, dilations, \ strides); \ // ----------------------------- + // Convolution ops. + // ----------------------------- + CONV_OP_SPECIALIZER(linalg::Conv1DOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNwcWcfOp); + CONV_OP_SPECIALIZER(linalg::Conv1DNcwFcwOp); + CONV_OP_SPECIALIZER(linalg::Conv2DOp); + CONV_OP_SPECIALIZER(linalg::Conv3DOp); + // ----------------------------- // Depthwise Convolution ops. // ----------------------------- + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNcwCwOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcOp); + CONV_OP_SPECIALIZER(linalg::DepthwiseConv1DNwcWcmOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv2DNchwChwOp); CONV_OP_SPECIALIZER(linalg::DepthwiseConv3DNdhwcDhwcmOp); // ----------------------------- diff --git a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp index 5dd5e1b055f0d..6b85e6ba0ede2 100644 --- a/mlir/lib/Dialect/Linalg/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Linalg/Utils/Utils.cpp @@ -390,7 +390,7 @@ static bool matchConvDimAddExprPattern(ArrayAttr indexingMaps, unsigned iDim, unsigned inputMapIdx = 0, filterMapIdx = 1, outputMapIdx = indexingMaps.size() - 1; AffineExpr inpExpr = getAffineMapDim(indexingMaps, inputMapIdx, iDim); - auto addExpr = dyn_cast(inpExpr); + auto addExpr = dyn_cast_or_null(inpExpr); if (!addExpr || addExpr.getKind() != AffineExprKind::Add) return false; @@ -434,6 +434,263 @@ static bool convLayoutMatches(ArrayRef> mapListExpected, }))); } +// #inputMap = affine_map<(W, w) -> (W + w)> +// #filterMap = affine_map<(W, w) -> (w)> +// #outputMap = affine_map<(W, w) -> (W)> +template <> +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); + MLIRContext *context = op->getContext(); + AffineExpr W = getAffineDimExpr(0, context); + AffineExpr w = getAffineDimExpr(1, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, (*dilations)[0], (*strides)[0])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{W * (*strides)[0] + w * (*dilations)[0]}, + /*filterMap=*/{w}, + /*outputMap=*/{W}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + +// #inputMap = affine_map<(N, W, F, w, c) -> (N, W + w, c)> +// #filterMap = affine_map<(N, W, F, w, c) -> (w, c, F)> +// #outputMap = affine_map<(N, W, F, w, c) -> (N, W, F)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr W = getAffineDimExpr(1, context); + AffineExpr F = getAffineDimExpr(2, context); + AffineExpr w = getAffineDimExpr(3, context); + AffineExpr c = getAffineDimExpr(4, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], c}, + /*filterMap=*/{w, c, F}, + /*outputMap=*/{N, W, F}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + +// #inputMap = affine_map<(N, F, W, c, w) -> (N, c, W + w)> +// #filterMap = affine_map<(N, F, W, c, w) -> (F, c, w)> +// #outputMap = affine_map<(N, F, W, c, w) -> (N, F, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr F = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr c = getAffineDimExpr(3, context); + AffineExpr w = getAffineDimExpr(4, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, (*dilations)[0], (*strides)[0])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, c, W * (*strides)[0] + w * (*dilations)[0]}, + /*filterMap=*/{F, c, w}, + /*outputMap=*/{N, F, W}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + +// #inputMap = affine_map<(H, W, h, w) -> (H + h, W + w)> +// #filterMap = affine_map<(H, W, h, w) -> (h, w)> +// #outputMap = affine_map<(H, W, h, w) -> (H, W)> +template <> +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(2, 1); + *strides = SmallVector(2, 1); + MLIRContext *context = op->getContext(); + AffineExpr H = getAffineDimExpr(0, context); + AffineExpr W = getAffineDimExpr(1, context); + AffineExpr h = getAffineDimExpr(2, context); + AffineExpr w = getAffineDimExpr(3, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, (*dilations)[0], (*strides)[0])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, (*dilations)[1], (*strides)[1])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{H * (*strides)[0] + h * (*dilations)[0], + W * (*strides)[1] + w * (*dilations)[1]}, + /*filterMap=*/{h, w}, + /*outputMap=*/{H, W}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + +// #inputMap = affine_map<(D, H, W, d, h, w) -> (D + d, H + h, W + w)> +// #filterMap = affine_map<(D, H, W, d, h, w) -> (d, h, w)> +// #outputMap = affine_map<(D, H, W, d, h, w) -> (D, H, W)> +template <> +bool isaConvolutionOpOfType(LinalgOp op, + SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(3, 1); + *strides = SmallVector(3, 1); + MLIRContext *context = op->getContext(); + AffineExpr D = getAffineDimExpr(0, context); + AffineExpr H = getAffineDimExpr(1, context); + AffineExpr W = getAffineDimExpr(2, context); + AffineExpr d = getAffineDimExpr(3, context); + AffineExpr h = getAffineDimExpr(4, context); + AffineExpr w = getAffineDimExpr(5, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: D * stride + d * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/0, /*fDim=*/0, + /*oDim=*/0, (*dilations)[0], (*strides)[0])) + return false; + // Match: H * stride + h * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/1, + /*oDim=*/1, (*dilations)[1], (*strides)[1])) + return false; + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/2, + /*oDim=*/2, (*dilations)[2], (*strides)[2])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{D * (*strides)[0] + d * (*dilations)[0], + H * (*strides)[1] + h * (*dilations)[1], + W * (*strides)[2] + w * (*dilations)[2]}, + /*filterMap=*/{d, h, w}, + /*outputMap=*/{D, H, W}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + +// #inputMap = affine_map<(N, W, C, w) -> (N, C, W + w)> +// #filterMap = affine_map<(N, W, C, w) -> (C, w)> +// #outputMap = affine_map<(N, W, C, w) -> (N, C, W)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr W = getAffineDimExpr(1, context); + AffineExpr C = getAffineDimExpr(2, context); + AffineExpr w = getAffineDimExpr(3, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/2, /*fDim=*/1, + /*oDim=*/2, (*dilations)[0], (*strides)[0])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, C, W * (*strides)[0] + w * (*dilations)[0]}, + /*filterMap=*/{C, w}, + /*outputMap=*/{N, C, W}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + // #inputMap = affine_map<(N, W, C, w) -> (N, W + w, C)> // #filterMap = affine_map<(N, W, C, w) -> (w, C)> // #outputMap = affine_map<(N, W, C, w) -> (N, W, C)> @@ -474,6 +731,47 @@ bool isaConvolutionOpOfType( return bodyMatcherForConvolutionOps(yieldVal, body); } +// #inputMap = affine_map<(N, W, C, CM, w) -> (N, W + w, C)> +// #filterMap = affine_map<(N, W, C, CM, w) -> (w, C, CM)> +// #outputMap = affine_map<(N, W, C, CM, w) -> (N, W, C, CM)> +template <> +bool isaConvolutionOpOfType( + LinalgOp op, SmallVector *dilations, + SmallVector *strides) { + if (isa(op)) + return true; + + assert(isaConvolutionOpInterface(op) && + "expected op to implement ConvolutionOpInterface"); + + *dilations = SmallVector(1, 1); + *strides = SmallVector(1, 1); + MLIRContext *context = op->getContext(); + AffineExpr N = getAffineDimExpr(0, context); + AffineExpr W = getAffineDimExpr(1, context); + AffineExpr C = getAffineDimExpr(2, context); + AffineExpr CM = getAffineDimExpr(3, context); + AffineExpr w = getAffineDimExpr(4, context); + ArrayAttr indexingMaps = op.getIndexingMaps(); + // First fetch dilations/strides :- + // Match: W * stride + w * dilation + if (!matchConvDimAddExprPattern(indexingMaps, /*iDim=*/1, /*fDim=*/0, + /*oDim=*/1, (*dilations)[0], (*strides)[0])) + return false; + // Match expected indexing maps + if (!convLayoutMatches( + {/*inputMap=*/{N, W * (*strides)[0] + w * (*dilations)[0], C}, + /*filterMap=*/{w, C, CM}, + /*outputMap=*/{N, W, C, CM}}, + indexingMaps, context)) + return false; + // Match body + Block *body = op.getBlock(); + auto yieldOp = cast(body->getTerminator()); + Value yieldVal = yieldOp.getOperand(0); + return bodyMatcherForConvolutionOps(yieldVal, body); +} + // #inputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H + h, W + w)> // #filterMap = affine_map<(N, H, W, C, h, w) -> (C, h, w)> // #outputMap = affine_map<(N, H, W, C, h, w) -> (N, C, H, W)> diff --git a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir index 8f22cc749bee9..4b2d42a3ae4e0 100644 --- a/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir +++ b/mlir/test/Dialect/Linalg/convolution/roundtrip-convolution.mlir @@ -1,8 +1,87 @@ // The following test examples of linalg convolution named ops lowered to linalg.generic and then // lifted back up to named op. +// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test. + // RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s --implicit-check-not=linalg.generic -// NOTE: Most tests in this file use dynamic shapes as the underlying transformations don't modify shapes. There's one exception that's added as a smoke test. +// ----------------------------- +// Convolution ops. +// ----------------------------- +func.func @conv_1d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { + %0 = linalg.conv_1d + ins(%in, %filter : tensor, tensor) + outs(%out : tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_1d +// CHECK: linalg.conv_1d + +// ----- + +func.func @conv_1d_nwc_wcf(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_1d_nwc_wcf + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_1d_nwc_wcf +// CHECK: linalg.conv_1d_nwc_wcf +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @conv_1d_ncw_fcw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.conv_1d_ncw_fcw + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_1d_ncw_fcw +// CHECK: linalg.conv_1d_ncw_fcw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + +func.func @conv_2d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { + %0 = linalg.conv_2d + ins(%in, %filter : tensor, tensor) + outs(%out: tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_2d +// CHECK: linalg.conv_2d + +// ----- + +func.func @conv_3d(%in : tensor, %filter : tensor, %out : tensor) -> tensor { + %0 = linalg.conv_3d + ins(%in, %filter : tensor, tensor) + outs(%out : tensor) -> tensor + return %0 : tensor +} +// CHECK: @conv_3d +// CHECK: linalg.conv_3d + +// ----- + +// ----------------------------- +// Depthwise Convolution ops. +// ----------------------------- +func.func @depthwise_conv_1d_ncw_cw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_1d_ncw_cw + {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_1d_ncw_cw +// CHECK: linalg.depthwise_conv_1d_ncw_cw +// CHECK-SAME: dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64> + +// ----- + func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: tensor<3x8xi8>, %output: tensor<1x10x8xi32>) -> tensor<1x10x8xi32> { %0 = linalg.depthwise_conv_1d_nwc_wc {dilations = dense<3> : tensor<1xi64>, strides = dense<2> : tensor<1xi64>} @@ -16,6 +95,19 @@ func.func @depthwise_conv_1d_nwc_wc_static(%input: tensor<1x25x8xi8>, %filter: t // ----- +func.func @depthwise_conv_1d_nwc_wcm(%input: tensor, %filter: tensor, %output: tensor) -> tensor { + %0 = linalg.depthwise_conv_1d_nwc_wcm + {dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64>} + ins (%input, %filter: tensor, tensor) + outs (%output: tensor) -> tensor + return %0 : tensor +} +// CHECK: @depthwise_conv_1d_nwc_wcm +// CHECK: linalg.depthwise_conv_1d_nwc_wcm +// CHECK-SAME: dilations = dense<1> : tensor<1xi64>, strides = dense<1> : tensor<1xi64> + +// ----- + func.func @depthwise_conv_2d_nchw_chw(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.depthwise_conv_2d_nchw_chw {dilations = dense<[2,3]> : vector<2xi64>, strides = dense<[4,5]> : vector<2xi64>} @@ -42,6 +134,9 @@ func.func @depthwise_conv_3d_ndhwc_dhwcm(%input: tensor, %filter: // ----- +// ----------------------------- +// Pooling ops. +// ----------------------------- func.func @pooling_nhwc_max(%input: tensor, %filter: tensor, %output: tensor) -> tensor { %0 = linalg.pooling_nhwc_max {dilations = dense<1> : tensor<2xi64>, strides = dense<1> : tensor<2xi64>}