Skip to content

Conversation

@matthias-springer
Copy link
Member

@matthias-springer matthias-springer commented Nov 26, 2025

Remove a workaround in the implementation of replaceAllUsesWith in the no-rollback dialect conversion. This workaround was necessary for restoreByValRefArgumentType in the func-to-llvm lowering because there was no support for replaceAllUsesExcept. Support for this API has been added to the no-rollback driver, so the workaround can be dropped from that driver. The workaround is still in place for the rollback driver.

Depends on #169606.

@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Nov 26, 2025
@llvmbot
Copy link
Member

llvmbot commented Nov 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

Changes

Replace a workaround in the implementation of replaceAllUsesWith in the no-rollback dialect conversion. This workaround was necessary for restoreByValRefArgumentType in the func-to-llvm lowering because there was no support for replaceAllUsesExcept. Support for this API has been added to the no-rollback driver, so the workaround can be dropped from that driver. The workaround is still in place for the rollback driver.

Depends on #169606.


Full diff: https://github.com/llvm/llvm-project/pull/169609.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp (+10-2)
  • (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+12-21)
  • (modified) mlir/test/Transforms/test-convert-func-op.mlir (+2-1)
  • (modified) mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp (+10-1)
diff --git a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
index 2220f61ed8a07..ddd94f5d03042 100644
--- a/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
+++ b/mlir/lib/Conversion/FuncToLLVM/FuncToLLVM.cpp
@@ -283,8 +283,16 @@ static void restoreByValRefArgumentType(
     Type resTy = typeConverter.convertType(
         cast<TypeAttr>(byValRefAttr->getValue()).getValue());
 
-    Value valueArg = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
-    rewriter.replaceAllUsesWith(arg, valueArg);
+    auto loadOp = LLVM::LoadOp::create(rewriter, arg.getLoc(), resTy, arg);
+    if (!rewriter.getConfig().allowPatternRollback) {
+      rewriter.replaceAllUsesExcept(arg, loadOp, loadOp);
+    } else {
+      // replaceAllUsesExcept is not supported in rollback mode. The rollback
+      // mode implementation has a workaround: certain replacements that would
+      // cause a dominance violation are skipped.
+      // TODO: Remove workaround.
+      rewriter.replaceAllUsesWith(arg, loadOp);
+    }
   }
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index c9f1596c07cbe..ccc5b7cb6f229 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1205,17 +1205,14 @@ void BlockTypeConversionRewrite::rollback() {
   getNewBlock()->replaceAllUsesWith(getOrigBlock());
 }
 
-/// Replace all uses of `from` with `repl`.
-static void
-performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
-                    function_ref<bool(OpOperand &)> functor = nullptr) {
+void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
+  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
+  if (!repl)
+    return;
+
   if (isa<BlockArgument>(repl)) {
     // `repl` is a block argument. Directly replace all uses.
-    if (functor) {
-      rewriter.replaceUsesWithIf(from, repl, functor);
-    } else {
-      rewriter.replaceAllUsesWith(from, repl);
-    }
+    rewriter.replaceAllUsesWith(value, repl);
     return;
   }
 
@@ -1244,23 +1241,14 @@ performReplaceValue(RewriterBase &rewriter, Value from, Value repl,
   // `ConversionPatternRewriter` API with the normal `RewriterBase` API.
   Operation *replOp = repl.getDefiningOp();
   Block *replBlock = replOp->getBlock();
-  rewriter.replaceUsesWithIf(from, repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(value, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     bool result =
         user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
-    if (functor)
-      result &= functor(operand);
     return result;
   });
 }
 
-void ReplaceValueRewrite::commit(RewriterBase &rewriter) {
-  Value repl = rewriterImpl.findOrBuildReplacementValue(value, converter);
-  if (!repl)
-    return;
-  performReplaceValue(rewriter, value, repl);
-}
-
 void ReplaceValueRewrite::rollback() {
   rewriterImpl.mapping.erase({value});
 #ifndef NDEBUG
@@ -2000,8 +1988,11 @@ void ConversionPatternRewriterImpl::replaceValueUses(
     Value repl = repls.front();
     if (!repl)
       return;
-
-    performReplaceValue(r, from, repl, functor);
+    if (functor) {
+      r.replaceUsesWithIf(from, repl, functor);
+    } else {
+      r.replaceAllUsesWith(from, repl);
+    }
     return;
   }
 
diff --git a/mlir/test/Transforms/test-convert-func-op.mlir b/mlir/test/Transforms/test-convert-func-op.mlir
index 180f16a32991b..14c15ecbe77f0 100644
--- a/mlir/test/Transforms/test-convert-func-op.mlir
+++ b/mlir/test/Transforms/test-convert-func-op.mlir
@@ -1,4 +1,5 @@
-// RUN: mlir-opt %s -test-convert-func-op --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=1" --split-input-file | FileCheck %s
+// RUN: mlir-opt %s -test-convert-func-op="allow-pattern-rollback=0" --split-input-file | FileCheck %s
 
 // CHECK-LABEL: llvm.func @add
 func.func @add(%arg0: i32, %arg1: i32) -> i32 attributes { llvm.emit_c_interface } {
diff --git a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
index 75168dde93130..897b11b65b6f2 100644
--- a/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
+++ b/mlir/test/lib/Conversion/FuncToLLVM/TestConvertFuncOp.cpp
@@ -68,6 +68,9 @@ struct TestConvertFuncOp
     : public PassWrapper<TestConvertFuncOp, OperationPass<ModuleOp>> {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConvertFuncOp)
 
+  TestConvertFuncOp() = default;
+  TestConvertFuncOp(const TestConvertFuncOp &other) : PassWrapper(other) {}
+
   void getDependentDialects(DialectRegistry &registry) const final {
     registry.insert<LLVM::LLVMDialect>();
   }
@@ -92,10 +95,16 @@ struct TestConvertFuncOp
     patterns.add<ReturnOpConversion>(typeConverter);
 
     LLVMConversionTarget target(getContext());
+    ConversionConfig config;
+    config.allowPatternRollback = allowPatternRollback;
     if (failed(applyPartialConversion(getOperation(), target,
-                                      std::move(patterns))))
+                                      std::move(patterns), config)))
       signalPassFailure();
   }
+
+  Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+                                    llvm::cl::desc("Allow pattern rollback"),
+                                    llvm::cl::init(true)};
 };
 
 } // namespace

Base automatically changed from users/matthias-springer/replace_uses_functor to main November 27, 2025 01:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants