Expand Up
@@ -17,9 +17,12 @@
#include " mlir/Dialect/PDL/IR/PDLTypes.h"
#include " mlir/Dialect/Transform/IR/TransformDialect.h"
#include " mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include " mlir/IR/BlockAndValueMapping.h"
#include " mlir/Interfaces/TilingInterface.h"
#include " mlir/Parser/Parser.h"
#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
#include " llvm/ADT/STLExtras.h"
#include " llvm/ADT/ScopeExit.h"
#include " llvm/ADT/StringSet.h"
using namespace mlir ;
Expand Down
Expand Up
@@ -226,78 +229,168 @@ LogicalResult transform::FuseOp::verify() {
// FuseIntoContainingOp
// ===----------------------------------------------------------------------===//
static FailureOr<SmallVector<Operation *>> tileAndFuse (Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
// / Find the first "extract" user of `producerOp` and tile it right before its
// / use. The tiled op is now fused under the `containingOp`.
// / Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUse (Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
return failure () ;
return nullptr ;
// Search the producer slices accessed within the containing operation.
// TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
// evolve into an interface.
SmallVector<tensor::ExtractSliceOp> sliceOps;
for (Operation *user : tileableProducer->getUsers ()) {
// TODO: Generalize to more extract/insert/parallel_insert triples.
// Maybe evolve into an interface.
auto it = llvm::find_if (tileableProducer->getUsers (), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (!sliceOp)
continue ;
if (!containingOp->isProperAncestor (sliceOp))
return sliceOp && containingOp->isProperAncestor (sliceOp);
});
// Check for a non-empty fusion opportunity.
if (it == tileableProducer->getUsers ().end ())
return nullptr ;
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);
// Try to fuse the producer in-place.
OpBuilder::InsertionGuard guard (rewriter);
rewriter.setInsertionPoint (sliceOpToTile);
// Tile the producer.
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue (
rewriter, /* resultNumber=*/ 0 , sliceOpToTile.getMixedOffsets (),
sliceOpToTile.getMixedSizes ());
if (failed (tiledProducer))
return nullptr ;
// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp ();
rewriter.replaceOp (sliceOpToTile, fusedOp->getResult (0 ));
return fusedOp;
}
// / Find the first "extract" user of `producerOp` and tile it right before its
// / use. The tiled op is now fused under the `containingOp`.
// / Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (
Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {
auto foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(containingOp);
if (!foreachThreadOp)
return nullptr ;
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
return nullptr ;
// Search the producer slices accessed within the containing
// operation.
// TODO: Generalize to more extract/insert/parallel_insert triples.
// Maybe evolve into an interface.
OpOperand *pUse;
BlockArgument bbArg;
tensor::ExtractSliceOp sliceOpToTile;
// Only consider slices that may come from the containingOp args.
for (OpOperand &use : tileableProducer->getUses ()) {
if (use.getOwner () != containingOp)
continue ;
sliceOps.push_back (sliceOp);
pUse = &use;
bbArg = foreachThreadOp.getTiedBlockArgument (&use);
for (Operation *user : bbArg.getUsers ()) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (!sliceOp)
continue ;
if (!containingOp->isAncestor (sliceOp))
continue ;
sliceOpToTile = sliceOp;
break ;
}
if (sliceOpToTile)
break ;
}
// Check for a non-empty list of fusion opportunities.
if (sliceOps. empty () )
return failure () ;
if (!sliceOpToTile || !pUse )
return nullptr ;
// Try to fuse the producer in-place.
SmallVector<Operation *> fusedOps;
for (tensor::ExtractSliceOp sliceOp : sliceOps) {
OpBuilder::InsertionGuard guard (rewriter);
rewriter. setInsertionPoint (sliceOp) ;
// Ensure there is exactly one destination operand that we can replace the
// ForeachThreadOp bbArg with.
auto destinationOperands = tileableProducer. getDestinationOperands (rewriter);
if (destinationOperands. size () != 1 )
return nullptr ;
// Tile the producer.
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue (
rewriter, /* resultNumber=*/ 0 , sliceOp.getMixedOffsets (),
sliceOp.getMixedSizes ());
if (failed (tiledProducer))
return failure ();
fusedOps.push_back (tiledProducer->getDefiningOp ());
}
// Try to fuse the producer in-place.
OpBuilder::InsertionGuard guard (rewriter);
rewriter.setInsertionPoint (sliceOpToTile);
// Replace the use in the tileableProducer before tiling, replace and then
// tile.
BlockAndValueMapping bvm;
bvm.map (destinationOperands.front (), bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone (*tileableProducer, bvm));
auto scopeGuard =
llvm::make_scope_exit ([&]() { rewriter.eraseOp (tileableProducerClone); });
// Tile the producer.
FailureOr<Value> tiledProducer =
tileableProducerClone.generateResultTileValue (
rewriter, /* resultNumber=*/ 0 , sliceOpToTile.getMixedOffsets (),
sliceOpToTile.getMixedSizes ());
if (failed (tiledProducer))
return nullptr ;
// Replace the extract op.
for (const auto &en : enumerate(sliceOps))
rewriter.replaceOp (en.value (), fusedOps[en.index ()]->getResult (0 ));
return fusedOps;
Operation *fusedOp = tiledProducer->getDefiningOp ();
rewriter.replaceOp (sliceOpToTile, fusedOp->getResult (0 ));
// Replace the use in containingOp.
rewriter.startRootUpdate (fusedOp);
containingOp->setOperand (pUse->getOperandNumber (),
destinationOperands.front ());
rewriter.finalizeRootUpdate (fusedOp);
return fusedOp;
}
static FailureOr<SmallVector< Operation *>>
cloneAndFuse (Operation *producerOp, Operation *containingOp,
RewriterBase &rewriter) {
static Operation *cloneAndFuseFirstUse (Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
for (OpResult result : producerOp->getOpResults ())
for (OpOperand &use : result.getUses ())
if (containingOp->isProperAncestor (use.getOwner ()))
for (OpResult result : producerOp->getOpResults ()) {
for (OpOperand &use : result.getUses ()) {
if (containingOp->isProperAncestor (use.getOwner ())) {
uses.push_back (&use);
continue ;
}
// Cannot clone and fuse if the use is fom the containing op itself: fail.
if (containingOp == use.getOwner ())
return nullptr ;
}
}
// Check for a non-empty list of fusion opportunities.
if (uses.empty ())
return failure () ;
return nullptr ;
// Clone and fuse inside the containing op.
SmallVector< Operation *> fusedOps ;
Operation *fusedOp = nullptr ;
for (OpOperand *use : uses) {
// Parallel insert slice is not a valid clone destination.
// TODO: Generalize to other type of ops.
assert (!isa<tensor::ParallelInsertSliceOp>(use->getOwner ()) &&
" Parallel insert slice is not a valid clone destination" );
unsigned resultNumber = use->get ().cast <OpResult>().getResultNumber ();
OpBuilder::InsertionGuard guard (rewriter);
rewriter.setInsertionPoint (use->getOwner ());
Operation *cloned = rewriter.clone (*producerOp);
fusedOp = rewriter.clone (*producerOp);
rewriter.updateRootInPlace (
use->getOwner (), [&] { use->set (cloned ->getOpResult (resultNumber)); });
fusedOps. push_back (cloned) ;
use->getOwner (), [&] { use->set (fusedOp ->getOpResult (resultNumber)); });
break ;
}
return fusedOps ;
return fusedOp ;
}
DiagnosedSilenceableFailure
Expand All
@@ -312,7 +405,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
}
for (Operation *producerOp : producerOps) {
if (producerOp->getNumResults () != 1 ) {
Diagnostic diag (producerOp->getLoc (), DiagnosticSeverity::Note );
Diagnostic diag (producerOp->getLoc (), DiagnosticSeverity::Remark );
diag << " op with != 1 results not supported" ;
return DiagnosedSilenceableFailure::silenceableFailure (std::move (diag));
}
Expand All
@@ -331,15 +424,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value ();
bool hasUseInContainingOp =
any_of (producerOp->getUsers (), [&](Operation *op) {
return containingOp->isProperAncestor (op);
// The containing op may be a user of producerOp: use isAncestor.
int64_t numUsesInContainingOp =
llvm::count_if (producerOp->getUsers (), [&](Operation *op) {
return containingOp->isAncestor (op);
});
// TODO: When resolving the TODO below (no duplicate ops), take an op that
// has no use among the remaining producers. This is a topological
// TODO: When resolving the TODO below (no duplicate ops), take an op
// that has no use among the remaining producers. This is a topological
// sorting.
if (hasUseInContainingOp) {
remainingProducers.erase (remainingProducers.begin () + it.index ());
if (numUsesInContainingOp > 0 ) {
if (numUsesInContainingOp == 1 )
remainingProducers.erase (remainingProducers.begin () + it.index ());
return producerOp;
}
}
Expand All
@@ -350,29 +445,42 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
while (!remainingProducers.empty ()) {
auto nextProducer = getNextProducer ();
if (failed (nextProducer)) {
Diagnostic diag (containingOp->getLoc (), DiagnosticSeverity::Note );
Diagnostic diag (containingOp->getLoc (), DiagnosticSeverity::Remark );
diag << " could not fuse ops into container" ;
return DiagnosedSilenceableFailure::silenceableFailure (std::move (diag));
}
Operation *producerOp = *nextProducer;
// TODO: If there are multiple uses of the producer in the containing op, we
// currently tile/clone the op multiple times (once per use). In some cases,
// we can tile/clone once and reuse the value for each use. Futhermore,
// producers should then be traversed according to a topological sorting.
auto tiled = tileAndFuse (producerOp, containingOp, rewriter);
if (succeeded (tiled))
fusedOps.append (*tiled);
auto cloned = cloneAndFuse (producerOp, containingOp, rewriter);
if (succeeded (cloned))
fusedOps.append (*cloned);
if (failed (tiled) && failed (cloned)) {
Diagnostic diag (producerOp->getLoc (), DiagnosticSeverity::Note);
diag << " could not fuse into containing op" ;
return DiagnosedSilenceableFailure::silenceableFailure (std::move (diag));
// TODO: If there are multiple uses of the producer in the containing op,
// we currently tile/clone the op multiple times (once per use). In some
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
Operation *tiled =
tileAndFuseFirstExtractUse (producerOp, containingOp, rewriter);
if (tiled) {
fusedOps.push_back (tiled);
continue ;
}
Operation *tiledContainingOpOperand =
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument (
producerOp, containingOp, rewriter);
if (tiledContainingOpOperand) {
fusedOps.push_back (tiledContainingOpOperand);
continue ;
}
Operation *cloned =
cloneAndFuseFirstUse (producerOp, containingOp, rewriter);
if (cloned) {
fusedOps.push_back (cloned);
continue ;
}
Diagnostic diag (producerOp->getLoc (), DiagnosticSeverity::Remark);
diag << " could not fuse " << *producerOp << " into " << *containingOp;
return DiagnosedSilenceableFailure::silenceableFailure (std::move (diag));
}
results.set (getFusedOp ().cast <OpResult>(), fusedOps);
Expand Down
Expand Up
@@ -458,18 +566,18 @@ transform::MatchOp::apply(transform::TransformResults &results,
SmallVector<Operation *> res;
auto matchFun = [&](Operation *op) {
if (getOps ().has_value () && !strs.contains (op->getName ().getStringRef ()))
return WalkResult::advance () ;
return ;
// Interfaces cannot be matched by name, just by ID.
// So we specifically encode the interfaces we care about for this op.
if (getInterface ().has_value ()) {
auto iface = getInterface ().value ();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
!isa<linalg::LinalgOp>(op))
return WalkResult::advance () ;
return ;
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
return WalkResult::advance () ;
return ;
}
// Check if all specified attributes match.
Expand All
@@ -480,15 +588,21 @@ transform::MatchOp::apply(transform::TransformResults &results,
attr.getName () == getOpsAttrName ())
continue ;
if (!op->hasAttr (attr.getName ()))
return WalkResult::advance () ;
return ;
if (op->getAttr (attr.getName ()) != attr.getValue ())
return WalkResult::advance () ;
return ;
}
}
if (getFilterResultType ().has_value ()) {
Type t = getFilterResultType ().value ();
if (op->getNumResults () != 1 || op->getResultTypes ().front () != t)
return ;
}
// All constraints are satisfied.
res.push_back (op);
return WalkResult::advance () ;
return ;
};
payloadOps.front ()->walk (matchFun);
Expand Down
Expand Up
@@ -620,9 +734,9 @@ LogicalResult transform::PadOp::verify() {
extractFromI64ArrayAttr (getPaddingDimensions ());
if (any_of (paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0 ; })) {
return emitOpError ()
<< " expects padding_dimensions to contain positive integers, found "
<< getPaddingDimensions ();
return emitOpError () << " expects padding_dimensions to contain positive "
" integers, found "
<< getPaddingDimensions ();
}
SmallVector<int64_t > hoistPaddings =
Expand Down
Expand Up
@@ -693,8 +807,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
transform::TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims ();
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
// sizes and asserts that it is not already set.
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
// tile sizes and asserts that it is not already set.
SmallVector<int64_t > emptyTileSizes;
LinalgTilingPattern pattern (getContext (), tilingOptions);
SimpleRewriter rewriter (getContext ());
Expand Down
Expand Up
@@ -841,8 +955,8 @@ LogicalResult SplitOp::verify() {
if ((static_cast <int64_t >(getStaticSplitPoint ()) !=
ShapedType::kDynamicSize ) ^
(getDynamicSplitPoint () == nullptr )) {
return emitOpError ()
<< " expects either a dynamic or a static split point to be provided" ;
return emitOpError () << " expects either a dynamic or a static split "
" point to be provided" ;
}
return success ();
}
Expand Down
Expand Up
@@ -1196,8 +1310,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
// ===----------------------------------------------------------------------===//
namespace {
// / Registers new ops and declares PDL as dependent dialect since the additional
// / ops are using PDL types for operands and results.
// / Registers new ops and declares PDL as dependent dialect since the
// / additional ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
Expand Down