Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[mlir][linalg] Enable fuse consumer #85528

Merged
merged 1 commit into from
Apr 22, 2024
Merged

[mlir][linalg] Enable fuse consumer #85528

merged 1 commit into from
Apr 22, 2024

Conversation

cxy-1993
Copy link
Contributor

This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.

  • Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position.
  • Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position.
  • Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.

@llvmbot
Copy link
Collaborator

llvmbot commented Mar 16, 2024

@llvm/pr-subscribers-mlir-tensor
@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: donald chen (cxy-1993)

Changes

This patch adds support for consumer fusion to the tiling interface, and implements fuse consumers on FuseIntoContainingOp.

  • Add interface method 'getIterDomainTilePositionFromOperandPosition' to tiling interface which get iteration domain position from operand position.
  • Add interface method 'getTiledImplementationFromOperandPosition' to tiling interface which generate tiled implementation according to operand position.
  • Implemented the above two methods and supported consumer fusion for FuseIntoContainingOp.

Patch is 40.78 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/85528.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td (+29-19)
  • (modified) mlir/include/mlir/Interfaces/TilingInterface.td (+55)
  • (modified) mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp (+307-62)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp (+75-21)
  • (modified) mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir (+124-1)
diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
index bdeab55091b9f3..2c501a3ecb14f7 100644
--- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
+++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td
@@ -310,51 +310,61 @@ def FuseIntoContainingOp :
           ["allowsRepeatedHandleOperands"]>,
        DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
        ReportTrackingListenerFailuresOpTrait]> {
-  let summary = "Fuse a producer into a containing operation.";
+  let summary = "Fuse a target into a containing operation.";
 
   let description = [{
-    Fuses the `producer_op` into the `containing_op`.
+    Fuses the `target_op` into the `containing_op`.
     Returns a handle to the fused ops and the `new_containing_op`.
 
-    The producer is typically a slice of a tileable op (i.e., implements
-    TilingInterface). In that case, this transform computes the accessed
-    producer slice inside of the containing op ("tile and fuse") and if required,
-    creates a new containing op with outputs from the fused producer. Otherwise,
-    the entire producer is cloned inside the containing op ("clone and fuse").
+    This operation supports fusion of producer or fusion of consumer. We will
+    refer to the value connecting the containing operation and the target
+    operation as the "bridge" below.
+
+    When fuse producer, the bridge is typically a slice of a tileable op (i.e.,
+    implements TilingInterface). In that case, this transform computes the
+    accessed bridge slice inside of the containing op ("tile and fuse") and
+    if required, creates a new containing op with outputs from the fused target.
+    Otherwise, the entire target is cloned inside the containing op ("clone
+    and fuse").
+
+    When fuse consumer, the bridge is the result of containing op and a operand
+    of a tileable op (i.2., implements TilingInterface). In this case, this
+    transform computes the access bridge slice inside the containing op ("tile
+    and fuse") and creates a new containing op with consumer's output.
 
     The containing op handle must be associated with exactly one payload op. The
-    producer op handle may be associated with multiple payload ops. This
-    transform fuses producers one-by-one, always picking an unspecified producer
+    target op handle may be associated with multiple payload ops. This
+    transform fuses targets one-by-one, always picking an unspecified target
     that has at least one use inside the containing op among the
-    producers. A producer can be listed multiple times in the handle.
+    targets. A target can be listed multiple times in the handle.
 
-    Note: If a producer has multiple uses inside the containing op, it is
+    Note: If a target has multiple uses inside the containing op, it is
     currently tiled and/or cloned multiple times into the containing op.
     TODO: Reuse already fused OpResults instead of tiling/cloning a second time
-    when possible. Fuse producers according to a topological sorting to achieve
+    when possible. Fuse targets according to a topological sorting to achieve
     the largest amount of reuse.
 
     #### Return modes
 
-    If at least one producer could not be fused, this operation produces a
+    If at least one target could not be fused, this operation produces a
     silenceable failure.  This is the case when tiling fails or when no
-    producer op could be found among the remaining producers that has at least
-    one use within the containing op. I.e., "producers" that are not consumed
+    target op could be found among the remaining targets that has at least
+    one use within the containing op. I.e., "targets" that are not consumed
     within the containing op are rejected by this operation.
 
-    This operation consumes the producer handle.
+    This operation consumes the target handle.
     This operation only reads the containing op handle.
   }];
 
-  let arguments = (ins TransformHandleTypeInterface:$producer_op,
+  let arguments = (ins TransformHandleTypeInterface:$target_op,
                        TransformHandleTypeInterface:$containing_op);
   let results = (outs TransformHandleTypeInterface:$fused_op,
                       TransformHandleTypeInterface:$new_containing_op);
-  let assemblyFormat = "$producer_op `into` $containing_op attr-dict "
+  let assemblyFormat = "$target_op `into` $containing_op attr-dict "
                        " `:` functional-type(operands, results)";
 
   let builders = [
-    OpBuilder<(ins "Value":$producerOp, "Value":$containingOp)>
+    OpBuilder<(ins "Value":$targetOp, "Value":$containingOp)>
   ];
 }
 
diff --git a/mlir/include/mlir/Interfaces/TilingInterface.td b/mlir/include/mlir/Interfaces/TilingInterface.td
index 66382f29c24249..4c62d45822ad44 100644
--- a/mlir/include/mlir/Interfaces/TilingInterface.td
+++ b/mlir/include/mlir/Interfaces/TilingInterface.td
@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return {};
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to return iterator domain position computed by the
+          input operand position.
+        }],
+        /*retType=*/"LogicalResult",
+        /*methodName=*/"getIterDomainTilePositionFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult> ":$offsets,
+          "ArrayRef<OpFoldResult> ":$sizes,
+          "SmallVector<OpFoldResult> &":$iterDomainOffsets,
+          "SmallVector<OpFoldResult> &":$iterDomainSizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to return the position of the result tile computed by the tiled operation.
@@ -96,6 +115,42 @@ def TilingInterface : OpInterface<"TilingInterface"> {
           return failure();
         }]
       >,
