From aa2a96a24ae3a8cc04635ab6ede474c5f2665053 Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Thu, 11 Jan 2024 21:31:03 -0800 Subject: [PATCH] [mlir][TilingInterface] Move TilingInterface tests to use transform dialect ops. (#77204) In the process a couple of test transform dialect ops are added just for testing. These operations are not intended to use as full flushed out of transformation ops, but are rather operations added for testing. A separate operation is added to `LinalgTransformOps.td` to convert a `TilingInterface` operation to loops using the `generateScalarImplementation` method implemented by the operation. Eventually this and other operations related to tiling using the `TilingInterface` need to move to a better place (i.e. out of `Linalg` dialect) --- .../Linalg/TransformOps/LinalgTransformOps.td | 32 +- .../TransformOps/LinalgTransformOps.cpp | 55 +- .../SCF/Transforms/TileUsingInterface.cpp | 5 +- .../lower-to-loops-using-interface.mlir | 94 ++- .../tile-and-fuse-using-interface.mlir | 125 +++- .../tile-fuse-and-yield-using-interface.mlir | 16 +- .../tile-pad-using-interface.mlir | 99 ++- .../TilingInterface/tile-using-interface.mlir | 155 +++-- .../TilingInterface/tile-using-scfforall.mlir | 50 +- .../Interfaces/TilingInterface/CMakeLists.txt | 10 +- .../TilingInterface/TestTilingInterface.cpp | 620 ------------------ .../TestTilingInterfaceTransformOps.cpp | 267 ++++++++ .../TestTilingInterfaceTransformOps.td | 81 +++ .../Interfaces/TilingInterface/lit.local.cfg | 1 + mlir/tools/mlir-opt/mlir-opt.cpp | 4 +- 15 files changed, 856 insertions(+), 758 deletions(-) delete mode 100644 mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp create mode 100644 mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp create mode 100644 mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td create mode 100644 mlir/test/lib/Interfaces/TilingInterface/lit.local.cfg diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index bc257d17483e3..7d10ba0ae829e 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -293,7 +293,10 @@ def FuseOp : Op:$loops); - let hasCustomAssemblyFormat = 1; + let assemblyFormat = [{ + $target ($tile_sizes^)? (`interchange` $tile_interchange^)? + attr-dict `:` functional-type(operands, results) + }]; let hasVerifier = 1; } @@ -1269,6 +1272,33 @@ def ScalarizeOp : Op { + let description = [{ + For operations that implement the `TilingInterface`, and implement + the `generateScalarImplementation` method, lowers the operation to + loops. This operation does not return any handles. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target); + let results = (outs); + + let assemblyFormat = [{ + $target attr-dict `:` type($target) + }]; + + let extraClassDeclaration = [{ + ::mlir::DiagnosedSilenceableFailure applyToOne( + ::mlir::transform::TransformRewriter &rewriter, + ::mlir::TilingInterface target, + ::mlir::transform::ApplyToEachResultList &results, + ::mlir::transform::TransformState &state); + }]; +} + + //===----------------------------------------------------------------------===// // DecomposeInterfaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 5254aac976f46..97d2b4a3be5c5 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -492,38 +492,6 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter, : DiagnosedSilenceableFailure::success(); } -ParseResult transform::FuseOp::parse(OpAsmParser &parser, - OperationState &result) { - OpAsmParser::UnresolvedOperand targetOperand; - if (parser.parseOperand(targetOperand) || - parser.parseOptionalAttrDict(result.attributes)) - return failure(); - - FunctionType trailingType; - SMLoc typeLoc; - if (parser.getCurrentLocation(&typeLoc) || - parser.parseColonType(trailingType)) { - return failure(); - } - if (trailingType.getNumInputs() != 1) - return parser.emitError(typeLoc) << "expected one input type"; - - result.addTypes(trailingType.getResults()); - if (parser.resolveOperand(targetOperand, trailingType.getInput(0), - result.operands)) - return failure(); - return success(); -} - -void transform::FuseOp::print(OpAsmPrinter &p) { - p << ' '; - p << getTarget(); - p.printOptionalAttrDict((*this)->getAttrs()); - p << " : "; - p.printFunctionalType(TypeRange(getOperand().getType()), - getResults().getTypes()); -} - LogicalResult transform::FuseOp::verify() { SmallVector permutation = extractFromIntegerArrayAttr(getTileInterchange()); @@ -2111,6 +2079,22 @@ transform::ScalarizeOp::applyToOne(transform::TransformRewriter &rewriter, return DiagnosedSilenceableFailure::success(); } +//===----------------------------------------------------------------------===// +// ConvertToLoopsOp +//===----------------------------------------------------------------------===// + +DiagnosedSilenceableFailure transform::ConvertToLoopsOp::applyToOne( + transform::TransformRewriter &rewriter, TilingInterface target, + transform::ApplyToEachResultList &results, + transform::TransformState &state) { + rewriter.setInsertionPoint(target); + FailureOr> loops = + scf::lowerToLoopsUsingSCFForOp(rewriter, target); + if (failed(loops)) + return emitDefaultDefiniteFailure(target); + return DiagnosedSilenceableFailure::success(); +} + //===----------------------------------------------------------------------===// // RewriteInDestinationPassingStyleOp //===----------------------------------------------------------------------===// @@ -2620,7 +2604,12 @@ transform::TileUsingForOp::apply(transform::TransformRewriter &rewriter, } scf::SCFTilingOptions tilingOptions; - if (!tileSizes.empty()) { + if (tileSizes.empty()) { + tilingOptions.setTileSizeComputationFunction( + [](OpBuilder &, Operation *) -> SmallVector { + return {}; + }); + } else { tilingOptions.setTileSizeComputationFunction([&, index = i](OpBuilder &b, Operation *) { SmallVector sizes; diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index 22826cababe77..38e0625d7ce09 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -283,10 +283,7 @@ mlir::scf::tileUsingSCFForOp(RewriterBase &rewriter, TilingInterface op, // 1. Get the range of the loops that are represented by the operation. SmallVector iterationDomain = op.getIterationDomain(rewriter); size_t numLoops = iterationDomain.size(); - if (numLoops == 0) { - return rewriter.notifyMatchFailure( - op, "unable to tile op with no iteration domain"); - } + // 2. Materialize the tile sizes. Enforce the convention that "tiling by zero" // skips tiling a particular dimension. This convention is significantly // simpler to handle instead of adjusting affine maps to account for missing diff --git a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir index c8199c325abfe..7245498f641ec 100644 --- a/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/lower-to-loops-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=lower-to-scalar-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt -transform-interpreter -split-input-file -canonicalize -cse %s | FileCheck %s func.func @gemm(%arg0 : memref, %arg1 : memref, %arg2 : memref) { @@ -6,13 +6,22 @@ func.func @gemm(%arg0 : memref, %arg1 : memref, outs(%arg2 : memref) return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %matmul : !transform.any_op + transform.yield + } +} // CHECK-LABEL: func @gemm // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: memref // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index -// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C1]] @@ -51,6 +60,15 @@ func.func @indexed_generic(%arg0 : memref<200x300xi32>, %arg1 : memref<300xi16>, } return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %generic : !transform.any_op + transform.yield + } +} // CHECK-LABEL: func @indexed_generic // CHECK-SAME: %[[ARG0:.+]]: memref<200x300xi32> // CHECK-SAME: %[[ARG1:.+]]: memref<300xi16> @@ -87,8 +105,18 @@ func.func @conv_strides_and_dilation(%arg0 : memref, %arg1 : memref outs(%arg2 : memref) return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d1 + d4 * 3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d2 * 2 + d5 * 4)> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %conv : !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECK: func @conv_strides_and_dilation( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref @@ -111,8 +139,8 @@ func.func @conv_strides_and_dilation(%arg0 : memref, %arg1 : memref // CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]] // CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]] // CHECK: scf.for %[[IV6:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]] -// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]]) -// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]], %[[IV6]]) +// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]]) +// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]]) // CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV6]]] // CHECK-DAG: %[[T10:.+]] = memref.load %[[ARG1]][%[[IV4]], %[[IV5]], %[[IV6]], %[[IV3]]] // CHECK-DAG: %[[T11:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] @@ -131,8 +159,18 @@ func.func @pool_strides_and_dilation(%arg0 : memref, %arg1 : memref outs(%arg2 : memref) return } -// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d1 + d4 * 3)> -// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1, d2, d3, d4, d5) -> (d2 * 2 + d5 * 4)> + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pool = transform.structured.match ops{["linalg.pooling_nhwc_max"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %pool : !transform.any_op + transform.yield + } +} + +// CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0, d1) -> (d0 + d1 * 3)> +// CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0, d1) -> (d0 * 2 + d1 * 4)> // CHECK: func @pool_strides_and_dilation // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: memref // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: memref @@ -153,8 +191,8 @@ func.func @pool_strides_and_dilation(%arg0 : memref, %arg1 : memref // CHECK: scf.for %[[IV3:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C1]] // CHECK: scf.for %[[IV4:[a-zA-Z0-9]+]] = %[[C0]] to %[[H]] step %[[C1]] // CHECK: scf.for %[[IV5:[a-zA-Z0-9]+]] = %[[C0]] to %[[W]] step %[[C1]] -// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]) -// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]], %[[IV4]], %[[IV5]]) +// CHECK-DAG: %[[I:.+]] = affine.apply #[[MAP0]](%[[IV1]], %[[IV4]]) +// CHECK-DAG: %[[J:.+]] = affine.apply #[[MAP1]](%[[IV2]], %[[IV5]]) // CHECK-DAG: %[[T8:.+]] = memref.load %[[ARG0]][%[[IV0]], %[[I]], %[[J]], %[[IV3]]] // CHECK-DAG: %[[T9:.+]] = memref.load %[[ARG2]][%[[IV0]], %[[IV1]], %[[IV2]], %[[IV3]]] // CHECK: %[[T10:.+]] = arith.maximumf %[[T9]], %[[T8]] @@ -172,6 +210,15 @@ func.func @map(%lhs: memref<64xf32>, } return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %map = transform.structured.match ops{["linalg.map"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %map : !transform.any_op + transform.yield + } +} // CHECK-LABEL: func.func @map( // CHECK-SAME: %[[LHS:[a-zA-Z0-9]+]]: memref<64xf32>, // CHECK-SAME: %[[RHS:[a-zA-Z0-9]+]]: memref<64xf32>, @@ -195,6 +242,15 @@ func.func @transpose(%arg0: memref<16x32x64xf32>, outs(%arg1 : memref<32x64x16xf32>) permutation = [1, 2, 0] return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %transpose = transform.structured.match ops{["linalg.transpose"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %transpose : !transform.any_op + transform.yield + } +} // CHECK-LABEL: func.func @transpose( // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>, // CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<32x64x16xf32>) @@ -223,6 +279,15 @@ func.func @reduce(%arg0: memref<16x32x64xf32>, } return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %reduce = transform.structured.match ops{["linalg.reduce"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %reduce : !transform.any_op + transform.yield + } +} // CHECK-LABEL: func.func @reduce( // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<16x32x64xf32>, // CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<16x64xf32> @@ -251,6 +316,15 @@ func.func @broadcast(%input: memref<8x32xf32>, dimensions = [1] func.return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %broadcast = transform.structured.match ops{["linalg.broadcast"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + transform.structured.convert_to_loops %broadcast : !transform.any_op + transform.yield + } +} // CHECK-LABEL: func.func @broadcast( // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: memref<8x32xf32>, // CHECK-SAME: %[[OUT:[a-zA-Z0-9]+]]: memref<8x16x32xf32> diff --git a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir index 2078b5b4dabb2..11ab30a7d237c 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-and-fuse-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-and-fuse-producer-using-scf-for -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> tensor { %c0 = arith.constant 0 : index @@ -8,11 +8,20 @@ func.func @gemm_fill_fusion(%arg0 : tensor, %arg1 : tensor) -> %d1 = tensor.dim %arg1, %c1 : tensor %init = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %gemm = linalg.matmul {__internal_transform__ = "fusion"} - ins(%arg0, %arg1 : tensor, tensor) + %gemm = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor return %gemm : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %matmul [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func.func @gemm_fill_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) @@ -43,11 +52,9 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, %d1 = tensor.dim %arg1, %c1 : tensor %init = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init : tensor) -> tensor - %gemm = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor) + %gemm = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor %generic = linalg.generic { - __internal_transform__ = "fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%gemm, %arg2 : tensor, tensor) outs(%init : tensor) { @@ -57,6 +64,16 @@ func.func @gemm_generic_fusion(%arg0 : tensor, %arg1 : tensor, } -> tensor return %generic : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %generic [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func.func @gemm_generic_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor, @@ -97,10 +114,22 @@ func.func @gemm_gemm_fusion(%lhs0 : tensor, %rhs0 : tensor, %r %d2 = tensor.dim %rhs1, %c1 : tensor %init1 = tensor.empty(%d0, %d2) : tensor %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor - %gemm1 = linalg.matmul {__internal_transform__ = "gemm_fusion"} + %gemm1 = linalg.matmul ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor return %gemm1 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %mm1, %mm2 = transform.split_handle %matmuls + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.structured.fuse %mm2 [10] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func.func @gemm_gemm_fusion( // CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, @@ -142,12 +171,10 @@ func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor %init0 = tensor.empty(%d0, %d1) : tensor %fill = linalg.fill ins(%cst : f32) outs(%init0 : tensor) -> tensor - %gemm = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor) + %gemm = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%fill : tensor) -> tensor %init1 = tensor.empty(%d1, %d0) : tensor %transpose = linalg.generic { - __internal_transform__ = "fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>], iterator_types = ["parallel", "parallel"]} ins(%gemm : tensor) outs(%init1 : tensor) { @@ -156,6 +183,16 @@ func.func @gemm_transpose_fusion(%arg0 : tensor, %arg1 : tensor tensor return %transpose : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %generic [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func.func @gemm_transpose_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) @@ -194,11 +231,9 @@ func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor %1 = linalg.fill ins(%cst : f32) outs(%0 : tensor) -> tensor - %2 = linalg.matmul - ins(%arg0, %arg1 : tensor, tensor) + %2 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%1 : tensor) -> tensor %3 = linalg.generic { - __internal_transform__ = "gemm_interchange_fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor) outs(%0 : tensor) { @@ -208,6 +243,16 @@ func.func @interchange_matmul_fusion(%arg0 : tensor, %arg1 : tensor tensor return %3 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %generic [10, 20] interchange[1, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func.func @interchange_matmul_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor) @@ -248,8 +293,7 @@ func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_transform__ = "gemm_plus_gemm_fusion"} + iterator_types = ["parallel", "parallel"]} ins(%2, %2 : tensor, tensor) outs(%5 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : @@ -258,8 +302,16 @@ func.func @matmul_plus_matmul(%arg0: tensor, %arg1: tensor, } -> tensor return %6 : tensor } -// This fuses as expected but the gemm operation is inlined twice. It should be CSE-d but isnt today. +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %generic [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func @matmul_plus_matmul // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor @@ -301,8 +353,7 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor, %arg1: tensor (d0, d1)>, affine_map<(d0, d1) -> (d1, d0)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"], - __internal_transform__ = "gemm_plus_gemm_fusion"} + iterator_types = ["parallel", "parallel"]} ins(%2, %2 : tensor, tensor) outs(%5 : tensor) { ^bb0(%arg3 : f32, %arg4 : f32, %arg5 : f32) : @@ -311,6 +362,16 @@ func.func @matmul_plus_transpose_matmul(%arg0: tensor, %arg1: tensor tensor return %6 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.fuse %generic [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func @matmul_plus_transpose_matmul // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor // CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]+]]: tensor @@ -351,13 +412,22 @@ func.func @matmul_sequence_fusion(%arg0: tensor, %arg1: tensor outs(%arg2 : tensor) -> tensor // [M, N0] * [N0, N1] %1 = linalg.matmul ins(%0, %arg3 : tensor, tensor) outs(%arg4 : tensor) -> tensor // [M, N1] * [N1, N2] - %2 = linalg.matmul - {__internal_transform__ = "gemm_sequence_fusion"} - ins(%1, %arg5 : tensor, tensor) + %2 = linalg.matmul ins(%1, %arg5 : tensor, tensor) outs(%arg6 : tensor) -> tensor // [M, N2] * [N2, N3] return %2 : tensor } +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %mm1, %mm2, %mm3 = transform.split_handle %matmuls + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.structured.fuse %mm3 [10] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: #[[MAP:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> // CHECK: func @matmul_sequence_fusion( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]+]]: tensor @@ -425,7 +495,6 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> { linalg.yield %10, %9 : f32, f32 } -> (tensor<30xf32>, tensor<30x3xf32>) %6 = linalg.generic { - __internal_transform__ = "reduction_sequence_fusion", indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} @@ -436,6 +505,18 @@ func.func @reduction_sequence(%arg0: tensor<30x3xf32>) -> tensor<30x3xf32> { } -> tensor<30x3xf32> return %6 : tensor<30x3xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generics = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %generic1, %generic2, %generic3 = transform.split_handle %generics + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + %a, %b = transform.structured.fuse %generic3 [10] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func @reduction_sequence(%[[ARG0:.+]]: tensor<30x3xf32>) // CHECK-DAG: %[[INIT0:.+]] = tensor.empty() : tensor<30xf32> // CHECK-DAG: %[[INIT1:.+]] = tensor.empty() : tensor<30x3xf32> diff --git a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir index f725d19e14a0c..3d353c068a9f9 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-fuse-and-yield-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-consumer-fuse-and-yield-producer-using-scf-for -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor, %rhs0 : tensor, %rhs1 : tensor, %init0 : tensor, %init1 : tensor) @@ -13,10 +13,22 @@ func.func @gemm_gemm_fusion_yield_both(%lhs0 : tensor, %rhs0 : tensor, tensor) outs(%fill0 : tensor) -> tensor %d2 = tensor.dim %rhs1, %c1 : tensor %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor) -> tensor - %gemm1 = linalg.matmul {__internal_transform__ = "gemm_sequence_fusion_and_yield"} + %gemm1 = linalg.matmul ins(%gemm0, %rhs1 : tensor, tensor) outs(%fill1 : tensor) -> tensor return %gemm0, %gemm1 : tensor, tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmuls = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %mm1, %mm2 = transform.split_handle %matmuls + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + %a, %b = transform.test.fuse_and_yield %mm2 [10] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: func.func @gemm_gemm_fusion_yield_both( // CHECK-SAME: %[[LHS0:[a-zA-Z0-9]+]]: tensor // CHECK-SAME: %[[RHS0:[a-zA-Z0-9]+]]: tensor, diff --git a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir index cbc5d6c186d6d..05b7afdf0d1ca 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-pad-using-interface.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -resolve-shaped-type-result-dims -cse -split-input-file %s | FileCheck %s +// RUN: mlir-opt -transform-interpreter -cse -split-input-file %s | FileCheck %s // 2D tiling of dynamic 2D pad tensor op. func.func @dynamic_2d_pad_tensor(%input_tensor: tensor, @@ -6,22 +6,33 @@ func.func @dynamic_2d_pad_tensor(%input_tensor: tensor, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_transform__ = "pad_2dtiling"}: tensor to tensor + } : tensor to tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %pad [2, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} + // CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 + 8)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 + 7)> // CHECK: func @dynamic_2d_pad_tensor( // CHECK-SAME: %[[IN:[a-zA-Z0-9]+]]: tensor // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index +// CHECK-DAG: %[[DIM_IN0:.+]] = tensor.dim %[[IN]], %[[C0]] +// CHECK-DAG: %[[DIM0:.+]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]] // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]] +// CHECK-DAG: %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]] // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index -// CHECK: %[[DIM_IN0:.+]] = tensor.dim %[[IN]], %[[C0]] -// CHECK: %[[DIM0:.+]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]] -// CHECK: %[[DIM_IN1:.+]] = tensor.dim %[[IN]], %[[C1]] -// CHECK: %[[DIM1:.+]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]] // CHECK: %[[RESULT:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[DIM0]] step %[[C2]] +// CHECK: %[[C3:.+]] = arith.constant 3 : index // CHECK: scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = // CHECK: %[[SWAP_RESULT:.*]] = scf.if // CHECK: tensor.generate @@ -38,20 +49,30 @@ func.func @dynamic_2d_pad_tensor_inner_tiling(%input_tensor: tensor, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_transform__ = "pad_inner_tiling"}: tensor to tensor + } : tensor to tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.structured.tile_using_for %pad [0, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[MAP0:.*]] = affine_map<()[s0] -> (s0 + 8)> // CHECK-DAG: #[[MAP1:.*]] = affine_map<()[s0] -> (s0 + 7)> // CHECK: func @dynamic_2d_pad_tensor_inner_tiling( // CHECK-SAME: %[[IN:.*]]: tensor // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[DIM_IN0:.*]] = tensor.dim %[[IN]], %[[C0]] +// CHECK-DAG: %[[DIM0:.*]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]] // CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[DIM_IN1:.*]] = tensor.dim %[[IN]], %[[C1]] +// CHECK-DAG: %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]] // CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index -// CHECK: %[[DIM_IN0:.*]] = tensor.dim %[[IN]], %[[C0]] -// CHECK: %[[DIM0:.*]] = affine.apply #[[MAP0]]()[%[[DIM_IN0]]] -// CHECK: %[[DIM_IN1:.*]] = tensor.dim %[[IN]], %[[C1]] -// CHECK: %[[DIM1:.*]] = affine.apply #[[MAP1]]()[%[[DIM_IN1]]] // CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[DIM1]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = // CHECK: %[[SWAP_RESULT:.*]] = scf.if // CHECK: tensor.generate @@ -68,17 +89,27 @@ func.func @static_pad_tensor(%input_tensor: tensor<7x9xf32>, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_transform__ = "pad_2dtiling"} : tensor<7x9xf32> to tensor<15x16xf32> + } : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %pad [2, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @static_pad_tensor( // CHECK-SAME: %[[IN:.*]]: tensor<7x9xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK-DAG: %[[C15:.*]] = arith.constant 15 : index -// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK: %[[RESULT:.*]] = scf.for {{.*}} = %[[C0]] to %[[C15]] step %[[C2]] +// CHECK-DAG: %[[C16:.*]] = arith.constant 16 : index +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index // CHECK: scf.for {{.*}} = %[[C0]] to %[[C16]] step %[[C3]] iter_args(%[[INNER_OUT:.*]] = // CHECK: %[[SWAP_RESULT:.*]] = scf.if // CHECK: tensor.generate @@ -95,9 +126,19 @@ func.func @static_pad_tensor_inner_tiling(%input_tensor: tensor<7x9xf32>, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32> + } : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.structured.tile_using_for %pad [0, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @static_pad_tensor_inner_tiling( // CHECK-SAME: %[[IN:.*]]: tensor<7x9xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index @@ -122,9 +163,19 @@ func.func @dynamic_2d_pad_tensor_outer_tiling(%input_tensor: tensor, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_transform__ = "pad_outer_tiling"}: tensor to tensor + } : tensor to tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %pad [2, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @dynamic_2d_pad_tensor_outer_tiling // ----- @@ -134,7 +185,17 @@ func.func @static_pad_tensor_outer_tiling(%input_tensor: tensor<7x9xf32>, %0 = tensor.pad %input_tensor low[3, 4] high[5, 3] { ^bb0(%arg1: index, %arg2: index): tensor.yield %pad_value : f32 - } {__internal_transform__ = "pad_inner_tiling"} : tensor<7x9xf32> to tensor<15x16xf32> + } : tensor<7x9xf32> to tensor<15x16xf32> return %0 : tensor<15x16xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %pad = transform.structured.match ops{["tensor.pad"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.structured.tile_using_for %pad [0, 3] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @static_pad_tensor_outer_tiling diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir index e99ffc88066d6..444232e9e1e2e 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-interface.mlir @@ -1,12 +1,21 @@ -// RUN: mlir-opt -test-tiling-interface=tile-using-scf-for -split-input-file %s | FileCheck %s +// RUN: mlir-opt --transform-interpreter --cse -split-input-file %s | FileCheck %s func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.matmul {__internal_transform__ = "simple_gemm"} - ins(%arg0, %arg1 : tensor, tensor) + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %matmul [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> // CHECK-LABEL: func.func @simple_matmul( @@ -16,13 +25,13 @@ func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] // CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] // CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]] // CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]]) -// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]] +// CHECK-DAG: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]] +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK: %[[INNER:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]] // CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) // CHECK: %[[TS_X:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]] @@ -45,11 +54,20 @@ func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref, %arg2 : memref) { - linalg.matmul {__internal_transform__ = "simple_gemm_memref"} - ins(%arg0, %arg1 : memref, memref) + linalg.matmul ins(%arg0, %arg1 : memref, memref) outs(%arg2 : memref) return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c, %d = transform.structured.tile_using_for %matmul [10, 20, 30] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)> @@ -60,15 +78,15 @@ func.func @simple_matmul_memref(%arg0 : memref, %arg1 : memref // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK-DAG: %[[M:.+]] = memref.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[K:.+]] = memref.dim %[[ARG0]], %[[C1]] // CHECK-DAG: %[[N:.+]] = memref.dim %[[ARG1]], %[[C1]] // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]] -// CHECK: %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]] +// CHECK-DAG: %[[TS_M:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[M]]] +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]] -// CHECK: %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]] +// CHECK-DAG: %[[TS_N:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[N]]] +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]] // CHECK: %[[TS_K:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[K]]] // CHECK-DAG: %[[LHS_TILE:.+]] = memref.subview %[[ARG0]] @@ -91,8 +109,7 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x %init1 = tensor.empty() : tensor<300x128x200xf32> %0:2 = linalg.generic { indexing_maps = [#map0, #map1, #map2], - iterator_types = ["parallel", "parallel", "parallel"]} - {__internal_transform__ = "parallel_generic_transpose"} + iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0 : tensor<128x200x300xf32>) outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): @@ -100,19 +117,29 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %generic [10, 0, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)> // CHECK-LABEL: func.func @multi_result( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK-DAG: %[[C128:.+]] = arith.constant 128 : index -// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index // CHECK-DAG: %[[INIT0:.+]] = tensor.empty() // CHECK-DAG: %[[INIT1:.+]] = tensor.empty() // CHECK: %[[OUTER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[C128]] step %[[C10]] // CHECK-SAME: iter_args(%[[ARG1:[a-zA-Z0-9]+]] = %[[INIT0]], %[[ARG2:[a-zA-Z0-9]+]] = %[[INIT1]]) -// CHECK: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]]) +// CHECK-DAG: %[[TS_Y:.+]] = affine.min #[[$MAP0]](%[[IV0]]) +// CHECK-DAG: %[[C300:.+]] = arith.constant 300 : index +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK: %[[INNER:[a-zA-Z0-9]+]]:2 = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[C300]] step %[[C20]] // CHECK-SAME: iter_args(%[[ARG3:[a-zA-Z0-9]+]] = %[[ARG1]], %[[ARG4:[a-zA-Z0-9]+]] = %[[ARG2]]) // CHECK-DAG: %[[ARG_TILE:.+]] = tensor.extract_slice %[[ARG0]] @@ -138,12 +165,21 @@ func.func @conv2D(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.conv_2d_nhwc_hwcf { strides = dense<[2, 3]> : tensor<2xi64>, - dilation = dense<[4, 5]> : tensor<2xi64>, - __internal_transform__ = "simple_conv"} + dilation = dense<[4, 5]> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c, %d = transform.structured.tile_using_for %conv [0, 0, 0, 0, 10, 20, 30] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)> @@ -158,8 +194,6 @@ func.func @conv2D(%arg0 : tensor, %arg1 : tensor, // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index // CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index -// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK-DAG: %[[N:.+]] = tensor.dim %[[INPUT]], %[[C0]] // CHECK-DAG: %[[C:.+]] = tensor.dim %[[INPUT]], %[[C3]] // CHECK-DAG: %[[P:.+]] = tensor.dim %[[FILTER]], %[[C0]] @@ -169,10 +203,12 @@ func.func @conv2D(%arg0 : tensor, %arg1 : tensor, // CHECK-DAG: %[[S:.+]] = tensor.dim %[[INIT]], %[[C2]] // CHECK: scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[P]] step %[[C10]] // CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[INIT]]) -// CHECK: %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]] +// CHECK-DAG: %[[TS_P:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[P]]] +// CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index // CHECK: scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[Q]] step %[[C20]] // CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) -// CHECK: %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]] +// CHECK-DAG: %[[TS_Q:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[Q]]] +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK: scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[C]] step %[[C30]] // CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]]) // CHECK-DAG: %[[TS_C:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[C]]] @@ -199,8 +235,7 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> %0 = linalg.generic { indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], - iterator_types = ["parallel", "parallel"]} - {__internal_transform__ = "indexed_semantics"} + iterator_types = ["parallel", "parallel"]} ins(%arg0: tensor) outs(%arg1: tensor) { ^bb0(%arg2: f32, %arg3: f32): @@ -214,6 +249,16 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> } -> (tensor) return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %generic [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK: #[[$MAP_ADD:.+]] = affine_map<(d0, d1) -> (d0 + d1)> // CHECK-LABEL: @indexed_semantics // CHECK: scf.for %[[I0:.+]] = %{{.*}} to %{{.*}} step %{{.*}} @@ -228,11 +273,20 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> func.func @interchange_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.matmul {__internal_transform__ = "gemm_interchange"} - ins(%arg0, %arg1 : tensor, tensor) + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c, %d = transform.structured.tile_using_for %matmul [10, 20, 30] interchange = [1, 2, 0] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)> // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> @@ -242,18 +296,18 @@ func.func @interchange_matmul(%arg0 : tensor, %arg1 : tensor, // CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor // CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index -// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index // CHECK-DAG: %[[C20:.+]] = arith.constant 20 : index -// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK-DAG: %[[M:.+]] = tensor.dim %[[ARG0]], %[[C0]] // CHECK-DAG: %[[K:.+]] = tensor.dim %[[ARG0]], %[[C1]] // CHECK-DAG: %[[N:.+]] = tensor.dim %[[ARG1]], %[[C1]] // CHECK: %[[OUTER:[a-zA-Z0-9]+]] = scf.for %[[IV0:[a-zA-Z0-9]+]] = %[[C0]] to %[[N]] step %[[C20]] // CHECK-SAME: iter_args(%[[INIT0:.+]] = %[[ARG2]]) -// CHECK: %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]] +// CHECK-DAG: %[[TS_N:.+]] = affine.min #[[$MAP0]](%[[IV0]])[%[[N]]] +// CHECK-DAG: %[[C30:.+]] = arith.constant 30 : index // CHECK: %[[INNER1:[a-zA-Z0-9]+]] = scf.for %[[IV1:[a-zA-Z0-9]+]] = %[[C0]] to %[[K]] step %[[C30]] // CHECK-SAME: iter_args(%[[INIT1:.+]] = %[[INIT0]]) -// CHECK: %[[TS_K:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[K]]] +// CHECK-DAG: %[[TS_K:.+]] = affine.min #[[$MAP1]](%[[IV1]])[%[[K]]] +// CHECK-DAG: %[[C10:.+]] = arith.constant 10 : index // CHECK: %[[INNER2:[a-zA-Z0-9]+]] = scf.for %[[IV2:[a-zA-Z0-9]+]] = %[[C0]] to %[[M]] step %[[C10]] // CHECK-SAME: iter_args(%[[INIT2:.+]] = %[[INIT1]]) // CHECK-DAG: %[[TS_M:.+]] = affine.min #[[$MAP2]](%[[IV2]])[%[[M]]] @@ -276,10 +330,19 @@ func.func @interchange_matmul(%arg0 : tensor, %arg1 : tensor, // ----- func.func @linalg_copy_matmul(%a: memref, %b: memref) { - linalg.copy {__internal_transform__ = "simple_copy_memref"} - ins(%a : memref) outs(%b : memref) + linalg.copy ins(%a : memref) outs(%b : memref) return } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %copy = transform.structured.match ops{["linalg.copy"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b, %c = transform.structured.tile_using_for %copy [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @linalg_copy_matmul( // CHECK: scf.for // CHECK: scf.for @@ -293,8 +356,7 @@ func.func @check_scalar_operation(%arg0 : tensor) -> tensor { %init = tensor.empty() : tensor %0 = linalg.generic { indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], - iterator_types = []} - {__internal_transform__ = "scalar_op"} + iterator_types = []} ins(%arg0 : tensor) outs(%init : tensor){ ^bb0(%b0 : f32, %b1 : f32): %1 = arith.mulf %b0, %b0 : f32 @@ -302,18 +364,26 @@ func.func @check_scalar_operation(%arg0 : tensor) -> tensor { } -> tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a = transform.structured.tile_using_for %generic [] + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @check_scalar_operation // CHECK-NOT: scf.for // CHECK: linalg.generic -// CHECK-SAME: __internal_transform__ = "scalar_op" // ----- func.func @check_scalar_memref_operation(%arg0 : memref, %arg1 : memref){ linalg.generic { indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], - iterator_types = []} - {__internal_transform__ = "scalar_op"} + iterator_types = []} ins(%arg0 : memref) outs(%arg1 : memref){ ^bb0(%b0 : f32, %b1 : f32): %1 = arith.mulf %b0, %b0 : f32 @@ -321,7 +391,16 @@ func.func @check_scalar_memref_operation(%arg0 : memref, %arg1 : memref !transform.any_op + %a = transform.structured.tile_using_for %generic [] + : (!transform.any_op) -> (!transform.any_op) + transform.yield + } +} // CHECK-LABEL: func @check_scalar_memref_operation // CHECK-NOT: scf.for // CHECK: linalg.generic -// CHECK-SAME: __internal_transform__ = "scalar_op" diff --git a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir index 314efde45720a..db0c1327e2fe0 100644 --- a/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir +++ b/mlir/test/Interfaces/TilingInterface/tile-using-scfforall.mlir @@ -1,12 +1,22 @@ -// RUN: mlir-opt -test-tiling-interface=tile-using-scf-forall -split-input-file %s | FileCheck %s +// RUN: mlir-opt -transform-interpreter -split-input-file --cse %s | FileCheck %s func.func @simple_matmul(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { - %0 = linalg.matmul {__internal_transform__ = "simple_gemm"} + %0 = linalg.matmul ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %matmul = transform.structured.match ops{["linalg.matmul"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %matmul [10, 20] mapping [#gpu.block, #gpu.block] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> // CHECK-DAG: #[[MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> // CHECK: func.func @simple_matmul( @@ -48,7 +58,6 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x %0:2 = linalg.generic { indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel"]} - {__internal_transform__ = "parallel_generic_transpose"} ins(%arg0 : tensor<128x200x300xf32>) outs(%init0, %init1 : tensor<128x300x200xf32>, tensor<300x128x200xf32>) { ^bb0(%b0 : f32, %b1 : f32, %b2 : f32): @@ -56,6 +65,16 @@ func.func @multi_result(%arg0 : tensor<128x200x300xf32>) -> (tensor<128x300x200x } -> (tensor<128x300x200xf32>, tensor<300x128x200xf32>) return %0#0, %0#1 : tensor<128x300x200xf32>, tensor<300x128x200xf32> } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %generic [10, 0, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0) -> (10, -d0 + 128)> // CHECK-LABEL: func.func @multi_result( // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x200x300xf32>) @@ -85,12 +104,21 @@ func.func @conv2D(%arg0 : tensor, %arg1 : tensor, %arg2 : tensor) -> tensor { %0 = linalg.conv_2d_nhwc_hwcf { strides = dense<[2, 3]> : tensor<2xi64>, - dilation = dense<[4, 5]> : tensor<2xi64>, - __internal_transform__ = "simple_conv"} + dilation = dense<[4, 5]> : tensor<2xi64>} ins(%arg0, %arg1 : tensor, tensor) outs(%arg2 : tensor) -> tensor return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %conv = transform.structured.match ops{["linalg.conv_2d_nhwc_hwcf"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %conv [0, 0, 0, 0, 10, 20, 30] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} // CHECK-DAG: #[[$MAP0:.+]] = affine_map<(d0)[s0] -> (10, -d0 + s0)> // CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0)[s0] -> (20, -d0 + s0)> // CHECK-DAG: #[[$MAP2:.+]] = affine_map<(d0)[s0] -> (30, -d0 + s0)> @@ -144,7 +172,6 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} - {__internal_transform__ = "indexed_semantics"} ins(%arg0: tensor) outs(%arg1: tensor) { ^bb0(%arg2: f32, %arg3: f32): @@ -158,6 +185,17 @@ func.func @indexed_semantics(%arg0: tensor, %arg1: tensor) -> } -> (tensor) return %0 : tensor } + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1 : !transform.any_op {transform.readonly}) { + %generic = transform.structured.match ops{["linalg.generic"]} in %arg1 + : (!transform.any_op) -> !transform.any_op + %a, %b = transform.test.tile_using_forall %generic [10, 20] + : (!transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } +} + // CHECK-LABEL: @indexed_semantics // CHECK: scf.forall (%[[I0:.+]], %[[I1:.+]]) = // CHECK: %[[INDEX0:.+]] = linalg.index 0 diff --git a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt index 5f974b6198983..6dc633c9e21a7 100644 --- a/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt +++ b/mlir/test/lib/Interfaces/TilingInterface/CMakeLists.txt @@ -1,5 +1,13 @@ +set(LLVM_TARGET_DEFINITIONS TestTilingInterfaceTransformOps.td) +mlir_tablegen(TestTilingInterfaceTransformOps.h.inc -gen-op-decls) +mlir_tablegen(TestTilingInterfaceTransformOps.cpp.inc -gen-op-defs) +add_public_tablegen_target(MLIRTestTilingInterfaceTransformOpsIncGen) + add_mlir_library(MLIRTilingInterfaceTestPasses - TestTilingInterface.cpp + TestTilingInterfaceTransformOps.cpp + + DEPENDS + MLIRTestTilingInterfaceTransformOpsIncGen EXCLUDE_FROM_LIBMLIR diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp deleted file mode 100644 index 798293bc1327e..0000000000000 --- a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterface.cpp +++ /dev/null @@ -1,620 +0,0 @@ -//===- TestTilingInterface.cpp - Test tiling using `TilingInterface` -----===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// -// This file implements a pass for testing tiling operations using -// `TilingInterface`. -// -//===----------------------------------------------------------------------===// - -#include -#include - -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Transforms/TilingInterfaceImpl.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/IR/TensorTilingInterfaceImpl.h" -#include "mlir/Interfaces/TilingInterface.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/ADT/TypeSwitch.h" - -using namespace mlir; - -// TODO: this file should disappear and instead tests should make use of the -// transform dialect. -namespace { - -/// Marker used as attribute name in generated Linalg rewriting transformations. -const StringLiteral kTransformMarker = "__internal_transform__"; - -/// Helper class to control application of linalg transformation patterns. -/// Control comes in 2 forms: -/// 1. attribute matching and setting behavior using the attribute named -/// `kTransformMarker`. This can be used to build a state machine -/// using attributes and incrementally applying patterns to advance states. -/// 2. filter function, which is a simple lambda on the Operation* that -/// returns a LogicalResult. -struct TransformationFilter { - using FilterFunction = std::function; - - explicit TransformationFilter( - ArrayRef matchDisjunction = {}, - std::optional replacement = std::nullopt); - - explicit TransformationFilter( - const FilterFunction &f, ArrayRef matchDisjunction = {}, - std::optional replacement = std::nullopt); - - TransformationFilter(TransformationFilter &&) = default; - TransformationFilter(const TransformationFilter &) = default; - LogicalResult checkAndNotify(PatternRewriter &rewriter, Operation *op) const; - void replaceTransformationFilter(PatternRewriter &rewriter, - Operation *op) const; - - TransformationFilter &addFilter(const FilterFunction &f) { - if (f) - filters.push_back(f); - return *this; - } - - template - TransformationFilter &addOpFilter() { - return addFilter( - [](Operation *op) { return success(isa(op)); }); - } - - TransformationFilter &addOpNameFilter(StringRef opName) { - return addFilter([opName](Operation *op) { - return success(op->getName().getStringRef() == opName); - }); - } - - TransformationFilter &setMatchByDefault() { - matchByDefault = true; - return *this; - } - -private: - SmallVector filters; - SmallVector matchDisjunction; - std::optional replacement; - /// When set to true, if the attribute is not set, it will be treated as - /// a match. Default is false. - bool matchByDefault; -}; - -TransformationFilter::TransformationFilter( - ArrayRef matchDisjunction, - std::optional replacement) - : matchDisjunction(matchDisjunction.begin(), matchDisjunction.end()), - replacement(replacement), matchByDefault(false) {} - -LogicalResult TransformationFilter::checkAndNotify(PatternRewriter &rewriter, - Operation *op) const { - if (llvm::any_of(filters, - [&](const FilterFunction &f) { return failed(f(op)); })) - return failure(); - - auto attr = op->template getAttrOfType(kTransformMarker); - - if (!attr) { - // 1. Has no filter case and matchDisjunction is empty. - if (matchDisjunction.empty() || matchByDefault) - return success(); - - // 2. Has no filter but was expecting a filter. - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << " does not have any filter from list: "; - interleaveComma(matchDisjunction, diag); - }); - } - - // 4. Match explicit filter. - for (auto filter : matchDisjunction) - if (attr.getValue() == filter) - return success(); - - // 5. Fail to match. - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << " does not have any filter from list: "; - interleaveComma(matchDisjunction, diag); - }); -} - -void TransformationFilter::replaceTransformationFilter( - PatternRewriter &rewriter, Operation *op) const { - if (replacement.has_value()) - op->setAttr(kTransformMarker, *replacement); - else - op->removeAttr(rewriter.getStringAttr(kTransformMarker)); -} - -/// Pattern for testing `TileUsingSCFForOp` pattern (that tiles operations using -/// the `TilingInterface` with `scf.for` ops for iterating over the tiles) while -/// using a `filter` to avoid recursive application. -struct TestTileUsingSCFForOp - : public OpInterfaceRewritePattern { - TestTileUsingSCFForOp(MLIRContext *context, scf::SCFTilingOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - /// Construct a generic pattern applied to `opName`. - TestTileUsingSCFForOp(StringRef opName, MLIRContext *context, - scf::SCFTilingOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - FailureOr tilingResult = - scf::tileUsingSCFForOp(rewriter, op, options); - if (failed(tilingResult)) - return rewriter.notifyMatchFailure(op, "failed to tile operation"); - - if (op->getNumResults()) { - rewriter.replaceOp(op, tilingResult->replacements); - } else { - rewriter.eraseOp(op); - } - - for (auto *tiledOp : tilingResult->tiledOps) - filter.replaceTransformationFilter(rewriter, tiledOp); - return success(); - } - -private: - scf::SCFTilingOptions options; - TransformationFilter filter; -}; - -/// Pattern for testing `tileUsingSCFForallOp` (that tiles operations using -/// the `TilingInterface` with `scf.forall` ops for iterating over the tiles) -/// while using a `filter` to avoid recursive application. -struct TestTileUsingSCFForallOp - : public OpInterfaceRewritePattern { - TestTileUsingSCFForallOp(MLIRContext *context, scf::SCFTilingOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - /// Construct a generic pattern applied to `opName`. - TestTileUsingSCFForallOp(StringRef opName, MLIRContext *context, - scf::SCFTilingOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - FailureOr tilingResult = - scf::tileUsingSCFForallOp(rewriter, op, options); - if (failed(tilingResult)) - return rewriter.notifyMatchFailure(op, "failed to tile operation"); - - if (op->getNumResults()) { - rewriter.replaceOp(op, tilingResult->replacements); - } else { - rewriter.eraseOp(op); - } - - for (auto *tiledOp : tilingResult->tiledOps) - filter.replaceTransformationFilter(rewriter, tiledOp); - return success(); - } - -private: - scf::SCFTilingOptions options; - TransformationFilter filter; -}; - -/// Pattern for testing `TileConsumerAndFuseProducersUsingSCFForOp` pattern -/// (that tiles and fuses operations using the `TilingInterface` with `scf.for` -/// ops for iterating over the tiles) while using a `filter` to avoid recursive -/// application. -struct TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp - : public OpInterfaceRewritePattern { - TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( - MLIRContext *context, scf::SCFTileAndFuseOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - /// Construct a generic pattern applied to `opName`. - TestTileConsumerAndFuseProducersGreedilyUsingSCFForOp( - StringRef opName, MLIRContext *context, - scf::SCFTileAndFuseOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, op))) - return failure(); - - FailureOr tileAndFuseResult = - scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp(rewriter, op, - options); - if (failed(tileAndFuseResult)) { - return failure(); - } - // Replace the tiled op with replacements. - SmallVector replacements(op->getNumResults()); - for (const auto &result : llvm::enumerate(op->getResults())) { - replacements[result.index()] = - tileAndFuseResult->replacements.lookup(result.value()); - } - rewriter.replaceOp(op, replacements); - - filter.replaceTransformationFilter( - rewriter, tileAndFuseResult->tiledAndFusedOps.front()); - return success(); - } - -private: - scf::SCFTileAndFuseOptions options; - TransformationFilter filter; -}; - -/// Pattern to tile a consumer and fuse producer with it -/// while reconstructing the value of the fused producer -/// from within the loop nest to replace any external -/// uses of the producer. In general yielding the producer -/// this way requires a guarantee that the slice of the producer -/// is not computed redundantly within the tiled loops. An analysis that -/// figures it out has shown to be very complex. So this is left as a caller -/// side determination. In this test pattern it is assumed that the tile sizes -/// are selected such that all producers when fused into the tiled loops do no -/// have redundant computation. -struct TestTileConsumerFuseAndYieldProducerUsingSCFForOp - : public OpInterfaceRewritePattern { - - TestTileConsumerFuseAndYieldProducerUsingSCFForOp( - MLIRContext *context, scf::SCFTilingOptions options, - TransformationFilter filter = TransformationFilter(), - PatternBenefit benefit = 1) - : OpInterfaceRewritePattern(context, benefit), - options(std::move(options)), filter(std::move(filter)) {} - - LogicalResult matchAndRewrite(TilingInterface rootOp, - PatternRewriter &rewriter) const override { - if (failed(filter.checkAndNotify(rewriter, rootOp))) - return failure(); - - // Collect list of operations that can be tiled and fused. - llvm::SmallDenseSet tiledAndFusedOps = - collectTiledAndFusedOps(rootOp); - llvm::SmallDenseMap yielded; - auto isIgnoredUser = [&](Operation *user) { - return tiledAndFusedOps.count(user) || isa(user); - }; - for (Operation *op : tiledAndFusedOps) { - yielded[op] = llvm::any_of(op->getUsers(), [&](Operation *user) { - return !isIgnoredUser(user); - }); - } - - scf::SCFTileAndFuseOptions tileAndFuseOptions; - tileAndFuseOptions.setTilingOptions(options); - scf::SCFTileAndFuseOptions::ControlFnTy controlFn = - [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, - bool isDestinationOperand) { - Operation *owner = originalProducer.getOwner(); - return std::make_tuple(true, - yielded.contains(owner) && yielded[owner]); - }; - tileAndFuseOptions.setFusionControlFn(controlFn); - - FailureOr tileAndFuseResult = - scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( - rewriter, rootOp, tileAndFuseOptions); - if (failed(tileAndFuseResult)) { - return rewriter.notifyMatchFailure( - rootOp, "failed to tile and fuse with op as root"); - } - - for (auto it : tileAndFuseResult->replacements) { - Value origVal = it.first; - Value replacement = it.second; - rewriter.replaceUsesWithIf(origVal, replacement, [&](OpOperand &use) { - Operation *user = use.getOwner(); - return !isIgnoredUser(user) && - !tileAndFuseResult->loops.front()->isAncestor(user); - }); - } - - rewriter.eraseOp(rootOp); - for (auto tiledAndFusedOp : tileAndFuseResult->tiledAndFusedOps) - if (tiledAndFusedOp->hasAttr(kTransformMarker)) - filter.replaceTransformationFilter(rewriter, tiledAndFusedOp); - - return success(); - } - -private: - /// Starting from `op` walk all operands backwards to find all - /// potentially fusable operations, i.e. operations that implement - /// the `TilingInterface`. - llvm::SmallDenseSet - collectTiledAndFusedOps(Operation *op) const { - SmallVector worklist; - llvm::SmallDenseSet producers; - worklist.push_back(op); - producers.insert(op); - while (!worklist.empty()) { - Operation *current = worklist.pop_back_val(); - for (OpOperand &operand : current->getOpOperands()) { - Operation *producer = operand.get().getDefiningOp(); - if (!producer || !isa(producer) || - producers.count(producer)) - continue; - worklist.push_back(producer); - producers.insert(producer); - } - } - return producers; - } - - scf::SCFTilingOptions options; - TransformationFilter filter; -}; - -/// Pattern to lower operations that implement the `TilingInterface` to -/// loops/scalar IR using `scf.for`. -struct LowerToLoopsUsingSCFForOp - : public OpInterfaceRewritePattern { - using OpInterfaceRewritePattern::OpInterfaceRewritePattern; - - /// `matchAndRewrite` implementation that returns the significant transformed - /// pieces of IR. - LogicalResult matchAndRewrite(TilingInterface op, - PatternRewriter &rewriter) const override { - FailureOr> loops = - scf::lowerToLoopsUsingSCFForOp(rewriter, op); - if (failed(loops)) - return rewriter.notifyMatchFailure(op, "failed to lower to loops"); - rewriter.eraseOp(op); - return loops; - } -}; - -/// Test pass for testing the use of `TilingInterface`. -struct TestTilingInterfacePass - : public PassWrapper> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTilingInterfacePass) - - TestTilingInterfacePass() = default; - TestTilingInterfacePass(const TestTilingInterfacePass &pass) - : PassWrapper(pass) {} - void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); - linalg::registerTilingInterfaceExternalModels(registry); - tensor::registerTilingInterfaceExternalModels(registry); - } - StringRef getArgument() const final { return "test-tiling-interface"; } - StringRef getDescription() const final { - return "Test tiling using TilingInterface"; - } - - Option testTiling{ - *this, "tile-using-scf-for", - llvm::cl::desc( - "Test tiling using TilingInterface with scf.for operations"), - llvm::cl::init(false)}; - - Option testTilingForAll{ - *this, "tile-using-scf-forall", - llvm::cl::desc( - "Test tiling using TilingInterface with scf.forall operations"), - llvm::cl::init(false)}; - - Option testTileConsumerFuseAndYieldProducer{ - *this, "tile-consumer-fuse-and-yield-producer-using-scf-for", - llvm::cl::desc( - "Test tile and fuse transformation while yielding fused producer " - "replacements using TilingInterface with scf.for operations"), - llvm::cl::init(false)}; - - Option testTileConsumerAndFuseProducer{ - *this, "tile-consumer-and-fuse-producer-using-scf-for", - llvm::cl::desc("Test tile and fuse transformation using TilingInterface " - "with scf.for operations"), - llvm::cl::init(false)}; - - Option testLoweringToScalar{ - *this, "lower-to-scalar-using-scf-for", - llvm::cl::desc("Test lowering to scalar implementation using " - "TilingInterface with scf.for operations"), - llvm::cl::init(false)}; - - void runOnOperation() override; - -private: - void addTestPatterns(MLIRContext *context, RewritePatternSet &patterns); -}; -} // namespace - -static void addPatternForTiling(MLIRContext *context, - RewritePatternSet &patterns, - StringRef filterName, - ArrayRef tileSizes, - ArrayRef interchange = {}) { - scf::SCFTilingOptions tilingOptions; - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(context, tileSizes); - tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); - TransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, filter); -} - -static void addPatternForTilingUsingForall( - MLIRContext *context, RewritePatternSet &patterns, StringRef filterName, - ArrayRef tileSizes, - ArrayRef mapping = {}, - ArrayRef interchange = {}) { - scf::SCFTilingOptions tilingOptions; - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(context, tileSizes); - tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); - tilingOptions.setMapping(mapping); - TransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); - patterns.add(context, tilingOptions, filter); -} - -static void addPatternForTileFuseAndYield(MLIRContext *context, - RewritePatternSet &patterns, - StringRef filterName, - ArrayRef tileSizes, - ArrayRef interchange = {}) { - scf::SCFTilingOptions tilingOptions; - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(context, tileSizes); - tilingOptions.setTileSizes(tileSizesOfr).setInterchange(interchange); - TransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); - patterns.add( - context, tilingOptions, filter); -} - -static void addPatternForTileAndFuse(MLIRContext *context, - RewritePatternSet &patterns, - StringRef filterName, - ArrayRef tileSizes, - ArrayRef interchange = {}) { - scf::SCFTileAndFuseOptions tileAndFuseOptions; - SmallVector tileSizesOfr = - getAsIndexOpFoldResult(context, tileSizes); - tileAndFuseOptions.tilingOptions.setTileSizes(tileSizesOfr) - .setInterchange(interchange); - TransformationFilter filter(StringAttr::get(context, filterName), - StringAttr::get(context, "tiled")); - patterns.add( - context, tileAndFuseOptions, filter); -} - -void TestTilingInterfacePass::addTestPatterns(MLIRContext *context, - RewritePatternSet &patterns) { - if (testTiling) { - // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTiling(context, patterns, "simple_gemm", {10, 20}); - // 2. Tiling M, N and K of `linalg.matmul` on buffers. - addPatternForTiling(context, patterns, "simple_gemm_memref", {10, 20, 30}); - // 3. Tiling 3D parallel generic op which implements a transpose - addPatternForTiling(context, patterns, "parallel_generic_transpose", - {10, 0, 20}); - // 4. Tiling 2D conv op. - addPatternForTiling(context, patterns, "simple_conv", - {0, 0, 0, 0, 10, 20, 30}); - // 5. Tiling a simple op with `linalg.index` inside. - addPatternForTiling(context, patterns, "indexed_semantics", {10, 20}); - // 6. Tiling + interchange of an operation - addPatternForTiling(context, patterns, "gemm_interchange", {10, 20, 30}, - {1, 2, 0}); - // 7. Tiling for 2D pad tensor operations. - addPatternForTiling(context, patterns, "pad_2dtiling", {2, 3}); - // 8. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling(context, patterns, "pad_inner_tiling", {0, 3}); - // 9. Tiling inner dimension of 2d pad tensor operations. - addPatternForTiling(context, patterns, "pad_outer_tiling", {2, 3}); - // 10. Tiling M and N dims of `linalg.copy` on memrefs. - addPatternForTiling(context, patterns, "simple_copy_memref", {10, 20}); - // 11. Tiling scalar operations. - addPatternForTiling(context, patterns, "scalar_op", {}); - return; - } - if (testTilingForAll) { - // 1. Tiling M and N dims of `linalg.matmul` on tensors. - addPatternForTilingUsingForall( - context, patterns, "simple_gemm", {10, 20}, - {gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimY), - gpu::GPUBlockMappingAttr::get(context, gpu::MappingId::DimX)}); - // 2. Tiling 3D parallel generic op which implements a transpose. - addPatternForTilingUsingForall(context, patterns, - "parallel_generic_transpose", {10, 0, 20}); - // 3. Tiling 2D conv op. - addPatternForTilingUsingForall(context, patterns, "simple_conv", - {0, 0, 0, 0, 10, 20, 30}); - // 4. Tiling a simple op with `linalg.index` inside. - addPatternForTilingUsingForall(context, patterns, "indexed_semantics", - {10, 20}); - return; - } - if (testTileConsumerAndFuseProducer) { - // 1. Tile and fuse of gemm with fill producer and bias-add consumer. - addPatternForTileAndFuse(context, patterns, "fusion", {10, 20}); - // 2. Tile and fuse sequence of GEMMs, by fusing only along M. - addPatternForTileAndFuse(context, patterns, "gemm_fusion", {10}); - // 3. Tile and fuse gemm with consumer + interchange of tiled loops. - addPatternForTileAndFuse(context, patterns, "gemm_interchange_fusion", - {10, 20}, {1, 0}); - // 4. Tile and fuse matmul + transpose(matmul). Will introduce redundant - // computations. - addPatternForTileAndFuse(context, patterns, "gemm_plus_gemm_fusion", - {10, 20}); - // 5. Tile and fuse a sequence of GEMMs by tiling and fusing only along M - // dimension. - addPatternForTileAndFuse(context, patterns, "gemm_sequence_fusion", {10}); - // 6. Fusion of back-to-back-reduction ops - addPatternForTileAndFuse(context, patterns, "reduction_sequence_fusion", - {10}); - return; - } - if (testTileConsumerFuseAndYieldProducer) { - // 1. Fusion of back-to-back-reduction ops - addPatternForTileFuseAndYield(context, patterns, - "gemm_sequence_fusion_and_yield", {10}); - return; - } - if (testLoweringToScalar) { - patterns.add(context); - } -} - -void TestTilingInterfacePass::runOnOperation() { - MLIRContext *context = &getContext(); - - RewritePatternSet tilingPatterns(context); - addTestPatterns(context, tilingPatterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), - std::move(tilingPatterns)))) - return signalPassFailure(); -} - -namespace mlir { -namespace test { -void registerTestTilingInterface() { - PassRegistration(); -} -} // namespace test -} // namespace mlir diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp new file mode 100644 index 0000000000000..cc450f4564951 --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.cpp @@ -0,0 +1,267 @@ +//===- TestTilingInterfaceTransformOps.cpp - Test `TilingInterface` ------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines transform dialect operations used for testing +// TilingInterface +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Affine/IR/AffineOps.h" +#include "mlir/Dialect/Index/IR/IndexDialect.h" +#include "mlir/Dialect/SCF/Transforms/TileUsingInterface.h" +#include "mlir/Dialect/Transform/IR/TransformAttrs.h" +#include "mlir/Dialect/Transform/IR/TransformDialect.h" +#include "mlir/Dialect/Transform/IR/TransformInterfaces.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Dominance.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/Interfaces/TilingInterface.h" + +#define GET_OP_CLASSES +#include "TestTilingInterfaceTransformOps.h.inc" + +using namespace mlir; +using namespace mlir::transform; + +//===----------------------------------------------------------------------===// +// TestFuseAndYieldOp +//===----------------------------------------------------------------------===// + +static llvm::SmallDenseSet collectTiledAndFusedOps(Operation *op) { + SmallVector worklist; + llvm::SmallDenseSet producers; + worklist.push_back(op); + producers.insert(op); + while (!worklist.empty()) { + Operation *current = worklist.pop_back_val(); + for (OpOperand &operand : current->getOpOperands()) { + Operation *producer = operand.get().getDefiningOp(); + if (!producer || !isa(producer) || + producers.contains(producer)) + continue; + worklist.push_back(producer); + producers.insert(producer); + } + } + return producers; +} + +/// Apply a tile and fuse transformation to all payload ops and store both the +/// tiled operation as well as the created tile loops. +template +static LogicalResult +applyTileAndFuseToAll(RewriterBase &rewriter, Operation *transformOp, + Range &&payloadOps, unsigned numLoops, + ArrayRef tileSizes, + ArrayRef interchange, + transform::TransformResults &transformResults) { + SmallVector tiledOps; + SmallVector> loopOps(numLoops); + + for (Operation *target : payloadOps) { + auto tilingInterfaceOp = dyn_cast(target); + if (!tilingInterfaceOp) + return transformOp->emitError("only TilingInterface ops are supported"); + DominanceInfo dominanceInfo(tilingInterfaceOp); + + llvm::SmallDenseSet tiledAndFusedOps = + collectTiledAndFusedOps(tilingInterfaceOp); + llvm::DenseSet yieldReplacementsFor; + for (auto op : tiledAndFusedOps) { + if (llvm::any_of(op->getUsers(), [&](Operation *user) { + return dominanceInfo.properlyDominates(tilingInterfaceOp, user); + })) { + yieldReplacementsFor.insert(op); + } + } + + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); + + scf::SCFTileAndFuseOptions tileAndFuseOptions; + tileAndFuseOptions.setTilingOptions(tilingOptions); + + scf::SCFTileAndFuseOptions::ControlFnTy controlFn = + [&](tensor::ExtractSliceOp candidateSliceOp, OpResult originalProducer, + bool isDestinationOperand) { + Operation *owner = originalProducer.getOwner(); + bool yieldProducerReplacement = yieldReplacementsFor.contains(owner); + return std::make_tuple(true, yieldProducerReplacement); + }; + tileAndFuseOptions.setFusionControlFn(controlFn); + + rewriter.setInsertionPoint(target); + FailureOr tiledResults = + scf::tileConsumerAndFuseProducerGreedilyUsingSCFForOp( + rewriter, tilingInterfaceOp, tileAndFuseOptions); + if (failed(tiledResults)) + return failure(); + + // Perform the replacement of tiled and fused values. + SmallVector opsToReplace{target}; + llvm::append_range(opsToReplace, tiledResults->fusedProducers); + for (Operation *toReplace : opsToReplace) { + for (OpResult res : toReplace->getResults()) + if (auto replacement = tiledResults->replacements.lookup(res)) { + Operation *replacementOp = replacement.getDefiningOp(); + rewriter.replaceUsesWithIf( + res, replacement, [&](mlir::OpOperand &use) { + Operation *user = use.getOwner(); + return dominanceInfo.properlyDominates(replacementOp, user) && + user->getParentOp() == replacementOp->getParentOp(); + }); + } + + if (toReplace->use_empty()) { + rewriter.eraseOp(toReplace); + } + } + + // Report back the relevant handles to the transform op. + tiledOps.push_back(tiledResults->tiledAndFusedOps.front()); + assert(tiledResults->loops.size() == numLoops && + "Mismatched number of loops, tile and fuse transform should have " + "failed"); + for (unsigned int i = 0; i < numLoops; ++i) + loopOps[i].push_back(tiledResults->loops[i]); + } + + transformResults.set(transformOp->getOpResult(0), tiledOps); + for (unsigned int i = 0; i < numLoops; ++i) + transformResults.set(transformOp->getOpResult(i + 1), loopOps[i]); + + return success(); +} + +DiagnosedSilenceableFailure transform::TestFuseAndYieldOp::apply( + transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, + mlir::transform::TransformState &state) { + SmallVector tileSizes = + extractFromIntegerArrayAttr(getTileSizes()); + SmallVector tileInterchange = + extractFromIntegerArrayAttr(getTileInterchange()); + + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); + + LogicalResult result = applyTileAndFuseToAll( + rewriter, getOperation(), state.getPayloadOps(getTarget()), + tileSizes.size() - llvm::count(tileSizes, 0), tileSizesOfr, + tileInterchange, transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +//===----------------------------------------------------------------------===// +// TestTileUsingForallOp +//===----------------------------------------------------------------------===// + +/// Apply a tiling transformation to all payload ops and store both the +/// tiled operation as well as the created tile loops. +template +static LogicalResult +applyTileToAll(RewriterBase &rewriter, Operation *transformOp, + Range &&payloadOps, ArrayRef tileSizes, + ArrayRef interchange, std::optional mapping, + transform::TransformResults &transformResults) { + SmallVector tiledOps; + SmallVector loopOps; + + for (Operation *target : payloadOps) { + auto tilingInterfaceOp = dyn_cast(target); + if (!tilingInterfaceOp) + return transformOp->emitError("only TilingInterface ops are supported"); + scf::SCFTilingOptions tilingOptions; + tilingOptions.setTileSizes(tileSizes).setInterchange(interchange); + if (mapping) { + auto mappingAttrs = + llvm::map_to_vector(mapping.value(), [](Attribute attr) { + return cast(attr); + }); + tilingOptions.setMapping(mappingAttrs); + } + + rewriter.setInsertionPoint(target); + FailureOr tiledResults = + scf::tileUsingSCFForallOp(rewriter, tilingInterfaceOp, tilingOptions); + if (failed(tiledResults)) + return failure(); + + // Perform the replacement of tiled and fused values. + rewriter.replaceOp(tilingInterfaceOp, tiledResults->replacements); + + // Report back the relevant handles to the transform op. + tiledOps.push_back(tiledResults->tiledOps.front()); + for (Operation *loop : tiledResults->loops) + loopOps.push_back(loop); + } + + transformResults.set(transformOp->getOpResult(0), tiledOps); + for (auto [index, loop] : llvm::enumerate(loopOps)) + transformResults.set(transformOp->getOpResult(index + 1), {loop}); + + return success(); +} + +DiagnosedSilenceableFailure transform::TestTileUsingForallOp::apply( + transform::TransformRewriter &rewriter, + mlir::transform::TransformResults &transformResults, + mlir::transform::TransformState &state) { + SmallVector tileSizes = + extractFromIntegerArrayAttr(getTileSizes()); + SmallVector interchange = + extractFromIntegerArrayAttr(getInterchange()); + SmallVector tileSizesOfr = + getAsIndexOpFoldResult(rewriter.getContext(), tileSizes); + + LogicalResult result = + applyTileToAll(rewriter, getOperation(), state.getPayloadOps(getTarget()), + tileSizesOfr, interchange, getMapping(), transformResults); + return failed(result) ? DiagnosedSilenceableFailure::definiteFailure() + : DiagnosedSilenceableFailure::success(); +} + +void transform::TestTileUsingForallOp::getEffects( + SmallVectorImpl &effects) { + consumesHandle(getTarget(), effects); + producesHandle(getTiledOp(), effects); + producesHandle(getLoops(), effects); + modifiesPayload(effects); +} + +#define GET_OP_CLASSES +#include "TestTilingInterfaceTransformOps.cpp.inc" + +namespace { +class TestTilingInterfaceDialectExtension + : public transform::TransformDialectExtension< + TestTilingInterfaceDialectExtension> { +public: + using Base::Base; + + void init() { + declareDependentDialect(); + declareDependentDialect(); + declareDependentDialect(); + declareDependentDialect(); + + registerTransformOps< +#define GET_OP_LIST +#include "TestTilingInterfaceTransformOps.cpp.inc" + >(); + } +}; +} // namespace + +namespace test { +void registerTestTilingInterfaceTransformDialectExtension( + DialectRegistry ®istry) { + registry.addExtensions(); +} +} // namespace test diff --git a/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td new file mode 100644 index 0000000000000..6e9354198896a --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/TestTilingInterfaceTransformOps.td @@ -0,0 +1,81 @@ +//===- TestTilingInterfaceTransformOps.td -----------------*- tablegen -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef TEST_TILINGINTERFACE_TRANSFORM_OPS +#define TEST_TILINGINTERFACE_TRANSFORM_OPS + +include "mlir/Dialect/SCF/IR/DeviceMappingInterface.td" +include "mlir/Dialect/Transform/IR/TransformDialect.td" +include "mlir/Dialect/Transform/IR/TransformInterfaces.td" +include "mlir/Dialect/Transform/IR/TransformTypes.td" +include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/OpBase.td" + +// Those operations in this file are meant for testing the tiling interface +// transformations using scf operations. Over time these testing options +// might be useful transformations in their own right. Move these over +// as transform ops in the main repo (also find a proper place for them) + +def TestFuseAndYieldOp : Op, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Tiles the operations pointed to by the target handle, fuses their + producers greedily using the options provided as attributes. + It also yields some of the fused producers for testing. + + On success returns the tiled operations as well as generated loops. Emits + a definite failure if tiling fails. + }]; + + let arguments = + (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$tile_sizes, + DefaultValuedAttr:$tile_interchange); + let results = (outs TransformHandleTypeInterface:$transfomed, + Variadic:$loops); + + let assemblyFormat = [{ + $target ($tile_sizes^)? (`interchange` $tile_interchange^)? + attr-dict `:` functional-type(operands, results) + }]; +} + +def TestTileUsingForallOp : Op, + DeclareOpInterfaceMethods, + ReportTrackingListenerFailuresOpTrait]> { + let description = [{ + Test operation use to test tiling using TilingInterface and scf.forall for + the loop constructs. This is similar to + `transform.structured.tile_using_for`. Use of this operation is an + intermediate state and will be replaced in due course with either + `transform.structured.tile_using_for` or + `transform.structured.tile_using_forall`. + + On success returns the tiled operations as well as generated loops. Emits + a definite failure if tiling fails. + }]; + + let arguments = (ins TransformHandleTypeInterface:$target, + DefaultValuedAttr:$tile_sizes, + DefaultValuedOptionalAttr:$interchange, + OptionalAttr:$mapping); + let results = (outs TransformHandleTypeInterface:$tiled_op, + Variadic:$loops); + + let assemblyFormat = [{ + $target ($tile_sizes^)? (`interchange` $interchange^)? + (`mapping` $mapping^)? + attr-dict `:` functional-type(operands, results) + }]; +} + + +#endif // TEST_TILINGINTERFACE_TRANSFORM_OPS diff --git a/mlir/test/lib/Interfaces/TilingInterface/lit.local.cfg b/mlir/test/lib/Interfaces/TilingInterface/lit.local.cfg new file mode 100644 index 0000000000000..65a7f202dc82a --- /dev/null +++ b/mlir/test/lib/Interfaces/TilingInterface/lit.local.cfg @@ -0,0 +1 @@ +config.suffixes.remove(".td") diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 5c6a72881ddf4..428bdd9691e09 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -133,7 +133,6 @@ void registerTestShapeMappingPass(); void registerTestSliceAnalysisPass(); void registerTestTensorCopyInsertionPass(); void registerTestTensorTransforms(); -void registerTestTilingInterface(); void registerTestTopologicalSortAnalysisPass(); void registerTestTransformDialectEraseSchedulePass(); void registerTestTransformDialectInterpreterPass(); @@ -152,6 +151,7 @@ void registerTestPDLLPasses(); namespace test { void registerTestDialect(DialectRegistry &); void registerTestDynDialect(DialectRegistry &); +void registerTestTilingInterfaceTransformDialectExtension(DialectRegistry &); void registerTestTransformDialectExtension(DialectRegistry &); } // namespace test @@ -255,7 +255,6 @@ void registerTestPasses() { mlir::test::registerTestSliceAnalysisPass(); mlir::test::registerTestTensorCopyInsertionPass(); mlir::test::registerTestTensorTransforms(); - mlir::test::registerTestTilingInterface(); mlir::test::registerTestTopologicalSortAnalysisPass(); mlir::test::registerTestTransformDialectEraseSchedulePass(); mlir::test::registerTestTransformDialectInterpreterPass(); @@ -292,6 +291,7 @@ int main(int argc, char **argv) { #ifdef MLIR_INCLUDE_TESTS ::test::registerTestDialect(registry); ::test::registerTestTransformDialectExtension(registry); + ::test::registerTestTilingInterfaceTransformDialectExtension(registry); ::test::registerTestDynDialect(registry); #endif return mlir::asMainReturnCode(mlir::MlirOptMain(