-
Notifications
You must be signed in to change notification settings - Fork 10.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig
#82250
[mlir][Transforms] Encapsulate dialect conversion options in ConversionConfig
#82250
Conversation
@llvm/pr-subscribers-mlir-core @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThis commit adds a new A few existing options are moved to this objects, simplifying the dialect conversion API. Patch is 20.73 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/82250.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 5c91a9498b35d4..8da5dcb0be3fd0 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -24,6 +24,7 @@ namespace mlir {
// Forward declarations.
class Attribute;
class Block;
+struct ConversionConfig;
class ConversionPatternRewriter;
class MLIRContext;
class Operation;
@@ -770,7 +771,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
/// Conversion pattern rewriters must not be used outside of dialect
/// conversions. They apply some IR rewrites in a delayed fashion and could
/// bring the IR into an inconsistent state when used standalone.
- explicit ConversionPatternRewriter(MLIRContext *ctx);
+ explicit ConversionPatternRewriter(MLIRContext *ctx,
+ const ConversionConfig &config);
// Hide unsupported pattern rewriter API.
using OpBuilder::setListener;
@@ -1070,6 +1072,30 @@ class PDLConversionConfig final {
#endif // MLIR_ENABLE_PDL_IN_PATTERNMATCH
+//===----------------------------------------------------------------------===//
+// ConversionConfig
+//===----------------------------------------------------------------------===//
+
+/// Dialect conversion configuration.
+struct ConversionConfig {
+ /// An optional callback used to notify about match failure diagnostics during
+ /// the conversion. Diagnostics are only reported to this callback may only be
+ /// available in debug mode.
+ function_ref<void(Diagnostic &)> notifyCallback = nullptr;
+
+ /// Partial conversion only. All operations that are found not to be
+ /// legalizable are placed in this set. (Note that if there is an op
+ /// explicitly marked as illegal, the conversion terminates and the set will
+ /// not necessarily be complete.)
+ DenseSet<Operation *> *unlegalizedOps = nullptr;
+
+ /// Analysis conversion only. All operations that are found to be legalizable
+ /// are placed in this set. Note that no actual rewrites are applied to the
+ /// IR during an analysis conversion and only pre-existing operations are
+ /// added to the set.
+ DenseSet<Operation *> *legalizableOps = nullptr;
+};
+
//===----------------------------------------------------------------------===//
// Op Conversion Entry Points
//===----------------------------------------------------------------------===//
@@ -1083,19 +1109,16 @@ class PDLConversionConfig final {
/// Apply a partial conversion on the given operations and all nested
/// operations. This method converts as many operations to the target as
/// possible, ignoring operations that failed to legalize. This method only
-/// returns failure if there ops explicitly marked as illegal. If an
-/// `unconvertedOps` set is provided, all operations that are found not to be
-/// legalizable to the given `target` are placed within that set. (Note that if
-/// there is an op explicitly marked as illegal, the conversion terminates and
-/// the `unconvertedOps` set will not necessarily be complete.)
+/// returns failure if there ops explicitly marked as illegal.
LogicalResult
-applyPartialConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
+applyPartialConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps = nullptr);
+ ConversionConfig config = ConversionConfig());
LogicalResult
applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps = nullptr);
+ ConversionConfig config = ConversionConfig());
/// Apply a complete conversion on the given operations, and all nested
/// operations. This method returns failure if the conversion of any operation
@@ -1103,31 +1126,27 @@ applyPartialConversion(Operation *op, const ConversionTarget &target,
/// within 'ops'.
LogicalResult applyFullConversion(ArrayRef<Operation *> ops,
const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
LogicalResult applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns);
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
/// Apply an analysis conversion on the given operations, and all nested
/// operations. This method analyzes which operations would be successfully
/// converted to the target if a conversion was applied. All operations that
/// were found to be legalizable to the given 'target' are placed within the
-/// provided 'convertedOps' set; note that no actual rewrites are applied to the
-/// operations on success and only pre-existing operations are added to the set.
-/// This method only returns failure if there are unreachable blocks in any of
-/// the regions nested within 'ops'. There's an additional argument
-/// `notifyCallback` which is used for collecting match failure diagnostics
-/// generated during the conversion. Diagnostics are only reported to this
-/// callback may only be available in debug mode.
-LogicalResult applyAnalysisConversion(
- ArrayRef<Operation *> ops, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
-LogicalResult applyAnalysisConversion(
- Operation *op, ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+/// provided 'config.legalizableOps' set; note that no actual rewrites are
+/// applied to the operations on success. This method only returns failure if
+/// there are unreachable blocks in any of the regions nested within 'ops'.
+LogicalResult
+applyAnalysisConversion(ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
+LogicalResult
+applyAnalysisConversion(Operation *op, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config = ConversionConfig());
} // namespace mlir
#endif // MLIR_TRANSFORMS_DIALECTCONVERSION_H_
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 6cf178e149be7f..30fc2298b3deb3 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -224,6 +224,8 @@ class IRRewrite {
/// Erase the given block (unless it was already erased).
void eraseBlock(Block *block);
+ const ConversionConfig &getConfig() const;
+
const Kind kind;
ConversionPatternRewriterImpl &rewriterImpl;
};
@@ -723,9 +725,10 @@ static RewriteTy *findSingleRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter)
+ explicit ConversionPatternRewriterImpl(PatternRewriter &rewriter,
+ const ConversionConfig &config)
: rewriter(rewriter), eraseRewriter(rewriter.getContext()),
- notifyCallback(nullptr) {}
+ config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -931,10 +934,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// converting the arguments of blocks within that region.
DenseMap<Region *, const TypeConverter *> regionToConverter;
- /// This allows the user to collect the match failure message.
- function_ref<void(Diagnostic &)> notifyCallback;
-
- DenseSet<Operation *> *trackedOps = nullptr;
+ /// Dialect conversion configuration.
+ const ConversionConfig &config;
#ifndef NDEBUG
/// A set of operations that have pending updates. This tracking isn't
@@ -957,6 +958,10 @@ void IRRewrite::eraseBlock(Block *block) {
rewriterImpl.eraseRewriter.eraseBlock(block);
}
+const ConversionConfig &IRRewrite::getConfig() const {
+ return rewriterImpl.config;
+}
+
void BlockTypeConversionRewrite::commit() {
// Process the remapping for each of the original arguments.
for (unsigned i = 0, e = origBlock->getNumArguments(); i != e; ++i) {
@@ -1074,8 +1079,8 @@ void ReplaceOperationRewrite::commit() {
if (Value newValue =
rewriterImpl.mapping.lookupOrNull(result, result.getType()))
result.replaceAllUsesWith(newValue);
- if (rewriterImpl.trackedOps)
- rewriterImpl.trackedOps->erase(op);
+ if (getConfig().unlegalizedOps)
+ getConfig().unlegalizedOps->erase(op);
// Do not erase the operation yet. It may still be referenced in `mapping`.
op->getBlock()->getOperations().remove(op);
}
@@ -1510,8 +1515,8 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
Diagnostic diag(loc, DiagnosticSeverity::Remark);
reasonCallback(diag);
logger.startLine() << "** Failure : " << diag.str() << "\n";
- if (notifyCallback)
- notifyCallback(diag);
+ if (config.notifyCallback)
+ config.notifyCallback(diag);
});
}
@@ -1519,9 +1524,10 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
// ConversionPatternRewriter
//===----------------------------------------------------------------------===//
-ConversionPatternRewriter::ConversionPatternRewriter(MLIRContext *ctx)
+ConversionPatternRewriter::ConversionPatternRewriter(
+ MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(*this)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -1972,12 +1978,12 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
assert(rewriterImpl.pendingRootUpdates.empty() && "dangling root updates");
LLVM_DEBUG({
logFailure(rewriterImpl.logger, "pattern failed to match");
- if (rewriterImpl.notifyCallback) {
+ if (rewriterImpl.config.notifyCallback) {
Diagnostic diag(op->getLoc(), DiagnosticSeverity::Remark);
diag << "Failed to apply pattern \"" << pattern.getDebugName()
<< "\" on op:\n"
<< *op;
- rewriterImpl.notifyCallback(diag);
+ rewriterImpl.config.notifyCallback(diag);
}
});
rewriterImpl.resetState(curState);
@@ -2365,14 +2371,12 @@ namespace mlir {
struct OperationConverter {
explicit OperationConverter(const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- OpConversionMode mode,
- DenseSet<Operation *> *trackedOps = nullptr)
- : opLegalizer(target, patterns), mode(mode), trackedOps(trackedOps) {}
+ const ConversionConfig &config,
+ OpConversionMode mode)
+ : opLegalizer(target, patterns), config(config), mode(mode) {}
/// Converts the given operations to the conversion target.
- LogicalResult
- convertOperations(ArrayRef<Operation *> ops,
- function_ref<void(Diagnostic &)> notifyCallback = nullptr);
+ LogicalResult convertOperations(ArrayRef<Operation *> ops);
private:
/// Converts an operation with the given rewriter.
@@ -2409,14 +2413,11 @@ struct OperationConverter {
/// The legalizer to use when converting operations.
OperationLegalizer opLegalizer;
+ /// Dialect conversion configuration.
+ ConversionConfig config;
+
/// The conversion mode to use when legalizing operations.
OpConversionMode mode;
-
- /// A set of pre-existing operations. When mode == OpConversionMode::Analysis,
- /// this is populated with ops found to be legalizable to the target.
- /// When mode == OpConversionMode::Partial, this is populated with ops found
- /// *not* to be legalizable to the target.
- DenseSet<Operation *> *trackedOps;
};
} // namespace mlir
@@ -2430,28 +2431,27 @@ LogicalResult OperationConverter::convert(ConversionPatternRewriter &rewriter,
return op->emitError()
<< "failed to legalize operation '" << op->getName() << "'";
// Partial conversions allow conversions to fail iff the operation was not
- // explicitly marked as illegal. If the user provided a nonlegalizableOps
- // set, non-legalizable ops are included.
+ // explicitly marked as illegal. If the user provided a `unlegalizedOps`
+ // set, non-legalizable ops are added to that set.
if (mode == OpConversionMode::Partial) {
if (opLegalizer.isIllegal(op))
return op->emitError()
<< "failed to legalize operation '" << op->getName()
<< "' that was explicitly marked illegal";
- if (trackedOps)
- trackedOps->insert(op);
+ if (config.unlegalizedOps)
+ config.unlegalizedOps->insert(op);
}
} else if (mode == OpConversionMode::Analysis) {
// Analysis conversions don't fail if any operations fail to legalize,
// they are only interested in the operations that were successfully
// legalized.
- trackedOps->insert(op);
+ if (config.legalizableOps)
+ config.legalizableOps->insert(op);
}
return success();
}
-LogicalResult OperationConverter::convertOperations(
- ArrayRef<Operation *> ops,
- function_ref<void(Diagnostic &)> notifyCallback) {
+LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
if (ops.empty())
return success();
const ConversionTarget &target = opLegalizer.getTarget();
@@ -2472,10 +2472,8 @@ LogicalResult OperationConverter::convertOperations(
}
// Convert each operation and discard rewrites on failure.
- ConversionPatternRewriter rewriter(ops.front()->getContext());
+ ConversionPatternRewriter rewriter(ops.front()->getContext(), config);
ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
- rewriterImpl.notifyCallback = notifyCallback;
- rewriterImpl.trackedOps = trackedOps;
for (auto *op : toConvert)
if (failed(convert(rewriter, op)))
@@ -3461,56 +3459,51 @@ void mlir::registerConversionPDLFunctions(RewritePatternSet &patterns) {
//===----------------------------------------------------------------------===//
// Partial Conversion
-LogicalResult
-mlir::applyPartialConversion(ArrayRef<Operation *> ops,
- const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Partial,
- unconvertedOps);
+LogicalResult mlir::applyPartialConversion(
+ ArrayRef<Operation *> ops, const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Partial);
return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyPartialConversion(Operation *op, const ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> *unconvertedOps) {
- return applyPartialConversion(llvm::ArrayRef(op), target, patterns,
- unconvertedOps);
+ ConversionConfig config) {
+ return applyPartialConversion(llvm::ArrayRef(op), target, patterns, config);
}
//===----------------------------------------------------------------------===//
// Full Conversion
-LogicalResult
-mlir::applyFullConversion(ArrayRef<Operation *> ops, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Full);
+LogicalResult mlir::applyFullConversion(ArrayRef<Operation *> ops,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Full);
return opConverter.convertOperations(ops);
}
-LogicalResult
-mlir::applyFullConversion(Operation *op, const ConversionTarget &target,
- const FrozenRewritePatternSet &patterns) {
- return applyFullConversion(llvm::ArrayRef(op), target, patterns);
+LogicalResult mlir::applyFullConversion(Operation *op,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns,
+ ConversionConfig config) {
+ return applyFullConversion(llvm::ArrayRef(op), target, patterns, config);
}
//===----------------------------------------------------------------------===//
// Analysis Conversion
-LogicalResult
-mlir::applyAnalysisConversion(ArrayRef<Operation *> ops,
- ConversionTarget &target,
- const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback) {
- OperationConverter opConverter(target, patterns, OpConversionMode::Analysis,
- &convertedOps);
- return opConverter.convertOperations(ops, notifyCallback);
+LogicalResult mlir::applyAnalysisConversion(
+ ArrayRef<Operation *> ops, ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns, ConversionConfig config) {
+ OperationConverter opConverter(target, patterns, config,
+ OpConversionMode::Analysis);
+ return opConverter.convertOperations(ops);
}
LogicalResult
mlir::applyAnalysisConversion(Operation *op, ConversionTarget &target,
const FrozenRewritePatternSet &patterns,
- DenseSet<Operation *> &convertedOps,
- function_ref<void(Diagnostic &)> notifyCallback) {
- return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns,
- convertedOps, notifyCallback);
+ ConversionConfig config) {
+ return applyAnalysisConversion(llvm::ArrayRef(op), target, patterns, config);
}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 1c02232b8adbb1..c04d5cc80446f4 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1135,8 +1135,10 @@ struct TestLegalizePatternDriver
// Handle a partial conversion.
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
- if (failed(applyPartialConversion(
- getOperation(), target, std::move(patterns), &unlegalizedOps))) {
+ ConversionConfig config;
+ config.unlegalizedOps = &unlegalizedOps;
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns), config))) {
getOperation()->emitRemark() << "applyPartialConversion failed";
}
// Emit remarks for each legalizable operation.
@@ -1164,8 +1166,10 @@ struct TestLegalizePatternDriver
// Analyze the convertible operations.
DenseSet<Operation *> legalizedOps;
+ ConversionConfig config;
+ config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getO...
[truncated]
|
f3fe2f1
to
d33c3cf
Compare
819e5f9
to
577b5eb
Compare
…ionConfig` This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API.
577b5eb
to
58e2c18
Compare
@matthias-springer : the CI failure wasn't a false-positive, this PR broke the window pre-merge for FIR it seems. |
@matthias-springer : see the bot failure here https://lab.llvm.org/buildbot/#/builders/271/builds/4587 |
…`ConversionConfig` (llvm#82250)" This reverts commit 5f1319b. A FIR test is broken on Windows
This is reverted, please reland with a fix :) |
…ionConfig` This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API. This is a re-upload of #82250. This reverts commit 60fbd60.
…ionConfig` This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API. This is a re-upload of #82250. This reverts commit 60fbd60.
Oh, I must have missed that e-mail. I'm still a bit puzzled by the crash. It happens only on Windows and the fix (#83768) seems to be completely unrelated to my change, yet my commit clearly triggered it... |
…ionConfig` (#83754) This commit adds a new `ConversionConfig` struct that allows users to customize the dialect conversion. This configuration is similar to `GreedyRewriteConfig` for the greedy pattern rewrite driver. A few existing options are moved to this objects, simplifying the dialect conversion API. This is a re-upload of #82250. The Windows build breakage was fixed in #83768. This reverts commit 60fbd60.
This commit adds a new
ConversionConfig
struct that allows users to customize the dialect conversion. This configuration is similar toGreedyRewriteConfig
for the greedy pattern rewrite driver.A few existing options are moved to this objects, simplifying the dialect conversion API.