Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Enable fuse consumer #89893

Closed
wants to merge 1 commit into from

Conversation

cxy-1993
Copy link
Contributor

This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.

  • Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position.
  • Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position.
  • Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
  • Add tests for fuse consumer interface

@llvmbot
Copy link
Collaborator

llvmbot commented Apr 24, 2024

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-core
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-tensor

Author: donald chen (cxy-1993)

Changes

This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.

  • Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position.
  • Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position.
  • Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
  • Add tests for fuse consumer interface

Patch is 33.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89893.diff

8 Files Affected:

  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+61-6)
  • (modified) mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp (+2-2)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+80-26)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp (+6-6)
  • (added) mlir/test/Dialect/Linalg/test-fuse-consumer.mlir (+103)
  • (modified) mlir/test/lib/Dialect/Linalg/CMakeLists.txt (+1)
  • (added) mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp (+294)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..84f7dec2f4003d 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           The method returns the operation that is the tiled
           implementation.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"getTiledImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           by the tiled implementation. Expects the same `offsets` and `sizes` as
           used to obtain the tiled implementation of the operation.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"getResultTilePosition",
         /*args=*/(ins
           "OpBuilder &":$b,
           "unsigned":$resultNumber,
           "ArrayRef<OpFoldResult> ":$offsets,
           "ArrayRef<OpFoldResult> ":$sizes,
-          "SmallVector<OpFoldResult> &":$resultOffsets,
-          "SmallVector<OpFoldResult> &":$resultSizes),
+          "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return the position of iteration domain tile computed by the
+          tiled operation.
+        }],
+        /*retType=*/"::mlir::LogicalResult",
+        /*methodName=*/"getIterationDomainTileFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
         /*methodBody=*/"",
         /*defaultImplementation=*/[{
           return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
             iteration space).
           - `sizes` provides the size of the tile.
         }],
-        /*retType=*/"FailureOr<TilingResult>",
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
         /*methodName=*/"generateResultTileValue",
         /*args=*/(ins
           "OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand tile position.
+
+          Generates the IR that computes the tiled implementation of an
+          operation from operand tile.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the tile of the
+          operand required. This method enables consumer fusion by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<::mlir::TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandTile",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           transformations are done, this method can be used to lower to scalar
           code that can then be lowered to LLVM or SPIR-V dialects.
         }],
-        /*retType=*/"LogicalResult",
+        /*retType=*/"::mlir::LogicalResult",
         /*methodName=*/"generateScalarImplementation",
         /*args=*/(ins
             "OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..e9999c34d0face 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
 
 LogicalResult SoftmaxOp::getResultTilePosition(
     OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
-    ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
-    SmallVector<OpFoldResult> &resultSizes) {
+    ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+    SmallVectorImpl<OpFoldResult> &resultSizes) {
   if (resultNumber == 0) {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..71e9c3771dcded 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
         }));
   }
 
-  // Instantiate the tiled implementation of the operation.
+  /// Instantiate the tiled implementation of the operation.
   FailureOr<TilingResult>
   getTiledImplementation(Operation *op, OpBuilder &b,
                          ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
     return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
   }
 
-  // Return the details of the output tile generated by the tiled
-  // implementation.
+  void
+  getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+                         ArrayRef<OpFoldResult> offsets,
+                         ArrayRef<OpFoldResult> sizes,
+                         SmallVectorImpl<OpFoldResult> &mappedOffsets,
+                         SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+    unsigned numLoops = linalgOp.getNumLoops();
+    auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+    mappedOffsets.resize(numLoops);
+    mappedSizes.resize(numLoops);
+    if (!indexingMap.isPermutation()) {
+      SmallVector<Range> iterationDomain =
+          tilingInterfaceOp.getIterationDomain(b);
+      for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+        mappedOffsets[index] = value.offset;
+        mappedSizes[index] = value.size;
+      }
+    }
+    for (const auto &&[index, value] :
+         llvm::enumerate(indexingMap.getResults())) {
+      unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+      mappedOffsets[dimPosition] = offsets[index];
+      mappedSizes[dimPosition] = sizes[index];
+    }
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
+  LogicalResult getIterationDomainTileFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+      SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+      SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+    auto linalgOp = cast<LinalgOp>(op);
+
+    // Check that the indexing map used for the operand is a projected
+    // permutation. This could be relaxed with a more general approach that can
+    // map the offsets and sizes from the operand to iteration space tiles
+    // (filling in full extent for dimensions not used to access the result).
+    AffineMap indexingMap =
+        linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+    if (!indexingMap.isProjectedPermutation()) {
+      return emitError(op->getLoc(),
+                       "unhandled get iter domain position when operand is not "
+                       "accessed using a permuted projection");
+    }
+
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           iterDomainOffsets, iterDomainSizes);
+    return success();
+  }
+
+  /// Return the details of the output tile generated by the tiled
+  /// implementation.
   LogicalResult
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     Location loc = op->getLoc();
     LinalgOp linalgOp = cast<LinalgOp>(op);
 
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
     return success();
   }
 
