diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 387521a0e7245..205b0987ff98a 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -16,6 +16,49 @@ include "mlir/Dialect/PDL/IR/PDLTypes.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/IR/OpBase.td" +def DecomposeOp : Op { + let description = [{ + Decomposes named complex operations, such as higher-dimensional + (depthwise) convolutions, into combinations of lower-dimensional equivalents + when possible. The operand handle must point to a list of such operations. + The returning handle points to the main produced computational operation, + such as the lower-dimensional convolution. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$transformed); + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( + ::mlir::linalg::LinalgOp target); + }]; +} + +def GeneralizeOp : Op { + let description = [{ + Transforms a named structued operation into the generic form with the + explicit attached region. The operand handle must point to a list of + structured operations, it is consumed by the transformation and is not + expected to be used afterwards. The resulting handle points to the list + of equivalent generic operations, in the same order as the original named + operations. + }]; + + let arguments = (ins PDL_Operation:$target); + let results = (outs PDL_Operation:$transformed); + let assemblyFormat = "$target attr-dict"; + + let extraClassDeclaration = [{ + ::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne( + ::mlir::linalg::LinalgOp target); + }]; +} + def InterchangeOp : Op { diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h index 6d058e03fafed..3db28c32f740c 100644 --- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h +++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h @@ -708,6 +708,56 @@ struct LinalgPaddingPattern : public OpInterfaceRewritePattern { LinalgPaddingOptions options; }; +/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D +/// convolution ops. +struct DownscaleSizeOneWindowed2DConvolution final + : public OpRewritePattern { + DownscaleSizeOneWindowed2DConvolution( + MLIRContext *context, + LinalgTransformationFilter f = LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + filter(std::move(f)) {} + + FailureOr + returningMatchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(convOp, rewriter); + } + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; +}; + +/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) +/// dimensions into 1-D depthwise convolution ops. +struct DownscaleDepthwiseConv2DNhwcHwcOp final + : public OpRewritePattern { + DownscaleDepthwiseConv2DNhwcHwcOp( + MLIRContext *context, + LinalgTransformationFilter f = LinalgTransformationFilter(), + PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit), + filter(std::move(f)) {} + + FailureOr + returningMatchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, + PatternRewriter &rewriter) const; + + LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, + PatternRewriter &rewriter) const override { + return returningMatchAndRewrite(convOp, rewriter); + } + +private: + /// LinalgTransformMarker handles special attribute manipulations. + LinalgTransformationFilter filter; +}; + struct LinalgFusionOptions { /// List of operands indices to use for fusion. llvm::SmallSet indicesToFuse = {}; diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index f80ba4fc286f7..b081e241a848d 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -50,6 +50,68 @@ class SimpleRewriter : public PatternRewriter { }; } // namespace +/// Attempts to apply the pattern specified as template argument to the given +/// operation. The pattern is expected to have a `returningMatchAndRewrite` +/// function that returns the "main" result or failure. Returns failure if the +/// pattern failed to apply. Extra arguments are forwarded to the pattern +/// constructor. +template +static FailureOr tryApply(Operation *operation, Args &&...args) { + // Check if the given operation has the type expected by the pattern. + using OpTy = typename llvm::function_traits< + decltype(&PatternTy::returningMatchAndRewrite)>::template arg_t<0>; + auto op = dyn_cast(operation); + if (!op) + return failure(); + + // Apply the pattern directly to the op. + PatternTy pattern(operation->getContext(), std::forward(args)...); + SimpleRewriter rewriter(operation->getContext()); + rewriter.setInsertionPoint(operation); + auto result = pattern.returningMatchAndRewrite(op, rewriter); + if (failed(result)) + return failure(); + return cast(result->getOperation()); +} + +//===----------------------------------------------------------------------===// +// DecomposeOp +//===----------------------------------------------------------------------===// + +FailureOr transform::DecomposeOp::applyToOne(LinalgOp target) { + FailureOr windowed = + tryApply(target); + if (succeeded(windowed)) + return windowed; + + FailureOr depthwise = + tryApply(target); + if (succeeded(depthwise)) + return depthwise; + + InFlightDiagnostic diag = emitError() << "failed to apply"; + diag.attachNote(target.getLoc()) << "attempted to apply to this op"; + return diag; +} + +//===----------------------------------------------------------------------===// +// GeneralizeOp +//===----------------------------------------------------------------------===// + +FailureOr transform::GeneralizeOp::applyToOne(LinalgOp target) { + // Exit early if no transformation is needed. + if (isa(target)) + return target; + + FailureOr generic = tryApply(target); + if (succeeded(generic)) + return generic; + + InFlightDiagnostic diag = emitError() << "failed to apply"; + diag.attachNote(target.getLoc()) << "attempted to apply to this op"; + return diag; +} + //===----------------------------------------------------------------------===// // InterchangeOp //===----------------------------------------------------------------------===// @@ -70,15 +132,7 @@ FailureOr transform::InterchangeOp::applyToOne(LinalgOp target) { return diag; } - GenericOpInterchangePattern pattern(getContext(), interchangeVector); - SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(target); - FailureOr result = - pattern.returningMatchAndRewrite(genericTarget, rewriter); - if (failed(result)) - return failure(); - - return cast(result->getOperation()); + return tryApply(target, interchangeVector); } LogicalResult transform::InterchangeOp::verify() { @@ -147,18 +201,15 @@ FailureOr transform::PadOp::applyToOne(LinalgOp target) { paddingOptions.setHoistPaddings(extractI64Array(getHoistPaddings())); paddingOptions.setTransposePaddings(transposePaddings); - LinalgPaddingPattern pattern(getContext(), paddingOptions); - SimpleRewriter rewriter(getContext()); - rewriter.setInsertionPoint(target); - FailureOr patternResult = - pattern.returningMatchAndRewrite(target, rewriter); - if (failed(patternResult)) { - InFlightDiagnostic diag = emitError() - << "failed to apply pattern to target op"; - diag.attachNote(target.getLoc()) << "target op"; - return diag; - } - return patternResult; + FailureOr result = + tryApply(target, paddingOptions); + if (succeeded(result)) + return result; + + InFlightDiagnostic diag = emitError() + << "failed to apply pattern to target op"; + diag.attachNote(target.getLoc()) << "target op"; + return diag; } LogicalResult transform::PadOp::verify() { diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp index 7fed6c0428fb2..6b347561a09e0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp @@ -945,7 +945,6 @@ LogicalResult ExtractSliceOfPadTensorSwapPattern::matchAndRewrite( return success(); } -namespace { // The following are patterns for downscaling convolution ops with size-1 // window dimensions. // @@ -954,179 +953,145 @@ namespace { // and then turning back to named ops. But for now it's fine to have a few // patterns matching special ops to get started. -/// Rewrites 2-D convolution ops with size-1 window dimensions into 1-D -/// convolution ops. -struct DownscaleSizeOneWindowed2DConvolution final - : public OpRewritePattern { - DownscaleSizeOneWindowed2DConvolution( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - filter(std::move(f)) {} - - LogicalResult matchAndRewrite(linalg::Conv2DNhwcHwcfOp convOp, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); - if (convOp.hasBufferSemantics()) - return failure(); // To be implemented - - Value input = convOp.inputs().front(); - Value kernel = convOp.inputs().back(); - Value output = convOp.outputs().front(); - - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); - - auto kernelShape = kernelType.getShape(); - auto outputShape = outputType.getShape(); - - // Only handle the case where at least one of the window dimensions is - // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; - int64_t ohSize = outputShape[1], owSize = outputShape[2]; - bool removeH = (khSize == 1 && ohSize == 1); - bool removeW = (kwSize == 1 && owSize == 1); - if (!removeH && !removeW) - return failure(); - - // Get new shapes and types for all operands by removing the size-1 - // dimension. - using RTTBuilder = RankedTensorType::Builder; - RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); - RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); - RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim(removeH ? 1 : 2); - - // Rank-reduce operands. - Location loc = convOp.getLoc(); - Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, input, newInputType); - Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, kernel, newKernelType); - Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, output, newOutputType); - - // Rank-reduce strides and dilations too. - // TODO: dropDim 1-liner helper. - auto strides = llvm::to_vector<4>(convOp.strides().getValues()); - strides.erase(strides.begin() + (removeH ? 0 : 1)); - auto stridesAttr = rewriter.getI64VectorAttr(strides); - - auto dilations = - llvm::to_vector<4>(convOp.dilations().getValues()); - dilations.erase(dilations.begin() + (removeH ? 0 : 1)); - auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - - auto conv1DOp = rewriter.create( - loc, newOutputType, ValueRange{newInput, newKernel}, - ValueRange{newOutput}, stridesAttr, dilationsAttr); - - // Insert back. - Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( - rewriter, loc, conv1DOp.getResult(0), output); - rewriter.replaceOp(convOp, inserted); - - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); - return success(); - }; - -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; - -/// Rewrites 2-D depthwise convolution ops with size-1 (w, kw) or (h, kh) -/// dimensions into 1-D depthwise convolution ops. -struct DownscaleDepthwiseConv2DNhwcHwcOp final - : public OpRewritePattern { - DownscaleDepthwiseConv2DNhwcHwcOp( - MLIRContext *context, - LinalgTransformationFilter f = LinalgTransformationFilter(), - PatternBenefit benefit = 1) - : OpRewritePattern(context, benefit), - filter(std::move(f)) {} - - LogicalResult matchAndRewrite(DepthwiseConv2DNhwcHwcOp convOp, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, convOp))) - return failure(); - if (convOp.hasBufferSemantics()) - return failure(); // To be implemented - - Value input = convOp.inputs().front(); - Value kernel = convOp.inputs().back(); - Value output = convOp.outputs().front(); - - auto inputType = input.getType().dyn_cast(); - auto kernelType = kernel.getType().dyn_cast(); - auto outputType = output.getType().dyn_cast(); - - auto kernelShape = kernelType.getShape(); - auto outputShape = outputType.getShape(); - - // Only handle the case where at least one of the window dimensions is - // of size 1. Other cases can rely on tiling to reduce to such cases. - int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; - int64_t ohSize = outputShape[1], owSize = outputShape[2]; - bool removeH = (khSize == 1 && ohSize == 1); - bool removeW = (kwSize == 1 && owSize == 1); - if (!removeH && !removeW) - return failure(); +FailureOr +DownscaleSizeOneWindowed2DConvolution::returningMatchAndRewrite( + linalg::Conv2DNhwcHwcfOp convOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, convOp))) + return failure(); + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. + + Value input = convOp.inputs().front(); + Value kernel = convOp.inputs().back(); + Value output = convOp.outputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; + int64_t ohSize = outputShape[1], owSize = outputShape[2]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); - // Get new shapes and types for all operands by removing the size-1 - // dimension. - using RTTBuilder = RankedTensorType::Builder; - RankedTensorType newInputType = - RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); - RankedTensorType newKernelType = - RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); - RankedTensorType newOutputType = - RTTBuilder(outputType).dropDim(removeH ? 1 : 2); - - // Rank-reduce operands. - Location loc = convOp.getLoc(); - Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, input, newInputType); - Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, kernel, newKernelType); - Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( - rewriter, loc, output, newOutputType); - - // Rank-reduce strides and dilations too. - // TODO: dropDim 1-liner helper. - auto strides = llvm::to_vector<4>(convOp.strides().getValues()); - strides.erase(strides.begin() + (removeH ? 0 : 1)); - auto stridesAttr = rewriter.getI64VectorAttr(strides); - - auto dilations = - llvm::to_vector<4>(convOp.dilations().getValues()); - dilations.erase(dilations.begin() + (removeH ? 0 : 1)); - auto dilationsAttr = rewriter.getI64VectorAttr(dilations); - - auto conv1DOp = rewriter.create( - loc, newOutputType, ValueRange{newInput, newKernel}, - ValueRange{newOutput}, stridesAttr, dilationsAttr); - - // Insert back. - Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( - rewriter, loc, conv1DOp.getResult(0), output); - rewriter.replaceOp(convOp, inserted); - - filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); - return success(); - }; + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = llvm::to_vector<4>(convOp.strides().getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); + + auto dilations = llvm::to_vector<4>(convOp.dilations().getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); + + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newKernel}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); + + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + + filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); + return conv1DOp; +} -private: - /// LinalgTransformMarker handles special attribute manipulations. - LinalgTransformationFilter filter; -}; +FailureOr +DownscaleDepthwiseConv2DNhwcHwcOp::returningMatchAndRewrite( + DepthwiseConv2DNhwcHwcOp convOp, PatternRewriter &rewriter) const { + if (failed(filter.checkAndNotify(rewriter, convOp))) + return failure(); + if (convOp.hasBufferSemantics()) + return failure(); // To be implemented. + + Value input = convOp.inputs().front(); + Value kernel = convOp.inputs().back(); + Value output = convOp.outputs().front(); + + auto inputType = input.getType().dyn_cast(); + auto kernelType = kernel.getType().dyn_cast(); + auto outputType = output.getType().dyn_cast(); + + auto kernelShape = kernelType.getShape(); + auto outputShape = outputType.getShape(); + + // Only handle the case where at least one of the window dimensions is + // of size 1. Other cases can rely on tiling to reduce to such cases. + int64_t khSize = kernelShape[0], kwSize = kernelShape[1]; + int64_t ohSize = outputShape[1], owSize = outputShape[2]; + bool removeH = (khSize == 1 && ohSize == 1); + bool removeW = (kwSize == 1 && owSize == 1); + if (!removeH && !removeW) + return failure(); -} // namespace + // Get new shapes and types for all operands by removing the size-1 + // dimension. + using RTTBuilder = RankedTensorType::Builder; + RankedTensorType newInputType = + RTTBuilder(inputType).dropDim((removeH ? 1 : 2)); + RankedTensorType newKernelType = + RTTBuilder(kernelType).dropDim((removeH ? 0 : 1)); + RankedTensorType newOutputType = + RTTBuilder(outputType).dropDim(removeH ? 1 : 2); + + // Rank-reduce operands. + Location loc = convOp.getLoc(); + Value newInput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, input, newInputType); + Value newKernel = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, kernel, newKernelType); + Value newOutput = tensor::createCanonicalRankReducingExtractSliceOp( + rewriter, loc, output, newOutputType); + + // Rank-reduce strides and dilations too. + // TODO: dropDim 1-liner helper. + auto strides = llvm::to_vector<4>(convOp.strides().getValues()); + strides.erase(strides.begin() + (removeH ? 0 : 1)); + auto stridesAttr = rewriter.getI64VectorAttr(strides); + + auto dilations = llvm::to_vector<4>(convOp.dilations().getValues()); + dilations.erase(dilations.begin() + (removeH ? 0 : 1)); + auto dilationsAttr = rewriter.getI64VectorAttr(dilations); + + auto conv1DOp = rewriter.create( + loc, newOutputType, ValueRange{newInput, newKernel}, + ValueRange{newOutput}, stridesAttr, dilationsAttr); + + // Insert back. + Value inserted = tensor::createCanonicalRankReducingInsertSliceOp( + rewriter, loc, conv1DOp.getResult(0), output); + rewriter.replaceOp(convOp, inserted); + + filter.replaceLinalgTransformationFilter(rewriter, conv1DOp); + return conv1DOp; +} void linalg::populateDecomposeConvolutionPatterns( RewritePatternSet &patterns, const LinalgTransformationFilter &filter, diff --git a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py index 70e39be5289da..e5a2a473150cc 100644 --- a/mlir/python/mlir/dialects/_structured_transform_ops_ext.py +++ b/mlir/python/mlir/dialects/_structured_transform_ops_ext.py @@ -69,6 +69,28 @@ def _get_int_int_array_attr( return ArrayAttr.get([_get_int_array_attr(value) for value in values]) +class DecomposeOp: + """Specialization for DecomposeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + +class GeneralizeOp: + """Specialization for GeneralizeOp class.""" + + def __init__(self, target: Union[Operation, Value], *, loc=None, ip=None): + super().__init__( + pdl.OperationType.get(), + _get_op_result_or_value(target), + loc=loc, + ip=ip) + + class InterchangeOp: """Specialization for InterchangeOp class.""" diff --git a/mlir/test/Dialect/Linalg/transform-op-decompose.mlir b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir new file mode 100644 index 0000000000000..e80c3b1078d6d --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-decompose.mlir @@ -0,0 +1,75 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter --split-input-file %s | FileCheck %s + +// CHECK-LABEL: @conv_2d_nhwc_hwcf +// CHECK-SAME: %[[ARG0:.+]]: tensor, +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x?x?x?xf32> +// CHECK-SAME: %[[ARG2:.+]]: tensor +func.func @conv_2d_nhwc_hwcf(%input: tensor, %filter: tensor<1x?x?x?xf32>, %init: tensor) -> tensor { + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICE2:.+]] = tensor.extract_slice %[[ARG2]] + // CHECK: %[[SLICERES:.+]] = linalg.conv_1d_nwc_wcf + // CHECK: %[[RES:.+]] = tensor.insert_slice %[[SLICERES]] into %[[ARG2]] + %0 = linalg.conv_2d_nhwc_hwcf {dilations = dense<1> : tensor<2xi64>, + strides = dense<1> : tensor<2xi64>} + ins (%input, %filter: tensor, tensor<1x?x?x?xf32>) + outs (%init: tensor) -> tensor + // CHECK: return %[[RES]] + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.conv_2d_nhwc_hwcf"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.decompose %0 + } +} + +// ----- + +// CHECK-LABEL: @depthwise_conv_2d_nhwc_hwc +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x1x113x96xf32> +// CHECK-SAME: %[[ARG1:.+]]: tensor<1x3x96xf32> +func.func @depthwise_conv_2d_nhwc_hwc(%input: tensor<1x1x113x96xf32>, %filter: tensor<1x3x96xf32>) -> tensor<1x1x56x96xf32> { + // CHECK: %[[RES:.+]] = linalg.init_tensor + %init = linalg.init_tensor [1, 1, 56, 96] : tensor<1x1x56x96xf32> + // CHECK: %[[SLICE0:.+]] = tensor.extract_slice %[[ARG0]] + // CHECK: %[[SLICE1:.+]] = tensor.extract_slice %[[ARG1]] + // CHECK: %[[SLICERES:.+]] = tensor.extract_slice %[[RES]] + // CHECK: %[[OPRES:.+]] = linalg.depthwise_conv_1d_nwc_wc + // CHECK-SAME: ins(%[[SLICE0]], %[[SLICE1]] + // CHECK-SAME: outs(%[[SLICERES]] + // CHECK: %[[INSERTED:.+]] = tensor.insert_slice %[[OPRES]] into %[[RES]] + %0 = linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<2xi64>} + ins(%input, %filter: tensor<1x1x113x96xf32>, tensor<1x3x96xf32>) + outs(%init: tensor<1x1x56x96xf32>) -> tensor<1x1x56x96xf32> + // CHECK: %[[INSERTED]] + return %0: tensor<1x1x56x96xf32> +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.depthwise_conv_2d_nhwc_hwc"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.decompose %0 + } +} diff --git a/mlir/test/Dialect/Linalg/transform-op-generalize.mlir b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir new file mode 100644 index 0000000000000..1a20cf7502cab --- /dev/null +++ b/mlir/test/Dialect/Linalg/transform-op-generalize.mlir @@ -0,0 +1,28 @@ +// RUN: mlir-opt --test-transform-dialect-interpreter %s | FileCheck %s + +// CHECK-LABEL: func.func @generalize_unary +func.func @generalize_unary(%arg0: tensor, %arg1: tensor) -> tensor { + + // CHECK-NOT: linalg.elemwise_unary + // CHECK: linalg.generic + %0 = linalg.elemwise_unary ins(%arg0 : tensor) + outs(%arg1: tensor) -> tensor + return %0 : tensor +} + +transform.with_pdl_patterns { +^bb0(%arg0: !pdl.operation): + pdl.pattern @pdl_target : benefit(1) { + %args = operands + %results = types + %0 = pdl.operation "linalg.elemwise_unary"(%args : !pdl.range) -> (%results : !pdl.range) + // TODO: we don't want this, but it is the required terminator for pdl.pattern + rewrite %0 with "transform.dialect" + } + + transform.sequence %arg0 { + ^bb1(%arg1: !pdl.operation): + %0 = pdl_match @pdl_target in %arg1 + %1 = transform.structured.generalize %0 + } +} diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index 463dec10d7bd5..a34b03fb9d0bc 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -16,6 +16,28 @@ def run(f): return f +@run +def testDecompose(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.DecomposeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testDecompose + # CHECK: transform.sequence + # CHECK: transform.structured.decompose + + +@run +def testGeneralize(): + sequence = transform.SequenceOp() + with InsertionPoint(sequence.body): + structured.GeneralizeOp(sequence.bodyTarget) + transform.YieldOp() + # CHECK-LABEL: TEST: testGeneralize + # CHECK: transform.sequence + # CHECK: transform.structured.generalize + + @run def testInterchange(): sequence = transform.SequenceOp()