Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions mlir/include/mlir/IR/PatternMatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -675,9 +675,9 @@ class RewriterBase : public OpBuilder {
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
virtual void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
Expand Down
21 changes: 21 additions & 0 deletions mlir/include/mlir/Transforms/DialectConversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -903,6 +903,27 @@ class ConversionPatternRewriter final : public PatternRewriter {
replaceAllUsesWith(from, ValueRange{to});
}

/// Replace the uses of `from` with `to` for which the `functor` returns
/// "true". The conversion driver will try to reconcile all type mismatches
/// that still exist at the end of the conversion with materializations.
/// This function supports both 1:1 and 1:N replacements.
///
/// Note: The functor is also applied to builtin.unrealized_conversion_cast
/// ops that may have been inserted by the conversion driver. Some uses may
/// have been wrapped in unrealized_conversion_cast ops due to type changes.
///
/// Note: This function is not supported in rollback mode. Calling it in
/// rollback mode will trigger an assertion. Furthermore, the
/// `allUsesReplaced` flag is not supported yet.
void replaceUsesWithIf(Value from, Value to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr) override {
replaceUsesWithIf(from, ValueRange{to}, functor, allUsesReplaced);
}
void replaceUsesWithIf(Value from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);

/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
Expand Down
12 changes: 10 additions & 2 deletions mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,16 @@ static void restoreByValRefArgumentType(
Type resTy = typeConverter.convertType(
cast<TypeAttr>(byValRefAttr->getValue()).getValue());

Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
rewriter.replaceAllUsesWith(arg, valueArg);
auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
if (!rewriter.getConfig().allowPatternRollback) {
rewriter.replaceAllUsesExcept(arg, loadOp, loadOp);
} else {
// replaceAllUsesExcept is not supported in rollback mode. The rollback
// mode implementation has a workaround: certain replacements that would
// cause a dominance violation are skipped.
// TODO: Remove workaround.
rewriter.replaceAllUsesWith(arg, loadOp);
}
}
}

Expand Down
88 changes: 53 additions & 35 deletions mlir/lib/Transforms/Utils/DialectConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -976,9 +976,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
void replaceOp(Operation *op, SmallVector<SmallVector<Value>> &&newValues);

/// Replace the uses of the given value with the given values. The specified
/// converter is used to build materializations (if necessary).
void replaceAllUsesWith(Value from, ValueRange to,
const TypeConverter *converter);
/// converter is used to build materializations (if necessary). If `functor`
/// is specified, only the uses that the functor returns "true" for are
/// replaced.
void replaceValueUses(Value from, ValueRange to,
const TypeConverter *converter,
function_ref<bool(OpOperand &)> functor = nullptr);

/// Erase the given block and its contents.
void eraseBlock(Block *block);
Expand Down Expand Up @@ -1202,12 +1205,14 @@ void BlockTypeConversionRewrite::rollback() {
getNewBlock()->replaceAllUsesWith(getOrigBlock());
}

/// Replace all uses of `from` with `repl`.
static void performReplaceValue(RewriterBase &rewriter, Value from,
Value repl) {
void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
if (!repl)
return;

if (isa<BlockArgument>(repl)) {
// `repl` is a block argument. Directly replace all uses.
rewriter.replaceAllUsesWith(from, repl);
rewriter.replaceAllUsesWith(value, repl);
return;
}

Expand Down Expand Up @@ -1236,19 +1241,14 @@ static void performReplaceValue(RewriterBase &rewriter, Value from,
// `ConversionPatternRewriter` API with the normal `RewriterBase` API.
Operation *replOp = repl.getDefiningOp();
Block *replBlock = replOp->getBlock();
rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
Operation *user = operand.getOwner();
return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
bool result =
user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
return result;
});
}

void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
if (!repl)
return;
performReplaceValue(rewriter, value, repl);
}

void ReplaceValueRewrite::rollback() {
rewriterImpl.mapping.erase({value});
#ifndef NDEBUG
Expand Down Expand Up @@ -1646,7 +1646,7 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
/*outputTypes=*/origArgType, /*originalType=*/Type(), converter,
/*isPureTypeConversion=*/false)
.front();
replaceAllUsesWith(origArg, mat, converter);
replaceValueUses(origArg, mat, converter);
continue;
}

Expand All @@ -1655,14 +1655,14 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
assert(inputMap->size == 0 &&
"invalid to provide a replacement value when the argument isn't "
"dropped");
replaceAllUsesWith(origArg, inputMap->replacementValues, converter);
replaceValueUses(origArg, inputMap->replacementValues, converter);
continue;
}

// This is a 1->1+ mapping.
auto replArgs =
newBlock->getArguments().slice(inputMap->inputNo, inputMap->size);
replaceAllUsesWith(origArg, replArgs, converter);
replaceValueUses(origArg, replArgs, converter);
}

if (config.allowPatternRollback)
Expand Down Expand Up @@ -1962,8 +1962,24 @@ void ConversionPatternRewriterImpl::replaceOp(
op->walk([&](Operation *op) { replacedOps.insert(op); });
}

void ConversionPatternRewriterImpl::replaceAllUsesWith(
Value from, ValueRange to, const TypeConverter *converter) {
void ConversionPatternRewriterImpl::replaceValueUses(
Value from, ValueRange to, const TypeConverter *converter,
function_ref<bool(OpOperand &)> functor) {
LLVM_DEBUG({
logger.startLine() << "** Replace Value : '" << from << "'";
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
logger.getOStream() << " (in region of '" << parentOp->getName()
<< "' (" << parentOp << ")";
} else {
logger.getOStream() << " (unlinked block)";
}
}
if (functor) {
logger.getOStream() << ", conditional replacement";
}
});

