-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Transforms][NFC] Simplify function signatures #155997
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms][NFC] Simplify function signatures #155997
Conversation
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesMany internal functions take a Patch is 26.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155997.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
+ return applySignatureConversion(®ion->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(®ion->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
- *conversion);
+ return applySignatureConversion(®ion->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget ⌖
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
// is 'BeforePatterns'. 'Never' will skip this.
const ConversionConfig &config = rewriter.getConfig();
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
// If the operation can't be legalized via patterns, try to fold it in-place
// if the folding mode is 'AfterPatterns'.
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
llvm::join(insertedBlockNames, ", ") + "}");
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure()...
[truncated]
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesMany internal functions take a Patch is 26.33 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/155997.diff 1 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index b6a216adfdd25..c0685f54731d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,9 +848,10 @@ static bool hasRewrite(R &&rewrites, Block *block) {
namespace mlir {
namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
- explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
+ explicit ConversionPatternRewriterImpl(ConversionPatternRewriter &rewriter,
const ConversionConfig &config)
- : context(ctx), config(config), notifyingRewriter(ctx, config.listener) {}
+ : rewriter(rewriter), config(config),
+ notifyingRewriter(rewriter.getContext(), config.listener) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -887,8 +888,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// is the tag used when describing a value within a diagnostic, e.g.
/// "operand".
LogicalResult remapValues(StringRef valueDiagTag,
- std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped);
/// Return "true" if the given operation is ignored, and does not need to be
@@ -918,8 +918,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// Convert the types of block arguments within the given region.
FailureOr<Block *>
- convertRegionTypes(ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ convertRegionTypes(Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion);
/// Apply the given signature conversion on the given block. The new block
@@ -929,8 +928,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// translate between the origin argument types and those specified in the
/// signature conversion.
Block *applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion);
/// Replace the results of the given operation with the given values and
@@ -1060,8 +1058,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
// State
//===--------------------------------------------------------------------===//
- /// MLIR context.
- MLIRContext *context;
+ /// The rewriter that is used to perform the conversion.
+ ConversionPatternRewriter &rewriter;
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
@@ -1258,16 +1256,17 @@ void UnresolvedMaterializationRewrite::rollback() {
}
void ConversionPatternRewriterImpl::applyRewrites() {
- // Commit all rewrites.
- IRRewriter rewriter(context, config.listener);
+ // Commit all rewrites. Use a new rewriter, so the modifications are not
+ // tracked for rollback purposes etc.
+ IRRewriter irRewriter(rewriter.getContext(), config.listener);
// Note: New rewrites may be added during the "commit" phase and the
// `rewrites` vector may reallocate.
for (size_t i = 0; i < rewrites.size(); ++i)
- rewrites[i]->commit(rewriter);
+ rewrites[i]->commit(irRewriter);
// Clean up all rewrites.
SingleEraseRewriter eraseRewriter(
- context, /*opErasedCallback=*/[&](Operation *op) {
+ rewriter.getContext(), /*opErasedCallback=*/[&](Operation *op) {
if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
unresolvedMaterializations.erase(castOp);
});
@@ -1412,8 +1411,7 @@ void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep,
}
LogicalResult ConversionPatternRewriterImpl::remapValues(
- StringRef valueDiagTag, std::optional<Location> inputLoc,
- PatternRewriter &rewriter, ValueRange values,
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
SmallVector<ValueVector> &remapped) {
remapped.reserve(llvm::size(values));
@@ -1484,8 +1482,7 @@ bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
//===----------------------------------------------------------------------===//
FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
- ConversionPatternRewriter &rewriter, Region *region,
- const TypeConverter &converter,
+ Region *region, const TypeConverter &converter,
TypeConverter::SignatureConversion *entryConversion) {
regionToConverter[region] = &converter;
if (region->empty())
@@ -1500,25 +1497,23 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
if (!conversion)
return failure();
// Convert the block with the computed signature.
- applySignatureConversion(rewriter, &block, &converter, *conversion);
+ applySignatureConversion(&block, &converter, *conversion);
}
// Convert the entry block. If an entry signature conversion was provided,
// use that one. Otherwise, compute the signature with the type converter.
if (entryConversion)
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
+ return applySignatureConversion(®ion->front(), &converter,
*entryConversion);
std::optional<TypeConverter::SignatureConversion> conversion =
converter.convertBlockSignature(®ion->front());
if (!conversion)
return failure();
- return applySignatureConversion(rewriter, ®ion->front(), &converter,
- *conversion);
+ return applySignatureConversion(®ion->front(), &converter, *conversion);
}
Block *ConversionPatternRewriterImpl::applySignatureConversion(
- ConversionPatternRewriter &rewriter, Block *block,
- const TypeConverter *converter,
+ Block *block, const TypeConverter *converter,
TypeConverter::SignatureConversion &signatureConversion) {
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// A block cannot be converted multiple times.
@@ -2023,7 +2018,7 @@ void ConversionPatternRewriterImpl::notifyMatchFailure(
ConversionPatternRewriter::ConversionPatternRewriter(
MLIRContext *ctx, const ConversionConfig &config)
: PatternRewriter(ctx),
- impl(new detail::ConversionPatternRewriterImpl(ctx, config)) {
+ impl(new detail::ConversionPatternRewriterImpl(*this, config)) {
setListener(impl.get());
}
@@ -2100,7 +2095,7 @@ Block *ConversionPatternRewriter::applySignatureConversion(
assert(!impl->wasOpReplaced(block->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->applySignatureConversion(*this, block, converter, conversion);
+ return impl->applySignatureConversion(block, converter, conversion);
}
FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
@@ -2109,7 +2104,7 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
assert(!impl->wasOpReplaced(region->getParentOp()) &&
"attempting to apply a signature conversion to a block within a "
"replaced/erased op");
- return impl->convertRegionTypes(*this, region, converter, entryConversion);
+ return impl->convertRegionTypes(region, converter, entryConversion);
}
void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
@@ -2128,7 +2123,7 @@ void ConversionPatternRewriter::replaceUsesOfBlockArgument(BlockArgument from,
Value ConversionPatternRewriter::getRemappedValue(Value key) {
SmallVector<ValueVector> remappedValues;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, key,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, key,
remappedValues)))
return nullptr;
assert(remappedValues.front().size() == 1 && "1:N conversion not supported");
@@ -2141,7 +2136,7 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
if (keys.empty())
return success();
SmallVector<ValueVector> remapped;
- if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, *this, keys,
+ if (failed(impl->remapValues("value", /*inputLoc=*/std::nullopt, keys,
remapped)))
return failure();
for (const auto &values : remapped) {
@@ -2288,7 +2283,7 @@ ConversionPattern::matchAndRewrite(Operation *op,
// Remap the operands of the operation.
SmallVector<ValueVector> remapped;
- if (failed(rewriterImpl.remapValues("operand", op->getLoc(), rewriter,
+ if (failed(rewriterImpl.remapValues("operand", op->getLoc(),
op->getOperands(), remapped))) {
return failure();
}
@@ -2310,7 +2305,8 @@ class OperationLegalizer {
public:
using LegalizationAction = ConversionTarget::LegalizationAction;
- OperationLegalizer(const ConversionTarget &targetInfo,
+ OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns);
/// Returns true if the given operation is known to be illegal on the target.
@@ -2318,29 +2314,25 @@ class OperationLegalizer {
/// Attempt to legalize the given operation. Returns success if the operation
/// was legalized, failure otherwise.
- LogicalResult legalize(Operation *op, ConversionPatternRewriter &rewriter);
+ LogicalResult legalize(Operation *op);
/// Returns the conversion target in use by the legalizer.
const ConversionTarget &getTarget() { return target; }
private:
/// Attempt to legalize the given operation by folding it.
- LogicalResult legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithFold(Operation *op);
/// Attempt to legalize the given operation by applying a pattern. Returns
/// success if the operation was legalized, failure otherwise.
- LogicalResult legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter);
+ LogicalResult legalizeWithPattern(Operation *op);
/// Return true if the given pattern may be applied to the given operation,
/// false otherwise.
- bool canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter);
+ bool canApplyPattern(Operation *op, const Pattern &pattern);
/// Legalize the resultant IR after successfully applying the given pattern.
LogicalResult legalizePatternResult(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter,
const RewriterState &curState,
const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
@@ -2349,18 +2341,12 @@ class OperationLegalizer {
/// Legalizes the actions registered during the execution of a pattern.
LogicalResult
legalizePatternBlockRewrites(Operation *op,
- ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternCreatedOperations(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &newOps);
+ legalizePatternCreatedOperations(const SetVector<Operation *> &newOps);
LogicalResult
- legalizePatternRootUpdates(ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Operation *> &modifiedOps);
+ legalizePatternRootUpdates(const SetVector<Operation *> &modifiedOps);
//===--------------------------------------------------------------------===//
// Cost Model
@@ -2403,6 +2389,9 @@ class OperationLegalizer {
/// The current set of patterns that have been applied.
SmallPtrSet<const Pattern *, 8> appliedPatterns;
+ /// The rewriter to use when converting operations.
+ ConversionPatternRewriter &rewriter;
+
/// The legalization information provided by the target.
const ConversionTarget ⌖
@@ -2411,9 +2400,10 @@ class OperationLegalizer {
};
} // namespace
-OperationLegalizer::OperationLegalizer(const ConversionTarget &targetInfo,
+OperationLegalizer::OperationLegalizer(ConversionPatternRewriter &rewriter,
+ const ConversionTarget &targetInfo,
const FrozenRewritePatternSet &patterns)
- : target(targetInfo), applicator(patterns) {
+ : rewriter(rewriter), target(targetInfo), applicator(patterns) {
// The set of patterns that can be applied to illegal operations to transform
// them into legal ones.
DenseMap<OperationName, LegalizationPatterns> legalizerPatterns;
@@ -2427,9 +2417,7 @@ bool OperationLegalizer::isIllegal(Operation *op) const {
return target.isIllegal(op);
}
-LogicalResult
-OperationLegalizer::legalize(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalize(Operation *op) {
#ifndef NDEBUG
const char *logLineComment =
"//===-------------------------------------------===//\n";
@@ -2495,7 +2483,7 @@ OperationLegalizer::legalize(Operation *op,
// is 'BeforePatterns'. 'Never' will skip this.
const ConversionConfig &config = rewriter.getConfig();
if (config.foldingMode == DialectConversionFoldingMode::BeforePatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2505,7 +2493,7 @@ OperationLegalizer::legalize(Operation *op,
}
// Otherwise, we need to apply a legalization pattern to this operation.
- if (succeeded(legalizeWithPattern(op, rewriter))) {
+ if (succeeded(legalizeWithPattern(op))) {
LLVM_DEBUG({
logSuccess(logger, "");
logger.startLine() << logLineComment;
@@ -2516,7 +2504,7 @@ OperationLegalizer::legalize(Operation *op,
// If the operation can't be legalized via patterns, try to fold it in-place
// if the folding mode is 'AfterPatterns'.
if (config.foldingMode == DialectConversionFoldingMode::AfterPatterns) {
- if (succeeded(legalizeWithFold(op, rewriter))) {
+ if (succeeded(legalizeWithFold(op))) {
LLVM_DEBUG({
logSuccess(logger, "operation was folded");
logger.startLine() << logLineComment;
@@ -2541,9 +2529,7 @@ static T moveAndReset(T &obj) {
return result;
}
-LogicalResult
-OperationLegalizer::legalizeWithFold(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithFold(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
@@ -2577,14 +2563,14 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
- return legalize(op, rewriter);
+ return legalize(op);
// Insert a replacement for 'op' with the folded replacement values.
rewriter.replaceOp(op, replacementValues);
// Recursively legalize any new constant operations.
for (Operation *newOp : newOps) {
- if (failed(legalize(newOp, rewriter))) {
+ if (failed(legalize(newOp))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
newOp->getName()));
@@ -2629,9 +2615,7 @@ reportNewIrLegalizationFatalError(const Pattern &pattern,
llvm::join(insertedBlockNames, ", ") + "}");
}
-LogicalResult
-OperationLegalizer::legalizeWithPattern(Operation *op,
- ConversionPatternRewriter &rewriter) {
+LogicalResult OperationLegalizer::legalizeWithPattern(Operation *op) {
auto &rewriterImpl = rewriter.getImpl();
const ConversionConfig &config = rewriter.getConfig();
@@ -2663,7 +2647,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
// Functor that returns if the given pattern may be applied.
auto canApply = [&](const Pattern &pattern) {
- bool canApply = canApplyPattern(op, pattern, rewriter);
+ bool canApply = canApplyPattern(op, pattern);
if (canApply && config.listener)
config.listener->notifyPatternBegin(pattern, op);
return canApply;
@@ -2728,7 +2712,7 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
moveAndReset(rewriterImpl.patternModifiedOps);
SetVector<Block *> insertedBlocks =
moveAndReset(rewriterImpl.patternInsertedBlocks);
- auto result = legalizePatternResult(op, pattern, rewriter, curState, newOps,
+ auto result = legalizePatternResult(op, pattern, curState, newOps,
modifiedOps, insertedBlocks);
appliedPatterns.erase(&pattern);
if (failed(result)) {
@@ -2747,8 +2731,8 @@ OperationLegalizer::legalizeWithPattern(Operation *op,
onSuccess);
}
-bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
- ConversionPatternRewriter &rewriter) {
+bool OperationLegalizer::canApplyPattern(Operation *op,
+ const Pattern &pattern) {
LLVM_DEBUG({
auto &os = rewriter.getImpl().logger;
os.getOStream() << "\n";
@@ -2770,8 +2754,8 @@ bool OperationLegalizer::canApplyPattern(Operation *op, const Pattern &pattern,
}
LogicalResult OperationLegalizer::legalizePatternResult(
- Operation *op, const Pattern &pattern, ConversionPatternRewriter &rewriter,
- const RewriterState &curState, const SetVector<Operation *> &newOps,
+ Operation *op, const Pattern &pattern, const RewriterState &curState,
+ const SetVector<Operation *> &newOps,
const SetVector<Operation *> &modifiedOps,
const SetVector<Block *> &insertedBlocks) {
auto &impl = rewriter.getImpl();
@@ -2792,10 +2776,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Legalize each of the actions registered during application.
- if (failed(legalizePatternBlockRewrites(op, rewriter, impl, insertedBlocks,
- newOps)) ||
- failed(legalizePatternRootUpdates(rewriter, impl, modifiedOps)) ||
- failed(legalizePatternCreatedOperations(rewriter, impl, newOps))) {
+ if (failed(legalizePatternBlockRewrites(op, insertedBlocks, newOps)) ||
+ failed(legalizePatternRootUpdates(modifiedOps)) ||
+ failed(legalizePatternCreatedOperations(newOps))) {
return failure();
}
@@ -2804,10 +2787,9 @@ LogicalResult OperationLegalizer::legalizePatternResult(
}
LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
- Operation *op, ConversionPatternRewriter &rewriter,
- ConversionPatternRewriterImpl &impl,
- const SetVector<Block *> &insertedBlocks,
+ Operation *op, const SetVector<Block *> &insertedBlocks,
const SetVector<Operation *> &newOps) {
+ ConversionPatternRewriterImpl &impl = rewriter.getImpl();
SmallPtrSet<Operation *, 16> alreadyLegalized;
// If the pattern moved or created any blocks, make sure the types of block
@@ -2831,7 +2813,7 @@ LogicalResult OperationLegalizer::legalizePatternBlockRewrites(
"block"));
return failure()...
[truncated]
|
Many internal functions take a
ConversionPatternRewriter &
orConversionPatternRewriterImpl &
as a parameter. There's only a single instance of these classes, so it's better to store the reference in a field. This commit is in preparation of another PR that will require access toConversionPatternRewriter
in additional helper functions.Note: Public API does not change.