278 changes: 196 additions & 82 deletions mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,12 @@
#include "mlir/Dialect/PDL/IR/PDLTypes.h"
#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/IR/TransformInterfaces.h"
#include "mlir/IR/BlockAndValueMapping.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Parser/Parser.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/ScopeExit.h"
#include "llvm/ADT/StringSet.h"

using namespace mlir;
Expand Down Expand Up @@ -226,78 +229,168 @@ LogicalResult transform::FuseOp::verify() {
// FuseIntoContainingOp
//===----------------------------------------------------------------------===//

static FailureOr<SmallVector<Operation *>> tileAndFuse(Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is now fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUse(Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
return failure();
return nullptr;

// Search the producer slices accessed within the containing operation.
// TODO: Generalize to more extract/insert/parallel_insert triples. Maybe
// evolve into an interface.
SmallVector<tensor::ExtractSliceOp> sliceOps;
for (Operation *user : tileableProducer->getUsers()) {
// TODO: Generalize to more extract/insert/parallel_insert triples.
// Maybe evolve into an interface.
auto it = llvm::find_if(tileableProducer->getUsers(), [&](Operation *user) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (!sliceOp)
continue;
if (!containingOp->isProperAncestor(sliceOp))
return sliceOp && containingOp->isProperAncestor(sliceOp);
});

// Check for a non-empty fusion opportunity.
if (it == tileableProducer->getUsers().end())
return nullptr;
auto sliceOpToTile = cast<tensor::ExtractSliceOp>(*it);

// Try to fuse the producer in-place.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOpToTile);

// Tile the producer.
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer))
return nullptr;

// Replace the extract op.
Operation *fusedOp = tiledProducer->getDefiningOp();
rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));
return fusedOp;
}

/// Find the first "extract" user of `producerOp` and tile it right before its
/// use. The tiled op is now fused under the `containingOp`.
/// Return this fused op on success or nullptr if anything fails.
static Operation *tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
Operation *producerOp, Operation *containingOp, RewriterBase &rewriter) {

auto foreachThreadOp = dyn_cast<scf::ForeachThreadOp>(containingOp);
if (!foreachThreadOp)
return nullptr;

auto tileableProducer = dyn_cast<TilingInterface>(producerOp);
if (!tileableProducer)
return nullptr;

// Search the producer slices accessed within the containing
// operation.
// TODO: Generalize to more extract/insert/parallel_insert triples.
// Maybe evolve into an interface.
OpOperand *pUse;
BlockArgument bbArg;
tensor::ExtractSliceOp sliceOpToTile;
// Only consider slices that may come from the containingOp args.
for (OpOperand &use : tileableProducer->getUses()) {
if (use.getOwner() != containingOp)
continue;
sliceOps.push_back(sliceOp);
pUse = &use;
bbArg = foreachThreadOp.getTiedBlockArgument(&use);
for (Operation *user : bbArg.getUsers()) {
auto sliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
if (!sliceOp)
continue;
if (!containingOp->isAncestor(sliceOp))
continue;
sliceOpToTile = sliceOp;
break;
}
if (sliceOpToTile)
break;
}

// Check for a non-empty list of fusion opportunities.
if (sliceOps.empty())
return failure();
if (!sliceOpToTile || !pUse)
return nullptr;

// Try to fuse the producer in-place.
SmallVector<Operation *> fusedOps;
for (tensor::ExtractSliceOp sliceOp : sliceOps) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOp);
// Ensure there is exactly one destination operand that we can replace the
// ForeachThreadOp bbArg with.
auto destinationOperands = tileableProducer.getDestinationOperands(rewriter);
if (destinationOperands.size() != 1)
return nullptr;

// Tile the producer.
FailureOr<Value> tiledProducer = tileableProducer.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOp.getMixedOffsets(),
sliceOp.getMixedSizes());
if (failed(tiledProducer))
return failure();
fusedOps.push_back(tiledProducer->getDefiningOp());
}
// Try to fuse the producer in-place.
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(sliceOpToTile);

// Replace the use in the tileableProducer before tiling, replace and then
// tile.
BlockAndValueMapping bvm;
bvm.map(destinationOperands.front(), bbArg);
auto tileableProducerClone =
cast<TilingInterface>(rewriter.clone(*tileableProducer, bvm));
auto scopeGuard =
llvm::make_scope_exit([&]() { rewriter.eraseOp(tileableProducerClone); });

// Tile the producer.
FailureOr<Value> tiledProducer =
tileableProducerClone.generateResultTileValue(
rewriter, /*resultNumber=*/0, sliceOpToTile.getMixedOffsets(),
sliceOpToTile.getMixedSizes());
if (failed(tiledProducer))
return nullptr;

// Replace the extract op.
for (const auto &en : enumerate(sliceOps))
rewriter.replaceOp(en.value(), fusedOps[en.index()]->getResult(0));
return fusedOps;
Operation *fusedOp = tiledProducer->getDefiningOp();
rewriter.replaceOp(sliceOpToTile, fusedOp->getResult(0));