+      InterfaceMethod<
+        /*desc=*/[{
+          Method to generate the tiled implementation of an operation from
+          operand position.
+
+          Generates the IR that generate the tiled implementation of an
+          operation from operand position.  The `offsets` and `sizes`
+          describe the tile of the operand required. This is different from
+          `getTiledImplementation` which generates the tiled
+          implementation of the operation given a tile of the
+          iteration space. This method generates a tiled
+          implementation of the operation based on the position of the
+          operand required. This method enables fusion consumer by using
+          tile and fuse. The method returns failure if the operation
+          can't be tiled to generate the operand tile. In practical terms
+          this implies it cannot be tiled and fused with its producers.
+
+          - `offsets` provides the offset of the tile in the coordinate system
+            of the original iteration space, i.e., if an iteration space
+            dimension had non-zero offset, it must be included in the offset
+            provided here (as opposed to zero-based offset "relative" to the
+            iteration space).
+          - `sizes` provides the size of the tile.
+        }],
+        /*retType=*/"FailureOr<TilingResult>",
+        /*methodName=*/"getTiledImplementationFromOperandPosition",
+        /*args=*/(ins
+          "OpBuilder &":$b,
+          "unsigned":$operandNumber,
+          "ArrayRef<OpFoldResult>":$offsets,
+          "ArrayRef<OpFoldResult>":$sizes),
+        /*methodBody=*/"",
+        /*defaultImplementation=*/[{
+          return failure();
+        }]
+      >,
       InterfaceMethod<
         /*desc=*/[{
           Method to generate the code that produces a tile of the result.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index ae28049f02e391..ecffb910b236e8 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -546,9 +546,9 @@ LogicalResult transform::FuseOp::verify() {
 
 void transform::FuseIntoContainingOp::build(OpBuilder &builder,
                                             OperationState &result,
-                                            Value producerOp,
+                                            Value targetOp,
                                             Value containingOp) {
-  result.addOperands({producerOp, containingOp});
+  result.addOperands({targetOp, containingOp});
   auto resultType = transform::AnyOpType::get(builder.getContext());
   result.addTypes({resultType, resultType});
 }
@@ -631,6 +631,223 @@ static Operation *replaceForAllWithNewSignature(
   return newforallOp;
 }
 
+static std::tuple<SmallVector<Operation *>, Operation *>
+tileAndFuseParallelInsertSlice(RewriterBase &rewriter, Diagnostic &diag,
+                               Operation *consumerOp, Operation *containingOp) {
+  // Check consumer has tiling interface.
+  LLVM_DEBUG(DBGS() << "Try to fuse a consumer\n");
+  auto tileableConsumer = dyn_cast<TilingInterface>(consumerOp);
+  if (!tileableConsumer) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer is not a TileableInterface: " << *consumerOp;
+    return {};
+  }
+
+  // Check containing op is "scf::ForallOp".
+  auto forallOp = dyn_cast<scf::ForallOp>(containingOp);
+  if (!forallOp) {
+    diag.attachNote(containingOp->getLoc())
+        << "containing op is not a scf.forall: " << containingOp;
+    return {};
+  }
+
+  // Check dominance.
+  DominanceInfo domInfo(
+      containingOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>());
+  if (llvm::any_of(consumerOp->getOperands(), [&](Value v) {
+        return v.getDefiningOp() != containingOp &&
+               !domInfo.properlyDominates(v, containingOp);
+      })) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer's operand can't dominate containing op";
+    return {};
+  }
+
+  // Check consumer don't use more than one result of containingOp.
+  Value bridge(nullptr);
+  SmallVector<unsigned> operandNums;
+  for (auto [idx, opd] : llvm::enumerate((consumerOp->getOperands()))) {
+    if (opd.getDefiningOp() == containingOp) {
+      operandNums.push_back(idx);
+      if (!bridge) {
+        bridge = opd;
+      } else if (bridge != opd) {
+        diag.attachNote(consumerOp->getLoc())
+            << "consumer's operand use more than one containingOp's result";
+        return {};
+      }
+    }
+  }
+
+  // TODO: We have to init result of consumer before scf.forall, use
+  //       DestinationStyleOpInterface to get result shape from init for now.
+  //       Add support for other op such as op has InferTypeOpInterface.
+  // Check consumer has DestinationStyleOpInterface.
+  auto dstOp = dyn_cast<DestinationStyleOpInterface>(consumerOp);
+  if (!dstOp) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer op should have destination style op interface";
+    return {};
+  }
+
+  // Check consumer doon't use scf.forall's output as init.
+  SmallVector<Value> dpsInits = llvm::to_vector<4>(
+      llvm::map_range(dstOp.getDpsInits(), [](Value v) { return v; }));
+  if (llvm::is_contained(dpsInits, bridge)) {
+    diag.attachNote(consumerOp->getLoc())
+        << "consumer op take result of scf.forall as init";
+    return {};
+  }
+
+  // Check result was inserted only once.
+  int64_t bridgeResultIdx = cast<OpResult>(bridge).getResultNumber();
+  auto bridgeBlockArg = forallOp.getRegionOutArgs()[bridgeResultIdx];
+  scf::InParallelOp terminatorOp = forallOp.getTerminator();
+
+  tensor::ParallelInsertSliceOp targetInsertOp(nullptr);
+  for (Operation &op : terminatorOp.getRegion().front().getOperations()) {
+    auto parallelInsertSliceOp = cast<tensor::ParallelInsertSliceOp>(op);
+    if (parallelInsertSliceOp.getDest() == bridgeBlockArg) {
+      if (!targetInsertOp) {
+        targetInsertOp = parallelInsertSliceOp;
+      } else {
+        diag.attachNote(containingOp->getLoc())
+            << "containingOp's result inserted multi time";
+        return {};
+      }
+    }
+  }
+
+  if (!targetInsertOp) {
+    diag.attachNote(containingOp->getLoc())
+        << "containingOp's result was not inserted";
+    return {};
+  }
+
+  SmallVector<OpFoldResult> offsets = targetInsertOp.getMixedOffsets();
+  SmallVector<OpFoldResult> sizes = targetInsertOp.getMixedSizes();
+  SmallVector<OpFoldResult> strides = targetInsertOp.getMixedStrides();
+
+  // Check all insert stride is 1.
+  if (llvm::any_of(strides, [](OpFoldResult foldRes) {
+        if (auto attr = foldRes.dyn_cast<Attribute>()) {
+          return cast<IntegerAttr>(attr).getInt() != 1;
+        }
+        return true;
+      })) {
+    diag.attachNote(containingOp->getLoc())
+        << "containingOp's result yield with stride";
+    return {};
+  }
+
+  Location loc = forallOp.getLoc();
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(terminatorOp);
+
+  SmallVector<OpFoldResult> iterDomainOffsets, iterDomainSizes;
+
+  // Try to get iter domain position from input position.
+  if (failed(tileableConsumer.getIterDomainTilePositionFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes, iterDomainOffsets,
+          iterDomainSizes))) {
+    diag.attachNote(consumerOp->getLoc())
+        << "can't get iter domain position from input position";
+    return {};
+  }
+
+  // Try to get all containing op result's position from iter domain position.
+  llvm::SmallVector<std::pair<llvm::SmallVector<OpFoldResult>,
+                              llvm::SmallVector<OpFoldResult>>>
+      resultPositions(consumerOp->getNumResults());
+  for (auto [idx, v] : llvm::enumerate(consumerOp->getResults())) {
+    if (failed(tileableConsumer.getResultTilePosition(
+            rewriter, idx, iterDomainOffsets, iterDomainSizes,
+            resultPositions[idx].first, resultPositions[idx].second))) {
+      diag.attachNote(consumerOp->getLoc())
+          << "can't get result domain position from iter domain position";
+      return {};
+    }
+  }
+
+  // All check passed, try to fuse consumer.
+  // Create tiled implementation of containing op.
+  FailureOr<TilingResult> tileAndFuseResult =
+      tileableConsumer.getTiledImplementationFromOperandPosition(
+          rewriter, operandNums.front(), offsets, sizes);
+  if (failed(tileAndFuseResult)) {
+    diag.attachNote(consumerOp->getLoc()) << "get tiled implementation failed";
+    return {};
+  }
+
+  auto tiledOps = tileAndFuseResult->tiledOps;
+  if (failed(tileAndFuseResult) || tiledOps.size() != 1) {
+    diag.attachNote(tileableConsumer->getLoc())
+        << "failed to tile consumer op: " << *tileableConsumer;
+    return {};
+  }
+
+  // Replace tiled op's operand .
+  for (auto operandNum : operandNums) {
+    tiledOps[0]->setOperand(operandNum, targetInsertOp.getSource());
+  }
+  rewriter.replaceUsesWithIf(bridge, forallOp.getOutputs()[bridgeResultIdx],
+                             [&](OpOperand &use) {
+                               Operation *op = use.getOwner();
+                               return forallOp->isProperAncestor(op);
+                             });
+
+  SmallVector<Value> newOuts(forallOp.getOutputs());
+  newOuts.append(dpsInits);
+
+  // Create new scf.forall op.
+  rewriter.setInsertionPoint(forallOp);
+  auto newforallOp = rewriter.create<scf::ForallOp>(
+      loc, forallOp.getMixedLowerBound(), forallOp.getMixedUpperBound(),
+      forallOp.getMixedStep(), newOuts, forallOp.getMapping());
+  rewriter.eraseBlock(newforallOp.getBody());
+  newforallOp.getRegion().takeBody(forallOp.getRegion());
+
+  for (auto v : dpsInits) {
+    newforallOp.getBody()->addArgument(v.getType(), v.getLoc());
+    auto bbArgs = newforallOp.getBody()->getArguments();
+    rewriter.replaceUsesWithIf(v, bbArgs.back(), [&](OpOperand &use) {
+      Operation *op = use.getOwner();
+      return newforallOp->isProperAncestor(op);
+    });
+  }
+
+  // Fix terminator.
+  scf::InParallelOp newTerminatorOp = newforallOp.getTerminator();
+  SmallVector<Operation *> yieldingOps = llvm::to_vector<4>(llvm::map_range(
+      newTerminatorOp.getYieldingOps(), [](Operation &op) { return &op; }));
+  Operation *firstYieldOp = yieldingOps.front();
+  rewriter.setInsertionPoint(firstYieldOp);
+  auto bbArgs = newforallOp.getBody()->getArguments();
+  for (auto [idx, v] : llvm::enumerate(tiledOps[0]->getResults())) {
+    SmallVector<OpFoldResult> strides(resultPositions[idx].first.size(),
+                                      rewriter.getIndexAttr(1));
+    rewriter.create<tensor::ParallelInsertSliceOp>(
+        firstYieldOp->getLoc(), v,
+        bbArgs[forallOp.getRank() + forallOp.getOutputs().size() + idx],
+        resultPositions[idx].first, resultPositions[idx].second, strides);
+  }
+
+  // Replace the result of forall and consumer op.
+  for (auto result : llvm::enumerate(forallOp.getResults())) {
+    rewriter.replaceAllUsesWith(result.value(),
+                                newforallOp->getResult(result.index()));
+  }
+
+  for (auto consumerResult : llvm::enumerate(consumerOp->getResults())) {
+    rewriter.replaceAllUsesWith(
+        consumerResult.value(),
+        newforallOp->getResult(forallOp.getOutputs().size() +
+                               consumerResult.index()));
+  }
+
+  return std::make_tuple(tileAndFuseResult->tiledOps, newforallOp);
+}
+
 /// Find the first "extract" user of `producerOp` and tile it right before its
 /// use. The tiled op is fused under the `containingOp`.
 /// Return this fused op on success or nullptr if anything fails.