+  FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+      Operation *op, OpBuilder &b, unsigned operandNumber,
+      ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    auto tilingInterfaceOp = cast<TilingInterface>(op);
+    if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+            b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+      return emitError(
+          op->getLoc(),
+          "unable to obtain the iter domain position of the operation.");
+    }
+    return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+                                                    mappedSizes);
+  }
+
   FailureOr<TilingResult>
   generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
                           ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
           "unhandled tiled implementation generation when result is not "
           "accessed using a permuted projection");
     }
-
-    auto numLoops = linalgOp.getNumLoops();
+    SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+    getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+                           mappedOffsets, mappedSizes);
     auto tilingInterfaceOp = cast<TilingInterface>(op);
-    SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
-        iterationTileSizes(numLoops);
-    if (!indexingMap.isPermutation()) {
-      SmallVector<Range> iterationDomain =
-          tilingInterfaceOp.getIterationDomain(b);
-      for (const auto &range : llvm::enumerate(iterationDomain)) {
-        iterationTileOffsets[range.index()] = range.value().offset;
-        iterationTileSizes[range.index()] = range.value().size;
-      }
-    }
-    for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
-      unsigned dimPosition =
-          cast<AffineDimExpr>(resultExpr.value()).getPosition();
-      iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
-      iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
-    }
-
     FailureOr<TilingResult> tilingResult =
-        tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
-                                                 iterationTileSizes);
+        tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+    if (failed(tilingResult))
+      return failure();
+
     if (tilingResult->tiledOps.size() != 1)
       return op->emitOpError("failed to generate tiled implementation");
 
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec566..296c5fc7a5c2bd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets.assign(offsets.begin(), offsets.end());
     resultSizes.assign(sizes.begin(), sizes.end());
     return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     // The iteration domain is over outer dimensions of packed layout. In this
     // context, the outer dimensions of `resultOffsets` are `offsets`. The
     // inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
   getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
                         ArrayRef<OpFoldResult> offsets,
                         ArrayRef<OpFoldResult> sizes,