// Replace the use in containingOp.
rewriter.startRootUpdate(fusedOp);
containingOp->setOperand(pUse->getOperandNumber(),
destinationOperands.front());
rewriter.finalizeRootUpdate(fusedOp);

return fusedOp;
}

static FailureOr<SmallVector<Operation *>>
cloneAndFuse(Operation *producerOp, Operation *containingOp,
RewriterBase &rewriter) {
static Operation *cloneAndFuseFirstUse(Operation *producerOp,
Operation *containingOp,
RewriterBase &rewriter) {
// Gather all uses inside the containing op.
SmallVector<OpOperand *> uses;
for (OpResult result : producerOp->getOpResults())
for (OpOperand &use : result.getUses())
if (containingOp->isProperAncestor(use.getOwner()))
for (OpResult result : producerOp->getOpResults()) {
for (OpOperand &use : result.getUses()) {
if (containingOp->isProperAncestor(use.getOwner())) {
uses.push_back(&use);
continue;
}
// Cannot clone and fuse if the use is fom the containing op itself: fail.
if (containingOp == use.getOwner())
return nullptr;
}
}

// Check for a non-empty list of fusion opportunities.
if (uses.empty())
return failure();
return nullptr;

// Clone and fuse inside the containing op.
SmallVector<Operation *> fusedOps;
Operation *fusedOp = nullptr;
for (OpOperand *use : uses) {
// Parallel insert slice is not a valid clone destination.
// TODO: Generalize to other type of ops.
assert(!isa<tensor::ParallelInsertSliceOp>(use->getOwner()) &&
"Parallel insert slice is not a valid clone destination");
unsigned resultNumber = use->get().cast<OpResult>().getResultNumber();
OpBuilder::InsertionGuard guard(rewriter);
rewriter.setInsertionPoint(use->getOwner());
Operation *cloned = rewriter.clone(*producerOp);
fusedOp = rewriter.clone(*producerOp);
rewriter.updateRootInPlace(
use->getOwner(), [&] { use->set(cloned->getOpResult(resultNumber)); });
fusedOps.push_back(cloned);
use->getOwner(), [&] { use->set(fusedOp->getOpResult(resultNumber)); });
break;
}

return fusedOps;
return fusedOp;
}