if (!config.allowPatternRollback) {
SmallVector<Value> toConv = llvm::to_vector(to);
SmallVector<Value> repls =
Expand All @@ -1972,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
Value repl = repls.front();
if (!repl)
return;

performReplaceValue(r, from, repl);
if (functor) {
r.replaceUsesWithIf(from, repl, functor);
} else {
r.replaceAllUsesWith(from, repl);
}
return;
}

Expand All @@ -1992,6 +2011,8 @@ void ConversionPatternRewriterImpl::replaceAllUsesWith(
replacedValues.insert(from);
#endif // NDEBUG

assert(!functor &&
"conditional value replacement is not supported in rollback mode");
mapping.map(from, to);
appendRewrite<ReplaceValueRewrite>(from, converter);
}
Expand Down Expand Up @@ -2190,18 +2211,15 @@ FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
}

void ConversionPatternRewriter::replaceAllUsesWith(Value from, ValueRange to) {
LLVM_DEBUG({
impl->logger.startLine() << "** Replace Value : '" << from << "'";
if (auto blockArg = dyn_cast<BlockArgument>(from)) {
if (Operation *parentOp = blockArg.getOwner()->getParentOp()) {
impl->logger.getOStream() << " (in region of '" << parentOp->getName()
<< "' (" << parentOp << ")\n";
} else {
impl->logger.getOStream() << " (unlinked block)\n";
}
}
});
impl->replaceAllUsesWith(from, to, impl->currentTypeConverter);
impl->replaceValueUses(from, to, impl->currentTypeConverter);
}

void ConversionPatternRewriter::replaceUsesWithIf(
Value from, ValueRange to, function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced) {
assert(!allUsesReplaced &&
"allUsesReplaced is not supported in a dialect conversion");
impl->replaceValueUses(from, to, impl->currentTypeConverter, functor);
}

Value ConversionPatternRewriter::getRemappedValue(Value key) {
Expand Down
3 changes: 2 additions & 1 deletion mlir/test/Transforms/test-convert-func-op.mlir
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s
// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s

// CHECK-LABEL: llvm.func @add
func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
Expand Down
23 changes: 23 additions & 0 deletions mlir/test/Transforms/test-legalizer-no-rollback.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns="allow-pattern-rollback=0" -verify-diagnostics %s | FileCheck %s

// CHECK-LABEL: @conditional_replacement(
// CHECK-SAME: %[[arg0:.*]]: i43)
// CHECK: %[[cast1:.*]] = "test.cast"(%[[arg0]]) : (i43) -> i42
// CHECK: %[[legal:.*]] = "test.legal_op"() : () -> i42
// CHECK: %[[cast2:.*]] = "test.cast"(%[[legal]], %[[legal]]) : (i42, i42) -> i42
// Uses were replaced for dummy_user_1.
// CHECK: "test.dummy_user_1"(%[[cast2]]) {replace_uses} : (i42) -> ()
// Uses were also replaced for dummy_user_2, but not by value_replace. The uses
// were replaced due to the block signature conversion.
// CHECK: "test.dummy_user_2"(%[[cast1]]) : (i42) -> ()
// CHECK: "test.value_replace"(%[[cast1]], %[[legal]]) {conditional, is_legal} : (i42, i42) -> ()
func.func @conditional_replacement(%arg0: i42) {
%repl = "test.legal_op"() : () -> (i42)
// expected-remark @+1 {{is not legalizable}}
"test.dummy_user_1"(%arg0) {replace_uses} : (i42) -> ()
// expected-remark @+1 {{is not legalizable}}
"test.dummy_user_2"(%arg0) {} : (i42) -> ()
// Perform a conditional 1:N replacement.
"test.value_replace"(%arg0, %repl) {conditional} : (i42, i42) -> ()
"test.return"() : () -> ()
}
11 changes: 10 additions & 1 deletion mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ struct TestConvertFuncOp
: public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)

TestConvertFuncOp() = default;
TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {}

void getDependentDialects(DialectRegistry &registry) const final {
registry.insert<LLVM::LLVMDialect>();
}
Expand All @@ -92,10 +95,16 @@ struct TestConvertFuncOp
patterns.add<ReturnOpConversion>(typeConverter);

LLVMConversionTarget target(getContext());
ConversionConfig config;
config.allowPatternRollback = allowPatternRollback;
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
std::move(patterns), config)))
signalPassFailure();
}

Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
llvm::cl::desc("Allow pattern rollback"),
llvm::cl::init(true)};
};

} // namespace
Expand Down
8 changes: 7 additions & 1 deletion mlir/test/lib/Dialect/Test/TestPatterns.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,13 @@ struct TestValueReplace : public ConversionPattern {
// Replace the first operand with 2x the second operand.
Value from = op->getOperand(0);
Value repl = op->getOperand(1);
rewriter.replaceAllUsesWith(from, {repl, repl});
if (op->hasAttr("conditional")) {
rewriter.replaceUsesWithIf(from, {repl, repl}, [=](OpOperand &use) {
return use.getOwner()->hasAttr("replace_uses");
});
} else {
rewriter.replaceAllUsesWith(from, {repl, repl});
}
rewriter.modifyOpInPlace(op, [&] {
// If the "trigger_rollback" attribute is set, keep the op illegal, so
// that a rollback is triggered.
Expand Down