-                        SmallVector<OpFoldResult> &resultOffsets,
-                        SmallVector<OpFoldResult> &resultSizes) const {
+                        SmallVectorImpl<OpFoldResult> &resultOffsets,
+                        SmallVectorImpl<OpFoldResult> &resultSizes) const {
     resultOffsets = llvm::to_vector(offsets);
     resultSizes = llvm::to_vector(sizes);
     return success();
diff --git a/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
new file mode 100644
index 00000000000000..e7edbf0b2c25d4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-fuse-consumer | FileCheck %s
+
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+// CHECK-LABEL: func.func @fuse_tileable_consumer
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+  // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+  %0 = tensor.empty(%arg0) : tensor<?xf32>
+  %1 = affine.apply #map()[%arg0]
+  // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+  %2 = tensor.empty() : tensor<64xf32>
+  // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+  %3 = tensor.empty() : tensor<64xf32>
+  // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+  %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+    %6 = affine.apply #map1(%arg3)[%arg0]
+    %7 = affine.min #map2(%arg3)[%arg0]
+    // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+    // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+    %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+    // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+    // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+      // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+      tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+    }
+  } {"containing"}
+  // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+  %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>, "consumer"} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+  // CHECK: return %[[RES]]#1
+  return %5 : tensor<64xf32>
+}
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func.func @fuse_consumer_multi_output
+// CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+  %cst = arith.constant 0.000000e+00 : f32
+  // CHECK: %[[INIT:.*]] = linalg.fill
+  %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+  // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+  %1 = tensor.empty() : tensor<123x789xf32>
+  // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+  %2 = tensor.empty() : tensor<789x123xf32>
+  // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+  %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+    %5 = affine.min #map(%arg3)
+    %6 = affine.min #map1(%arg4)
+    %7 = affine.apply #map2(%arg3)
+    %8 = affine.apply #map3(%arg4)
+    %9 = affine.apply #map2(%arg3)
+    %10 = affine.apply #map3(%arg4)
+    // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+    %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+    // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+    %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+    // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+    %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+    // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+    %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+    // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+    // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+    // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+    %12 = affine.apply #map2(%arg3)
+    %13 = affine.apply #map3(%arg4)
+    scf.forall.in_parallel {
+      // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[...
[truncated]

@cxy-1993
Copy link
Contributor Author

This PR is a re-submission from (#85528), with changes to include additional tests. Please help me review this patch. @MaheshRavishankar @ftynse @Abhishek-Varma @nicolasvasilache

@cxy-1993 cxy-1993 force-pushed the cxy_fuse_comsumer_interface branch from cb24711 to 6b331ce Compare April 24, 2024 09:23
@banach-space
Copy link
Contributor

For those of us less familiar with the fusion logic, could you please provide a high level example with "BEFORE" and "AFTER"? You can use pseudo IR for that (I can use the implementation to reason about the finer details).

@nujaa
Copy link
Contributor

nujaa commented Apr 24, 2024

Could the behaviour of -test-linalg-fuse-consumer be reproduced in transform dialect ? E.g. with TileUsingForallOp + FuseOp ?

@MaheshRavishankar
Copy link
Contributor

Thanks @cxy-1993 for the change. I have comments, but just want to state at the outset that these arent a reaction to this PR or the effort to follow through on recommendations, but just unfortunate evolution of things.

What is being done here is what I was trying to avoid. To test the interface methods now there is an implementation of tile and fuse in a test folder. I don't see how this is useful.
#88712 is the place to test such a thing. This is why I didnt ask for tests on the original PR. So even if this is landed this way, all this is doing is adding more things to delete after the fact. I dont have the heart to block this change, but I really dont think this is a good outcome here.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for iterating on this!

It appears that some code was copy-pasted from #88712. At least, this is my impression after leaving multiple comments for identical issues in the code. If it is the case, please make sure all those issues are also addressed. You are also most welcome to collaborate with the author of that PR on a single branch and send one, unified PR.

if (!indexingMap.isPermutation()) {
SmallVector<Range> iterationDomain =
tilingInterfaceOp.getIterationDomain(b);
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: const && may not mean what you think it does here, see https://isocpp.org/blog/2012/11/universal-references-in-c11-scott-meyers. We also don't use const for IR objects https://mlir.llvm.org/docs/Rationale/UsageOfConst/.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"const" has been removed. Thanks for sharing knowledge about &&. Here, I want to use references to prevent reconstruction, and I don't care whether it's a lvalue reference or rvalue reference. It looks like a lvalue reference is used here.

#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
// CHECK-LABEL: func.func @fuse_tileable_consumer
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I'd match as just .+, there is no need to restrict symbols to only numbers and lowercase letters (and there may be underscores in names!)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tried your method, but there's an issue with matching because filecheck will match extra characters due to having the same data type on the same line, causing errors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, thanks for trying. Maybe let's update to [0-9A-Za-z_]+ then, which should capture all supported names.

mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
mlir/test/lib/Dialect/Linalg/TestLinalgFuseConsumer.cpp Outdated Show resolved Hide resolved
@ftynse
Copy link
Member

ftynse commented Apr 25, 2024

What is being done here is what I was trying to avoid. To test the interface methods now there is an implementation of tile and fuse in a test folder. I don't see how this is useful.

I agree that this isn't exactly useful, because it introduces a lot of additional logic beyond what actually needs to be tested. Moreover, the newly added tests only test this behavior indirectly, and may end up rather hard to debug should the behavior regress.

There is a much simpler testing strategy: implement a test pass that calls the newly added interface methods on all operations that implement the interface and outputs the results of the call as diagnostic remarks (attributes can be printed, values can be identified by their locations). This is straightforward and fits into a couple dozen lines of code. We could also consider a unit test. It was rather obvious to me, and I probably should have elaborated on that...

@ftynse
Copy link
Member

ftynse commented Apr 25, 2024

So even if this is landed this way, all this is doing is adding more things to delete after the fact.

I don't think this code is going to be deleted, but rather moved. So not much waste outside of reviewer time.

This patch adds support for consumer fusion to the tiling interface, and
implements fuse consumers on FuseIntoContainingOp.

- Add interface method 'getIterDomainTilePositionFromOperandPosition' to
tiling interface which get iteration domain position from operand
position.
- Add interface method 'getTiledImplementationFromOperandPosition' to
tiling interface which generate tiled implementation according to
operand position.
- Implemented the above two methods and supported consumer fusion for
FuseIntoContainingOp.
@cxy-1993 cxy-1993 force-pushed the cxy_fuse_comsumer_interface branch from 6b331ce to f36a1f6 Compare April 25, 2024 13:59
@cxy-1993 cxy-1993 requested a review from ftynse April 25, 2024 14:09
@cxy-1993
Copy link
Contributor Author

For those of us less familiar with the fusion logic, could you please provide a high level example with "BEFORE" and "AFTER"? You can use pseudo IR for that (I can use the implementation to reason about the finer details).

Thank you for your suggestion. A complete transform/pass submission should indeed include relevant descriptions. However, as discussed above, the implementation here is a test that is about to be moved or deleted. I feel it's unnecessary to provide excessive description of inputs and outputs here. We can refine that in this submission(#88712).

@cxy-1993
Copy link
Contributor Author

Could the behaviour of -test-linalg-fuse-consumer be reproduced in transform dialect ? E.g. with TileUsingForallOp + FuseOp ?

Yes. Considering it's about to be moved or deleted, I think this can be addressed in the next submission(https://github.com/llvm/llvm-project/pull/88712).

Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just adding this here as a proposal to consider if everyone thinks this would indeed be a better way forward :-

I believe having just the infra of TilingInterface.td + TilingInterfaceImpl.cpp (basically the reverted changes) as the separate base commit in the in-flight PR 88712 would be the most preferable way to go.

Reasons for the same :-

  1. The PR I added indeed makes use of the Interface methods to achieve the fusion for both scf.for and scf.forall.
  2. The same PR also adds the necessary tests based on a test transform dialect op.
  3. Most of the review comments by @ftynse has already been addressed - so @cxy-1993 can avoid having to redo the same set of things here.

After 88712 lands the only set of actionable item left would be to define an official transform dialect that does the same job as the test transform dialect I've already added in my PR - something perhaps @cxy-1993 can take up thereafter.

This avoids any unnecessary code movements and makes it cleaner.

@ftynse
Copy link
Member

ftynse commented Apr 25, 2024

Just adding this here as a proposal to consider if everyone thinks this would indeed be a better way forward :-

This would be fine with me. However, #88712 has changes requested by @nicolasvasilache with an ongoing discussion that may prevent this from landing for a longer time.

@cxy-1993
Copy link
Contributor Author

Just adding this here as a proposal to consider if everyone thinks this would indeed be a better way forward :-

I believe having just the infra of TilingInterface.td + TilingInterfaceImpl.cpp (basically the reverted changes) as the separate base commit in the in-flight PR 88712 would be the most preferable way to go.

Reasons for the same :-

  1. The PR I added indeed makes use of the Interface methods to achieve the fusion for both scf.for and scf.forall.
  2. The same PR also adds the necessary tests based on a test transform dialect op.
  3. Most of the review comments by @ftynse has already been addressed - so @cxy-1993 can avoid having to redo the same set of things here.

After 88712 lands the only set of actionable item left would be to define an official transform dialect that does the same job as the test transform dialect I've already added in my PR - something perhaps @cxy-1993 can take up thereafter.

This avoids any unnecessary code movements and makes it cleaner.

The proposal is also fine with me. Since @MaheshRavishankar @ftynse all agree on this, I'll close this PR. Please cherry pick the interface part to your PR @Abhishek-Varma , thanks!

BTW, Thanks again to @ftynse for reviewing this patch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants