Skip to content

Commit

Permalink
[mlir][linalg] Make fusion on tensor rewriter friendly (NFC).
Browse files Browse the repository at this point in the history
Let the calling pass or pattern replace the uses of the original root operation. Internally, the tileAndFuse still replaces uses and updates operands but only of newly created operations.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110169
  • Loading branch information
Tobias Gysi committed Sep 27, 2021
1 parent d5629b5 commit e158b56
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
4 changes: 3 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -199,9 +199,11 @@ class TileLoopNest {

/// Fuse the producer of `rootOpOperand` into the tile loop nest. Returns the
/// fused producer of fails if fusion is not possible.
// TODO: add replace uses callback to support passes and patterns.
FailureOr<LinalgOp> fuseProducer(OpBuilder &b, OpOperand *rootOpOperand);

/// Returns the replacement results for the original untiled root operation.
ValueRange getRootOpReplacementResults();

/// Returns the tiled root operation.
LinalgOp getRootOp() { return rootOp; }

Expand Down
22 changes: 18 additions & 4 deletions mlir/lib/Dialect/Linalg/Transforms/FusionOnTensors.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,15 @@ LogicalResult TileLoopNest::tileRootOp(OpBuilder &b,
.setLoopType(LinalgTilingLoopType::Loops);
Optional<TiledLinalgOp> tiledRootOp = tileLinalgOp(b, rootOp, tilingOptions);

// Replace all uses of the root operation.
// Exit if tiling the root operation fails.
if (!tiledRootOp.hasValue())
return failure();
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);

// Replace all uses of the root operation if it has been tiled before. All
// uses of the original untiled root operation are updated by the calling pass
// or pattern.
if (!isEmpty())
rootOp->replaceAllUsesWith(tiledRootOp->tensorResults);

// Update the root operation and append the loops and tile loop dimensions.
rootOp = tiledRootOp->op;
Expand Down Expand Up @@ -323,6 +328,11 @@ FailureOr<LinalgOp> TileLoopNest::fuseProducer(OpBuilder &b,
return clonedOp;
}

ValueRange TileLoopNest::getRootOpReplacementResults() {
assert(!isEmpty() && "expect tile loop nest to be non-empty");
return loopOps.front()->getOpResults();
}

//===----------------------------------------------------------------------===//
// Tile and fuse entry-points.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -433,9 +443,13 @@ struct LinalgTileAndFuseTensorOps
"expect the tile interchange permutes the root loops");

// Tile `rootOp` and fuse its producers.
if (failed(tileConsumerAndFuseProducers(b, rootOp, rootTileSizes,
rootInterchange)))
FailureOr<TileLoopNest> tileLoopNest =
tileConsumerAndFuseProducers(b, rootOp, rootTileSizes, rootInterchange);
if (failed(tileLoopNest))
return notifyFailure("tileConsumerAndFuseProducers failed unexpectedly");

// Replace all uses of the tiled loop operation.
rootOp->replaceAllUsesWith(tileLoopNest->getRootOpReplacementResults());
}
};
} // namespace
Expand Down

0 comments on commit e158b56

Please sign in to comment.