Skip to content

Commit

Permalink
Add tracing for pattern application in a ApplyPatternAction
Browse files Browse the repository at this point in the history
Reviewed By: rriddle

Differential Revision: https://reviews.llvm.org/D144816
  • Loading branch information
joker-eph committed Apr 11, 2023
1 parent 84eed78 commit e24b91b
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 26 deletions.
22 changes: 22 additions & 0 deletions mlir/include/mlir/Rewrite/PatternApplicator.h
Expand Up @@ -16,13 +16,35 @@

#include "mlir/Rewrite/FrozenRewritePatternSet.h"

#include "mlir/IR/Action.h"

namespace mlir {
class PatternRewriter;

namespace detail {
class PDLByteCodeMutableState;
} // namespace detail

/// This is the type of Action that is dispatched when a pattern is applied.
/// It captures the pattern to apply on top of the usual context.
class ApplyPatternAction : public tracing::ActionImpl<ApplyPatternAction> {
public:
using Base = tracing::ActionImpl<ApplyPatternAction>;
ApplyPatternAction(ArrayRef<IRUnit> irUnits, const Pattern &pattern)
: Base(irUnits), pattern(pattern) {}
static constexpr StringLiteral tag = "apply-pattern-action";
static constexpr StringLiteral desc =
"Encapsulate the application of rewrite patterns";

void print(raw_ostream &os) const override {
os << "`" << tag << "`\n"
<< " pattern: " << pattern.getDebugName() << '\n';
}

private:
const Pattern &pattern;
};

/// This class manages the application of a group of rewrite patterns, with a
/// user-provided cost model.
class PatternApplicator {
Expand Down
64 changes: 38 additions & 26 deletions mlir/lib/Rewrite/PatternApplicator.cpp
Expand Up @@ -185,35 +185,47 @@ LogicalResult PatternApplicator::matchAndRewrite(
// Try to match and rewrite this pattern. The patterns are sorted by
// benefit, so if we match we can immediately rewrite. For PDL patterns, the
// match has already been performed, we just need to rewrite.
rewriter.setInsertionPoint(op);
bool matched = false;
op->getContext()->executeAction<ApplyPatternAction>(
[&]() {
rewriter.setInsertionPoint(op);
#ifndef NDEBUG
// Operation `op` may be invalidated after applying the rewrite pattern.
Operation *dumpRootOp = getDumpRootOp(op);
// Operation `op` may be invalidated after applying the rewrite
// pattern.
Operation *dumpRootOp = getDumpRootOp(op);
#endif
if (pdlMatch) {
result = bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
} else {
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
<< bestPattern->getDebugName() << "\"\n");

const auto *pattern = static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);

LLVM_DEBUG(llvm::dbgs() << "\"" << bestPattern->getDebugName()
<< "\" result " << succeeded(result) << "\n");
}

// Process the result of the pattern application.
if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
result = failure();
if (succeeded(result)) {
LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
if (pdlMatch) {
result =
bytecode->rewrite(rewriter, *pdlMatch, *mutableByteCodeState);
} else {
LLVM_DEBUG(llvm::dbgs() << "Trying to match \""
<< bestPattern->getDebugName() << "\"\n");

const auto *pattern =
static_cast<const RewritePattern *>(bestPattern);
result = pattern->matchAndRewrite(op, rewriter);

LLVM_DEBUG(llvm::dbgs()
<< "\"" << bestPattern->getDebugName() << "\" result "
<< succeeded(result) << "\n");
}

// Process the result of the pattern application.
if (succeeded(result) && onSuccess && failed(onSuccess(*bestPattern)))
result = failure();
if (succeeded(result)) {
LLVM_DEBUG(logSucessfulPatternApplication(dumpRootOp));
matched = true;
return;
}

// Perform any necessary cleanups.
if (onFailure)
onFailure(*bestPattern);
},
{op}, *bestPattern);
if (matched)
break;
}

// Perform any necessary cleanups.
if (onFailure)
onFailure(*bestPattern);
} while (true);

if (mutableByteCodeState)
Expand Down

0 comments on commit e24b91b

Please sign in to comment.