@@ -880,7 +1097,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
                                        transform::TransformResults &results,
                                        transform::TransformState &state) {
   SmallVector<Operation *> fusedOps;
-  auto producerOps = state.getPayloadOps(getProducerOp());
+  auto targetOps = state.getPayloadOps(getTargetOp());
   auto containingOps = state.getPayloadOps(getContainingOp());
   if (!llvm::hasSingleElement(containingOps)) {
     return emitDefiniteFailure()
@@ -890,69 +1107,115 @@ transform::FuseIntoContainingOp::apply(transform::TransformRewriter &rewriter,
   Operation *containingOp = *containingOps.begin();
 
   // If nothing to fuse, propagate success.
-  if (std::empty(producerOps)) {
+  if (std::empty(targetOps)) {
     results.set(cast<OpResult>(getFusedOp()), SmallVector<mlir::Operation *>{});
     results.set(cast<OpResult>(getNewContainingOp()), {containingOp});
     return DiagnosedSilenceableFailure::success();
   }
 
-  // Helper function to find the next producer that should be fused. Take any
-  // producer that has a use inside the containing op.
-  SetVector<Operation *> remainingProducers(producerOps.begin(),
-                                            producerOps.end());
-  auto getNextProducer = [&]() -> FailureOr<Operation *> {
-    for (const auto &it : enumerate(remainingProducers)) {
-      Operation *producerOp = it.value();
-      // The containing op may be a user of producerOp: use isAncestor.
+  // Helper function to find the next target that should be fused. Take any
+  // target that has a use inside the containing op. Return target operation
+  // and a bool variable indicate if this target op is a producer.
+  SetVector<Operation *> remainingTargets(targetOps.begin(), targetOps.end());
+  auto getNextTarget = [&]() -> FailureOr<std::pair<Operation *, bool>> {
+    for (const auto &it : enumerate(remainingTargets)) {
+      Operation *targetOp = it.value();
+      // The containing op may be a user of targetOp: use isAncestor.
       int64_t numUsesInContainingOp =
-          llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
+          llvm::count_if(targetOp->getUsers(), [&](Operation *op) {
             return containingOp->isAncestor(op);
           });
       // TODO: When resolving the TODO below (no duplicate ops), take an op
-      // that has no use among the remaining producers. This is a topological
+      // that has no use among the remaining targets. Th...
[truncated]

@cxy-1993
Copy link
Contributor Author

@MaheshRavishankar Could you please help me review this patch?

@cxy-1993 cxy-1993 force-pushed the main branch 2 times, most recently from 7ca5ae5 to c79772d Compare March 16, 2024 14:33
@MaheshRavishankar
Copy link
Contributor

Thanks! Yes, I will definitely review the patch. I am in a middle of a crunch week. I'll probably get to it in about 10 days if that is ok

@cxy-1993
Copy link
Contributor Author

Thanks! Yes, I will definitely review the patch. I am in a middle of a crunch week. I'll probably get to it in about 10 days if that is ok

Don't worry, there's no rush. Take your time.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes. I looked at the changes to the TilingInterface... those look good. Just left a few nits.

I missed this, but could you instead look at the implementation in https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp instead. That file does tile consumer and fuse producer using the LoopLikeOpInterface and works for both scf.for and scf.forall. The path that is here in LinalgTransforms is supposed to be deprecated (by me when I get some time). Last time I touched it, I made things fairly modular and readable (afaik) and just following that but instead fusing with consumers would be a great help. If added here, it will just be on my list of things to deprecate and move it to the above file anyway.

I havent looked at the LinalgTransforms.* changes (cause I am also not an expert in transform dialect). I think it would be better to make these methods available outside of transform dialect to be callable as standalone utility functions (like is done in the implementation I linked to above).

@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
InterfaceMethod<
/*desc=*/[{
Method to return iterator domain position computed by the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this isnt as clear. Please see the description for getResultTilePosition maybe something closer to that (or improve both if you have a better description).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE

input operand position.
}],
/*retType=*/"LogicalResult",
/*methodName=*/"getIterDomainTilePositionFromOperandPosition",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Maybe rename to getIterationDomainTileFromOperandTile?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE rename to getIterationDomainTileFromOperandTile

Method to generate the tiled implementation of an operation from
operand position.

Generates the IR that generate the tiled implementation of an
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit :

Generates the IR that computes the tiled implementation of an operation based on
the operand tile

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE

@@ -74,6 +74,25 @@ def TilingInterface : OpInterface<"TilingInterface"> {
return {};
}]
>,
InterfaceMethod<
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Just a matter of reading. At least for me, fusing producer with consumer is more natural and is then easy to make the leap to fuse consumer with producer. So maybe position these methods textually after the corresponding methods that deal with the producer -> consumer fusion.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE, move methods after profucer fusion

`getTiledImplementation` which generates the tiled
implementation of the operation given a tile of the
iteration space. This method generates a tiled
implementation of the operation based on the position of the
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit:

This method generates a tiled implementation of the operation based on the tile of the operand required.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE

@@ -132,6 +132,59 @@ struct LinalgOpTilingInterface
return TilingResult{{tiledOp}, SmallVector<Value>(tiledOp->getResults())};
}

void getMappedOffsetAndSize(Operation *op, OpBuilder &b,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: just make the operand LinalgOp linalgOp type instead of Operation *op. You can then do

auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your comments! It's crunch week for me now, I'll address your review comments ASAP.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DONE

@MaheshRavishankar
Copy link
Contributor

Hi, some of this was used in #88712 . THere are two things needed here..

  1. We need the interface methods
  2. We need the base implementation that can fuse consumers for both scf.for and scf.forall (and hopefully unified under a LoopLikeOpInterface)
  3. We need the transform dialect ops.

Two co-ordinate between these two PRs, can we first land the interface methods. The implementation of (2) here for just scf.forall has some of the same issues that I highlighted in the other PR... You can wait for that to land (and help contribute/review) and then add (3). So can we split this PR first to just land the interface methods?

@cxy-1993
Copy link
Contributor Author

cxy-1993 commented Apr 17, 2024

Hi, some of this was used in #88712 . THere are two things needed here..

  1. We need the interface methods
  2. We need the base implementation that can fuse consumers for both scf.for and scf.forall (and hopefully unified under a LoopLikeOpInterface)
  3. We need the transform dialect ops.

Two co-ordinate between these two PRs, can we first land the interface methods. The implementation of (2) here for just scf.forall has some of the same issues that I highlighted in the other PR... You can wait for that to land (and help contribute/review) and then add (3). So can we split this PR first to just land the interface methods?

Thank you for your review and reminders. I have already extracted the interface part separately and correspondingly addressed your review comments. Could you please help review it again,

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs tests.

implementation of the operation given a tile of the
iteration space. This method generates a tiled
implementation of the operation based on the tile of the
operand required. This method enables fusion consumer by using
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
operand required. This method enables fusion consumer by using
operand required. This method enables consumer fusion by using

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"unsigned":$operandNumber,
"ArrayRef<OpFoldResult> ":$offsets,
"ArrayRef<OpFoldResult> ":$sizes,
"SmallVector<OpFoldResult> &":$iterDomainOffsets,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see that this is cargo-culted from the method above, but we should be using SmallVectorImpl to avoid hardcoding the implicit number of stack elements. SmallVector is in fact SmallVector<6> in this case.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree. I changed parameter type to SmallVectorImpl as well as the code I cargo-culted from.

Method to return the position of iteration domain tile computed by the
tiled operation.
}],
/*retType=*/"LogicalResult",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: prefer globally-qualified types in .td, ::mlir::LogicalResult. Generated code may be #included in another namespace.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice catch, done

ArrayRef<OpFoldResult> sizes,
SmallVector<OpFoldResult> &mappedOffsets,
SmallVector<OpFoldResult> &mappedSizes) const {
auto numLoops = linalgOp.getNumLoops();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: expand auto unless the type is obvious from line-level context, e.g., there is a cast such as below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

Comment on lines 148 to 150
for (const auto &range : llvm::enumerate(iterationDomain)) {
mappedOffsets[range.index()] = range.value().offset;
mappedSizes[range.index()] = range.value().size;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (const auto &range : llvm::enumerate(iterationDomain)) {
mappedOffsets[range.index()] = range.value().offset;
mappedSizes[range.index()] = range.value().size;
for (auto &&[index, value] : llvm::enumerate(iterationDomain)) {
mappedOffsets[index] = value.offset;
mappedSizes[index] = value.size;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

mappedSizes[range.index()] = range.value().size;
}
}
for (const auto &resultExpr : llvm::enumerate(indexingMap.getResults())) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

AffineMap indexingMap =
linalgOp.getMatchingIndexingMap(&op->getOpOperand(operandNumber));
if (!indexingMap.isProjectedPermutation()) {
return op->emitOpError(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: emitOpError is for errors that are related to the validity of the op, which this is not. Use emitError instead.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@cxy-1993
Copy link
Contributor Author

This needs tests.

Thanks for your time. I have modified the code base on your comments(as well as the original code).

As for tests, I have original implemented a transform op in linalg dialect, and add tests. After the discussion above, we decided to only add the interface part for now. The tests for these parts require the implementation of the transform op, so can we implement the interface first and add tests for the transform op later?

@MaheshRavishankar
Copy link
Contributor

@Abhishek-Varma please take a look at this PR and approve if this looks fine to you and you can land your PR on top of this.

For me this looks fine..

Copy link
Contributor

@Abhishek-Varma Abhishek-Varma left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this @cxy-1993 - this has been a really good nudge for fusing consumers!

A few comments to address before merging. LGTM otherwise.

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp Outdated Show resolved Hide resolved
mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp Outdated Show resolved Hide resolved
mlir/include/mlir/Interfaces/TilingInterface.td Outdated Show resolved Hide resolved
@cxy-1993
Copy link
Contributor Author

Thanks for this @cxy-1993 - this has been a really good nudge for fusing consumers!

A few comments to address before merging. LGTM otherwise.

Thanks for the comment @Abhishek-Varma. The code have modified according to your comment. I don't have write access to this repo, please help me merge this patch into master.

This patch adds support for consumer fusion to the tiling interface.

- Add interface method 'getIterationDomainTileFromOperandTile' to tiling
  interface which get iteration domain position from operand position.
- Add interface method 'getTiledImplementationFromOperandTile' to tiling
  interface which generate tiled implementation according to operand position.
@Abhishek-Varma Abhishek-Varma merged commit 2a47ee0 into llvm:main Apr 22, 2024
4 checks passed
@nicolasvasilache
Copy link
Contributor

@Abhishek-Varma, @cxy-1993, @ftynse gave a clear review that this needs tests.

It is not acceptable to land randomly untested code!

Please revert, add tests and resend for review.

@MaheshRavishankar
Copy link
Contributor

Definitely needs tests, but I didn't see a need for revert. Should be fine to add tests after the fact. @cxy-1993 and @Abhishek-Varma could we add some tests here

@ftynse
Copy link
Member

ftynse commented Apr 22, 2024

For practicality reasons, I'd prefer a revert and a new PR that includes tests as opposed to just adding tests. This will help us, the reviewers, ascertain that the tests are actually testing what they should be testing, there is enough coverage, etc. It also makes it easy to bisect and revert the change should the need arise.

@nicolasvasilache
Copy link
Contributor

nicolasvasilache commented Apr 22, 2024

Definitely needs tests, but I didn't see a need for revert. Should be fine to add tests after the fact. @cxy-1993 and @Abhishek-Varma could we add some tests here

yeah, let's instead be consistent with what we've been doing so far: https://mlir.llvm.org/getting_started/TestingGuide/

Additionally, I believe this should also not have landed for the same reason as I am objecting to #88712 I'll just copy paste here:

Now, I am afraid I cannot subscribe to any large code changes to tiling transforms until #77874 is addressed to my satisfaction. I already gave a pass 6 months ago and another one 3 months ago when things were supposed to be addressed "in short order".

We're now past tech-debt reduction time: first let's address [mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. #77874,

@MaheshRavishankar
Copy link
Contributor

We can't block code progress because of that. I have been reviewing all the code and have been keeping track of all the changes. This is also a strict addition that isn't part of what needs to be undone from the legacy code.

It is still on my plate to address this. Everyone has time commitments that they have to balance. I will get to it in short order, but we can't hold things hostage because of that.

@stellaraccident
Copy link
Contributor

stellaraccident commented Apr 22, 2024

Definitely needs tests, but I didn't see a need for revert. Should be fine to add tests after the fact. @cxy-1993 and @Abhishek-Varma could we add some tests here

yeah, let's instead be consistent with what we've been doing so far: https://mlir.llvm.org/getting_started/TestingGuide/

Let's keep the bar high. Alex's suggestion is the right one: revert and land with appropriate tests.

Additionally, I believe this should also not have landed for the same reason as I am objecting to #88712 I'll just copy paste here:

Now, I am afraid I cannot subscribe to any large code changes to tiling transforms until #77874 is addressed to my satisfaction. I already gave a pass 6 months ago and another one 3 months ago when things were supposed to be addressed "in short order".

We're now past tech-debt reduction time: first let's address [mlir][TilingInterface] Use LoopLikeOpInterface in tiling using SCF to unify tiling with scf.for and scf.forall. #77874,

Can we get a f2f to talk about this separately? I've had my eye on this for a long time too, and there is a significant hairball that (believe it or not) is being undone. It is just hard and taking more iteration than expected. (I imagine if we talked through it, we might be able to see a shortcut)

ftynse added a commit that referenced this pull request Apr 23, 2024
ftynse added a commit that referenced this pull request Apr 23, 2024
Reverts #85528. This was committed without tests,
despite reviewers requesting tests to be added. The post-commit
discussion leans towards revert, which would be consistent with the
policy.
@ftynse
Copy link
Member

ftynse commented Apr 23, 2024

Reverted in f220c35. Please add tests, open a new PR, and request review from the same people.

@cxy-1993
Copy link
Contributor Author

New PR: #89893, just add tests compare to this version. Please help me review it, thanks. @MaheshRavishankar @ftynse @Abhishek-Varma @nicolasvasilache

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

7 participants