-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][linalg] Enable fuse consumer #89893
Conversation
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tensor Author: donald chen (cxy-1993) ChangesThis patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.
Patch is 33.54 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/89893.diff 8 Files Affected:
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..84f7dec2f4003d 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -63,7 +63,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
The method returns the operation that is the tiled
implementation.
}],
- /*retType=*/"FailureOr<TilingResult>",
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"getTiledImplementation",
/*args=*/(ins
"OpBuilder &":$b,
@@ -82,15 +82,34 @@ def TilingInterface : OpInterface<"TilingInterface"> {
by the tiled implementation. Expects the same `offsets` and `sizes` as
used to obtain the tiled implementation of the operation.
}],
- /*retType=*/"LogicalResult",
+ /*retType=*/"::mlir::LogicalResult",
/*methodName=*/"getResultTilePosition",
/*args=*/(ins
"OpBuilder &":$b,
"unsigned":$resultNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
- "SmallVector<OpFoldResult> &":$resultOffsets,
- "SmallVector<OpFoldResult> &":$resultSizes),
+ "SmallVectorImpl<OpFoldResult> &":$resultOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$resultSizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to return the position of iteration domain tile computed by the
+ tiled operation.
+ }],
+ /*retType=*/"::mlir::LogicalResult",
+ /*methodName=*/"getIterationDomainTileFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult> ":$offsets,
+ "ArrayRef<OpFoldResult> ":$sizes,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainOffsets,
+ "SmallVectorImpl<OpFoldResult> &":$iterDomainSizes),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return failure();
@@ -119,7 +138,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
iteration space).
- `sizes` provides the size of the tile.
}],
- /*retType=*/"FailureOr<TilingResult>",
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
/*methodName=*/"generateResultTileValue",
/*args=*/(ins
"OpBuilder &":$b,
@@ -131,6 +150,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return failure();
}]
>,
+ InterfaceMethod<
+ /*desc=*/[{
+ Method to generate the tiled implementation of an operation from
+ operand tile position.
+
+ Generates the IR that computes the tiled implementation of an
+ operation from operand tile. The `offsets` and `sizes`
+ describe the tile of the operand required. This is different from
+ `getTiledImplementation` which generates the tiled
+ implementation of the operation given a tile of the
+ iteration space. This method generates a tiled
+ implementation of the operation based on the tile of the
+ operand required. This method enables consumer fusion by using
+ tile and fuse. The method returns failure if the operation
+ can't be tiled to generate the operand tile. In practical terms
+ this implies it cannot be tiled and fused with its producers.
+
+ - `offsets` provides the offset of the tile in the coordinate system
+ of the original iteration space, i.e., if an iteration space
+ dimension had non-zero offset, it must be included in the offset
+ provided here (as opposed to zero-based offset "relative" to the
+ iteration space).
+ - `sizes` provides the size of the tile.
+ }],
+ /*retType=*/"FailureOr<::mlir::TilingResult>",
+ /*methodName=*/"getTiledImplementationFromOperandTile",
+ /*args=*/(ins
+ "OpBuilder &":$b,
+ "unsigned":$operandNumber,
+ "ArrayRef<OpFoldResult>":$offsets,
+ "ArrayRef<OpFoldResult>":$sizes),
+ /*methodBody=*/"",
+ /*defaultImplementation=*/[{
+ return failure();
+ }]
+ >,
InterfaceMethod<
/*desc=*/[{
Generates the scalar implementation of the operation.
@@ -142,7 +197,7 @@ def TilingInterface : OpInterface<"TilingInterface"> {
transformations are done, this method can be used to lower to scalar
code that can then be lowered to LLVM or SPIR-V dialects.
}],
- /*retType=*/"LogicalResult",
+ /*retType=*/"::mlir::LogicalResult",
/*methodName=*/"generateScalarImplementation",
/*args=*/(ins
"OpBuilder &":$b,
diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
index 9c5c58fa1fabfb..e9999c34d0face 100644
--- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
+++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
@@ -2425,8 +2425,8 @@ SoftmaxOp::getTiledImplementation(OpBuilder &builder,
LogicalResult SoftmaxOp::getResultTilePosition(
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
- ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) {
+ ArrayRef<OpFoldResult> sizes, SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) {
if (resultNumber == 0) {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
diff --git a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
index bd870d4f982e5d..71e9c3771dcded 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp
@@ -110,7 +110,7 @@ struct LinalgOpTilingInterface
}));
}
- // Instantiate the tiled implementation of the operation.
+ /// Instantiate the tiled implementation of the operation.
FailureOr<TilingResult>
getTiledImplementation(Operation *op, OpBuilder &b,
ArrayRef<OpFoldResult> offsets,
@@ -132,14 +132,66 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}
- // Return the details of the output tile generated by the tiled
- // implementation.
+ void
+ getMappedOffsetAndSize(LinalgOp linalgOp, OpBuilder &b, AffineMap indexingMap,
+ ArrayRef<OpFoldResult> offsets,
+ ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &mappedOffsets,
+ SmallVectorImpl<OpFoldResult> &mappedSizes) const {
+ unsigned numLoops = linalgOp.getNumLoops();
+ auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
+ mappedOffsets.resize(numLoops);
+ mappedSizes.resize(numLoops);
+ if (!indexingMap.isPermutation()) {
+ SmallVector<Range> iterationDomain =
+ tilingInterfaceOp.getIterationDomain(b);
+ for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) {
+ mappedOffsets[index] = value.offset;
+ mappedSizes[index] = value.size;
+ }
+ }
+ for (const auto &&[index, value] :
+ llvm::enumerate(indexingMap.getResults())) {
+ unsigned dimPosition = cast<AffineDimExpr>(value).getPosition();
+ mappedOffsets[dimPosition] = offsets[index];
+ mappedSizes[dimPosition] = sizes[index];
+ }
+ }
+
+ /// Return the details of the output tile generated by the tiled
+ /// implementation.
+ LogicalResult getIterationDomainTileFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
+ SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
+ SmallVectorImpl<OpFoldResult> &iterDomainSizes) const {
+ auto linalgOp = cast<LinalgOp>(op);
+
+ // Check that the indexing map used for the operand is a projected
+ // permutation. This could be relaxed with a more general approach that can
+ // map the offsets and sizes from the operand to iteration space tiles
+ // (filling in full extent for dimensions not used to access the result).
+ AffineMap indexingMap =
+ linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
+ if (!indexingMap.isProjectedPermutation()) {
+ return emitError(op->getLoc(),
+ "unhandled get iter domain position when operand is not "
+ "accessed using a permuted projection");
+ }
+
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ iterDomainOffsets, iterDomainSizes);
+ return success();
+ }
+
+ /// Return the details of the output tile generated by the tiled
+ /// implementation.
LogicalResult
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
Location loc = op->getLoc();
LinalgOp linalgOp = cast<LinalgOp>(op);
@@ -160,6 +212,21 @@ struct LinalgOpTilingInterface
return success();
}
+ FailureOr<TilingResult> getTiledImplementationFromOperandTile(
+ Operation *op, OpBuilder &b, unsigned operandNumber,
+ ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes) const {
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ auto tilingInterfaceOp = cast<TilingInterface>(op);
+ if (failed(tilingInterfaceOp.getIterationDomainTileFromOperandTile(
+ b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
+ return emitError(
+ op->getLoc(),
+ "unable to obtain the iter domain position of the operation.");
+ }
+ return tilingInterfaceOp.getTiledImplementation(b, mappedOffsets,
+ mappedSizes);
+ }
+
FailureOr<TilingResult>
generateResultTileValue(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
@@ -177,29 +244,16 @@ struct LinalgOpTilingInterface
"unhandled tiled implementation generation when result is not "
"accessed using a permuted projection");
}
-
- auto numLoops = linalgOp.getNumLoops();
+ SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
+ getMappedOffsetAndSize(linalgOp, b, indexingMap, offsets, sizes,
+ mappedOffsets, mappedSizes);
auto tilingInterfaceOp = cast<TilingInterface>(op);
- SmallVector<OpFoldResult> iterationTileOffsets(numLoops),
- iterationTileSizes(numLoops);
- if (!indexingMap.isPermutation()) {
- SmallVector<Range> iterationDomain =
- tilingInterfaceOp.getIterationDomain(b);
- for (const auto &range : llvm::enumerate(iterationDomain)) {
- iterationTileOffsets[range.index()] = range.value().offset;
- iterationTileSizes[range.index()] = range.value().size;
- }
- }
- for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
- unsigned dimPosition =
- cast<AffineDimExpr>(resultExpr.value()).getPosition();
- iterationTileOffsets[dimPosition] = offsets[resultExpr.index()];
- iterationTileSizes[dimPosition] = sizes[resultExpr.index()];
- }
-
FailureOr<TilingResult> tilingResult =
- tilingInterfaceOp.getTiledImplementation(b, iterationTileOffsets,
- iterationTileSizes);
+ tilingInterfaceOp.getTiledImplementation(b, mappedOffsets, mappedSizes);
+
+ if (failed(tilingResult))
+ return failure();
+
if (tilingResult->tiledOps.size() != 1)
return op->emitOpError("failed to generate tiled implementation");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
index d25efcf50ec566..296c5fc7a5c2bd 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorTilingInterfaceImpl.cpp
@@ -61,8 +61,8 @@ struct PadOpTiling : public TilingInterface::ExternalModel<PadOpTiling, PadOp> {
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets.assign(offsets.begin(), offsets.end());
resultSizes.assign(sizes.begin(), sizes.end());
return success();
@@ -199,8 +199,8 @@ struct PackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
// The iteration domain is over outer dimensions of packed layout. In this
// context, the outer dimensions of `resultOffsets` are `offsets`. The
// inner dimensions of `resultOffsets` are zeros because tiling is not
@@ -452,8 +452,8 @@ struct UnPackOpTiling
getResultTilePosition(Operation *op, OpBuilder &b, unsigned resultNumber,
ArrayRef<OpFoldResult> offsets,
ArrayRef<OpFoldResult> sizes,
- SmallVector<OpFoldResult> &resultOffsets,
- SmallVector<OpFoldResult> &resultSizes) const {
+ SmallVectorImpl<OpFoldResult> &resultOffsets,
+ SmallVectorImpl<OpFoldResult> &resultSizes) const {
resultOffsets = llvm::to_vector(offsets);
resultSizes = llvm::to_vector(sizes);
return success();
diff --git a/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
new file mode 100644
index 00000000000000..e7edbf0b2c25d4
--- /dev/null
+++ b/mlir/test/Dialect/Linalg/test-fuse-consumer.mlir
@@ -0,0 +1,103 @@
+// RUN: mlir-opt %s -split-input-file -test-linalg-fuse-consumer | FileCheck %s
+
+#map = affine_map<()[s0] -> (64 ceildiv s0)>
+#map1 = affine_map<(d0)[s0] -> (d0 * s0)>
+#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)>
+// CHECK-LABEL: func.func @fuse_tileable_consumer
+// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index
+// CHECK-SAME: %[[IN:[0-9a-z]+]]: tensor<64xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<64xf32>
+func.func @fuse_tileable_consumer(%arg0: index, %arg1: tensor<64xf32>, %arg2: tensor<64xf32>) -> tensor<64xf32> {
+ // CHECK: %[[SLICE:.*]] = tensor.empty(%[[CHUNK_SIZE]]) : tensor<?xf32>
+ %0 = tensor.empty(%arg0) : tensor<?xf32>
+ %1 = affine.apply #map()[%arg0]
+ // CHECK: %[[EMPTY0:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+ %2 = tensor.empty() : tensor<64xf32>
+ // CHECK: %[[EMPTY1:[0-9a-z]+]] = tensor.empty() : tensor<64xf32>
+ %3 = tensor.empty() : tensor<64xf32>
+ // CHECK: %[[RES:[0-9a-z]+]]:2 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[OUT]], %[[LOOP_ARG1:.*]] = %[[EMPTY1]]
+ %4 = scf.forall (%arg3) in (%1) shared_outs(%arg4 = %arg2) -> (tensor<64xf32>) {
+ %6 = affine.apply #map1(%arg3)[%arg0]
+ %7 = affine.min #map2(%arg3)[%arg0]
+ // CHECK: %[[T0:.*]] = tensor.extract_slice %[[LOOP_ARG0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ %extracted_slice = tensor.extract_slice %arg4[%6] [%7] [1] : tensor<64xf32> to tensor<?xf32>
+ // CHECK: %[[T1:[0-9a-z]+]] = linalg.elemwise_unary
+ %8 = linalg.elemwise_unary ins(%0 : tensor<?xf32>) outs(%extracted_slice : tensor<?xf32>) -> tensor<?xf32>
+
+ // CHECK: %[[T2:.*]] = tensor.extract_slice %[[EMPTY0]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T3:.*]] = tensor.extract_slice %[[LOOP_ARG1]][%{{.*}}] [%{{.*}}] [{{.*}}]
+ // CHECK: %[[T4:.*]] = linalg.elemwise_binary {{.*}} ins(%[[T1]], %[[T2]] : {{.*}} outs(%[[T3]]
+
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[T4]] into %[[LOOP_ARG1]]
+ // CHECK: tensor.parallel_insert_slice %[[T1]] into %[[LOOP_ARG0]]
+ tensor.parallel_insert_slice %8 into %arg4[%6] [%7] [1] : tensor<?xf32> into tensor<64xf32>
+ }
+ } {"containing"}
+ // CHECK: %[[ORI_OUTPUT:.*]] = linalg.elemwise_binary
+ %5 = linalg.elemwise_binary {fun = #linalg.binary_fn<add>, "consumer"} ins(%4, %2 : tensor<64xf32>, tensor<64xf32>) outs(%3 : tensor<64xf32>) -> tensor<64xf32>
+ // CHECK: return %[[RES]]#1
+ return %5 : tensor<64xf32>
+}
+// -----
+
+#map = affine_map<(d0) -> (d0 * -50 + 123, 50)>
+#map1 = affine_map<(d0) -> (d0 * -16 + 789, 16)>
+#map2 = affine_map<(d0) -> (d0 * 50)>
+#map3 = affine_map<(d0) -> (d0 * 16)>
+#map4 = affine_map<(d0, d1) -> (d0, d1)>
+#map5 = affine_map<(d0, d1) -> (d1, d0)>
+// CHECK-LABEL: func.func @fuse_consumer_multi_output
+// CHECK-SAME: %[[IN0:[0-9a-z]+]]: tensor<123x456xf32>
+// CHECK-SAME: %[[IN1:[0-9a-z]+]]: tensor<456x789xf32>
+// CHECK-SAME: %[[OUT:[0-9a-z]+]]: tensor<123x789xf32>
+func.func @fuse_consumer_multi_output(%arg0: tensor<123x456xf32>, %arg1: tensor<456x789xf32>, %arg2: tensor<123x789xf32>) -> (tensor<123x789xf32>, tensor<789x123xf32>) {
+ %cst = arith.constant 0.000000e+00 : f32
+ // CHECK: %[[INIT:.*]] = linalg.fill
+ %0 = linalg.fill ins(%cst : f32) outs(%arg2 : tensor<123x789xf32>) -> tensor<123x789xf32>
+ // CHECK: %[[EMPTY0:.*]] = tensor.empty() : tensor<123x789xf32>
+ %1 = tensor.empty() : tensor<123x789xf32>
+ // CHECK: %[[EMPTY1:.*]] = tensor.empty() : tensor<789x123xf32>
+ %2 = tensor.empty() : tensor<789x123xf32>
+ // CHECK: %[[RES:[0-9a-z]+]]:3 = scf.forall {{.*}} shared_outs(%[[LOOP_ARG0:.*]] = %[[INIT]], %[[LOOP_ARG1:.*]] = %[[EMPTY0]], %[[LOOP_ARG2:.*]] = %[[EMPTY1]]
+ %3 = scf.forall (%arg3, %arg4) in (3, 50) shared_outs(%arg5 = %0) -> (tensor<123x789xf32>) {
+ %5 = affine.min #map(%arg3)
+ %6 = affine.min #map1(%arg4)
+ %7 = affine.apply #map2(%arg3)
+ %8 = affine.apply #map3(%arg4)
+ %9 = affine.apply #map2(%arg3)
+ %10 = affine.apply #map3(%arg4)
+ // CHECK: %[[EXTRACT_IN0:.*]] = tensor.extract_slice %[[IN0]]
+ %extracted_slice = tensor.extract_slice %arg0[%7, 0] [%5, 456] [1, 1] : tensor<123x456xf32> to tensor<?x456xf32>
+ // CHECK: %[[EXTRACT_IN1:.*]] = tensor.extract_slice %[[IN1]]
+ %extracted_slice_0 = tensor.extract_slice %arg1[0, %8] [456, %6] [1, 1] : tensor<456x789xf32> to tensor<456x?xf32>
+ // CHECK: %[[EXTRACT_OUT:.*]] = tensor.extract_slice %[[LOOP_ARG0]]
+ %extracted_slice_1 = tensor.extract_slice %arg5[%9, %10] [%5, %6] [1, 1] : tensor<123x789xf32> to tensor<?x?xf32>
+ // CHECK: %[[MATMUL_RES:.*]] = linalg.matmul ins(%[[EXTRACT_IN0]], %[[EXTRACT_IN1]] {{.*}} outs(%[[EXTRACT_OUT]]
+ %11 = linalg.matmul ins(%extracted_slice, %extracted_slice_0 : tensor<?x456xf32>, tensor<456x?xf32>) outs(%extracted_slice_1 : tensor<?x?xf32>) -> tensor<?x?xf32>
+
+ // CHECK: %[[EXTRACT_EMPTY0:.*]] = tensor.extract_slice %[[LOOP_ARG1]]
+ // CHECK: %[[EXTRACT_EMPTY1:.*]] = tensor.extract_slice %[[LOOP_ARG2]]
+ // CHECK: %[[GENERIC_RES:.*]]:2 = linalg.generic {{.*}} ins(%[[MATMUL_RES]] : tensor<?x?xf32>) outs(%[[EXTRACT_EMPTY0]], %[[EXTRACT_EMPTY1]]
+
+ %12 = affine.apply #map2(%arg3)
+ %13 = affine.apply #map3(%arg4)
+ scf.forall.in_parallel {
+ // CHECK: tensor.parallel_insert_slice %[[GENERIC_RES]]#0 into %[...
[truncated]
|
47aa35f
to
cb24711
Compare
This PR is a re-submission from (#85528), with changes to include additional tests. Please help me review this patch. @MaheshRavishankar @ftynse @Abhishek-Varma @nicolasvasilache |
cb24711
to
6b331ce
Compare
For those of us less familiar with the fusion logic, could you please provide a high level example with "BEFORE" and "AFTER"? You can use pseudo IR for that (I can use the implementation to reason about the finer details). |
Could the behaviour of |
Thanks @cxy-1993 for the change. I have comments, but just want to state at the outset that these arent a reaction to this PR or the effort to follow through on recommendations, but just unfortunate evolution of things. What is being done here is what I was trying to avoid. To test the interface methods now there is an implementation of tile and fuse in a test folder. I don't see how this is useful. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for iterating on this!
It appears that some code was copy-pasted from #88712. At least, this is my impression after leaving multiple comments for identical issues in the code. If it is the case, please make sure all those issues are also addressed. You are also most welcome to collaborate with the author of that PR on a single branch and send one, unified PR.
if (!indexingMap.isPermutation()) { | ||
SmallVector<Range> iterationDomain = | ||
tilingInterfaceOp.getIterationDomain(b); | ||
for (const auto &&[index, value] : llvm::enumerate(iterationDomain)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: const &&
may not mean what you think it does here, see https://isocpp.org/blog/2012/11/universal-references-in-c11-scott-meyers. We also don't use const
for IR objects https://mlir.llvm.org/docs/Rationale/UsageOfConst/.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"const" has been removed. Thanks for sharing knowledge about &&. Here, I want to use references to prevent reconstruction, and I don't care whether it's a lvalue reference or rvalue reference. It looks like a lvalue reference is used here.
#map1 = affine_map<(d0)[s0] -> (d0 * s0)> | ||
#map2 = affine_map<(d0)[s0] -> (-(d0 * s0) + 64, s0)> | ||
// CHECK-LABEL: func.func @fuse_tileable_consumer | ||
// CHECK-SAME: %[[CHUNK_SIZE:[0-9a-z]+]]: index |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I'd match as just .+
, there is no need to restrict symbols to only numbers and lowercase letters (and there may be underscores in names!)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried your method, but there's an issue with matching because filecheck will match extra characters due to having the same data type on the same line, causing errors.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see, thanks for trying. Maybe let's update to [0-9A-Za-z_]+
then, which should capture all supported names.
I agree that this isn't exactly useful, because it introduces a lot of additional logic beyond what actually needs to be tested. Moreover, the newly added tests only test this behavior indirectly, and may end up rather hard to debug should the behavior regress. There is a much simpler testing strategy: implement a test pass that calls the newly added interface methods on all operations that implement the interface and outputs the results of the call as diagnostic remarks (attributes can be printed, values can be identified by their locations). This is straightforward and fits into a couple dozen lines of code. We could also consider a unit test. It was rather obvious to me, and I probably should have elaborated on that... |
I don't think this code is going to be deleted, but rather moved. So not much waste outside of reviewer time. |
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp. - Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position. - Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position. - Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.
6b331ce
to
f36a1f6
Compare
Thank you for your suggestion. A complete transform/pass submission should indeed include relevant descriptions. However, as discussed above, the implementation here is a test that is about to be moved or deleted. I feel it's unnecessary to provide excessive description of inputs and outputs here. We can refine that in this submission(#88712). |
Yes. Considering it's about to be moved or deleted, I think this can be addressed in the next submission(https://github.com/llvm/llvm-project/pull/88712). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just adding this here as a proposal to consider if everyone thinks this would indeed be a better way forward :-
I believe having just the infra of TilingInterface.td
+ TilingInterfaceImpl.cpp
(basically the reverted changes) as the separate base commit in the in-flight PR 88712 would be the most preferable way to go.
Reasons for the same :-
- The PR I added indeed makes use of the Interface methods to achieve the fusion for both scf.for and scf.forall.
- The same PR also adds the necessary tests based on a test transform dialect op.
- Most of the review comments by @ftynse has already been addressed - so @cxy-1993 can avoid having to redo the same set of things here.
After 88712 lands the only set of actionable item left would be to define an official transform dialect that does the same job as the test transform dialect I've already added in my PR - something perhaps @cxy-1993 can take up thereafter.
This avoids any unnecessary code movements and makes it cleaner.
This would be fine with me. However, #88712 has changes requested by @nicolasvasilache with an ongoing discussion that may prevent this from landing for a longer time. |
The proposal is also fine with me. Since @MaheshRavishankar @ftynse all agree on this, I'll close this PR. Please cherry pick the interface part to your PR @Abhishek-Varma , thanks! BTW, Thanks again to @ftynse for reviewing this patch. |
This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.