DiagnosedSilenceableFailure
Expand All @@ -312,7 +405,7 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
}
for (Operation *producerOp : producerOps) {
if (producerOp->getNumResults() != 1) {
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "op with != 1 results not supported";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}
Expand All @@ -331,15 +424,17 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
auto getNextProducer = [&]() -> FailureOr<Operation *> {
for (const auto &it : enumerate(remainingProducers)) {
Operation *producerOp = it.value();
bool hasUseInContainingOp =
any_of(producerOp->getUsers(), [&](Operation *op) {
return containingOp->isProperAncestor(op);
// The containing op may be a user of producerOp: use isAncestor.
int64_t numUsesInContainingOp =
llvm::count_if(producerOp->getUsers(), [&](Operation *op) {
return containingOp->isAncestor(op);
});
// TODO: When resolving the TODO below (no duplicate ops), take an op that
// has no use among the remaining producers. This is a topological
// TODO: When resolving the TODO below (no duplicate ops), take an op
// that has no use among the remaining producers. This is a topological
// sorting.
if (hasUseInContainingOp) {
remainingProducers.erase(remainingProducers.begin() + it.index());
if (numUsesInContainingOp > 0) {
if (numUsesInContainingOp == 1)
remainingProducers.erase(remainingProducers.begin() + it.index());
return producerOp;
}
}
Expand All @@ -350,29 +445,42 @@ transform::FuseIntoContainingOp::apply(transform::TransformResults &results,
while (!remainingProducers.empty()) {
auto nextProducer = getNextProducer();
if (failed(nextProducer)) {
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Note);
Diagnostic diag(containingOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse ops into container";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}

Operation *producerOp = *nextProducer;
// TODO: If there are multiple uses of the producer in the containing op, we
// currently tile/clone the op multiple times (once per use). In some cases,
// we can tile/clone once and reuse the value for each use. Futhermore,
// producers should then be traversed according to a topological sorting.
auto tiled = tileAndFuse(producerOp, containingOp, rewriter);
if (succeeded(tiled))
fusedOps.append(*tiled);

auto cloned = cloneAndFuse(producerOp, containingOp, rewriter);
if (succeeded(cloned))
fusedOps.append(*cloned);

if (failed(tiled) && failed(cloned)) {
Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Note);
diag << "could not fuse into containing op";
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
// TODO: If there are multiple uses of the producer in the containing op,
// we currently tile/clone the op multiple times (once per use). In some
// cases, we can tile/clone once and reuse the value for each use.
// Futhermore, producers should then be traversed according to a
// topological sorting.
Operation *tiled =
tileAndFuseFirstExtractUse(producerOp, containingOp, rewriter);
if (tiled) {
fusedOps.push_back(tiled);
continue;
}

Operation *tiledContainingOpOperand =
tileAndFuseFirstExtractUseThroughContainingOpBlockArgument(
producerOp, containingOp, rewriter);
if (tiledContainingOpOperand) {
fusedOps.push_back(tiledContainingOpOperand);
continue;
}

Operation *cloned =
cloneAndFuseFirstUse(producerOp, containingOp, rewriter);
if (cloned) {
fusedOps.push_back(cloned);
continue;
}

Diagnostic diag(producerOp->getLoc(), DiagnosticSeverity::Remark);
diag << "could not fuse " << *producerOp << "into " << *containingOp;
return DiagnosedSilenceableFailure::silenceableFailure(std::move(diag));
}

results.set(getFusedOp().cast<OpResult>(), fusedOps);
Expand Down Expand Up @@ -458,18 +566,18 @@ transform::MatchOp::apply(transform::TransformResults &results,
SmallVector<Operation *> res;
auto matchFun = [&](Operation *op) {
if (getOps().has_value() && !strs.contains(op->getName().getStringRef()))
return WalkResult::advance();
return;

// Interfaces cannot be matched by name, just by ID.
// So we specifically encode the interfaces we care about for this op.
if (getInterface().has_value()) {
auto iface = getInterface().value();
if (iface == transform::MatchInterfaceEnum::LinalgOp &&
!isa<linalg::LinalgOp>(op))
return WalkResult::advance();
return;
if (iface == transform::MatchInterfaceEnum::TilingInterface &&
isa<TilingInterface>(op))
return WalkResult::advance();
return;
}

// Check if all specified attributes match.
Expand All @@ -480,15 +588,21 @@ transform::MatchOp::apply(transform::TransformResults &results,
attr.getName() == getOpsAttrName())
continue;
if (!op->hasAttr(attr.getName()))
return WalkResult::advance();
return;
if (op->getAttr(attr.getName()) != attr.getValue())
return WalkResult::advance();
return;
}
}

if (getFilterResultType().has_value()) {
Type t = getFilterResultType().value();
if (op->getNumResults() != 1 || op->getResultTypes().front() != t)
return;
}

// All constraints are satisfied.
res.push_back(op);
return WalkResult::advance();
return;
};

payloadOps.front()->walk(matchFun);
Expand Down Expand Up @@ -620,9 +734,9 @@ LogicalResult transform::PadOp::verify() {
extractFromI64ArrayAttr(getPaddingDimensions());
if (any_of(paddingDimensions,
[](int64_t paddingDimension) { return paddingDimension < 0; })) {
return emitOpError()
<< "expects padding_dimensions to contain positive integers, found "
<< getPaddingDimensions();
return emitOpError() << "expects padding_dimensions to contain positive "
"integers, found "
<< getPaddingDimensions();
}

SmallVector<int64_t> hoistPaddings =
Expand Down Expand Up @@ -693,8 +807,8 @@ transform::ScalarizeOp::applyToOne(linalg::LinalgOp target,
transform::TransformState &state) {
LinalgTilingOptions tilingOptions;
tilingOptions.scalarizeDynamicDims();
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
// sizes and asserts that it is not already set.
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the
// tile sizes and asserts that it is not already set.
SmallVector<int64_t> emptyTileSizes;
LinalgTilingPattern pattern(getContext(), tilingOptions);
SimpleRewriter rewriter(getContext());
Expand Down Expand Up @@ -841,8 +955,8 @@ LogicalResult SplitOp::verify() {
if ((static_cast<int64_t>(getStaticSplitPoint()) !=
ShapedType::kDynamicSize) ^
(getDynamicSplitPoint() == nullptr)) {
return emitOpError()
<< "expects either a dynamic or a static split point to be provided";
return emitOpError() << "expects either a dynamic or a static split "
"point to be provided";
}
return success();
}
Expand Down Expand Up @@ -1196,8 +1310,8 @@ transform::VectorizeOp::applyToOne(Operation *target,
//===----------------------------------------------------------------------===//

namespace {
/// Registers new ops and declares PDL as dependent dialect since the additional
/// ops are using PDL types for operands and results.
/// Registers new ops and declares PDL as dependent dialect since the
/// additional ops are using PDL types for operands and results.
class LinalgTransformDialectExtension
: public transform::TransformDialectExtension<
LinalgTransformDialectExtension> {
Expand Down
20 changes: 20 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-match.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,26 @@ transform.with_pdl_patterns {

// -----

func.func @by_type() {
%0 = arith.constant 0: i32
// expected-remark @below {{matched op name}}
%1 = arith.constant 1.0 : f32
return
}

transform.with_pdl_patterns {
^bb0(%arg0: !pdl.operation):
transform.sequence %arg0 failures(propagate) {
^bb1(%arg1: !pdl.operation):
%match_name = transform.structured.match
ops{["arith.constant"]} filter_result_type = f32 in %arg1
transform.test_print_remark_at_operand %match_name, "matched op name"
transform.test_consume_operand %match_name
}
}

// -----

#map0 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d0, d2)>
func.func @match_complex_attribute(%arg0: tensor<12x128x32xf32>)
Expand Down