Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver #77258

Conversation

zyx-billy
Copy link
Contributor

@zyx-billy zyx-billy commented Jan 7, 2024

Make GreedyPatternRewriteDriver handle failures of materializeConstant gracefully. Previously it was not checking whether the returned op was null and crashing. This PR handles it similarly to how OperationFolder does it.

@zyx-billy zyx-billy marked this pull request as ready for review January 7, 2024 22:16
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir labels Jan 7, 2024
@llvmbot
Copy link
Collaborator

llvmbot commented Jan 7, 2024

@llvm/pr-subscribers-mlir

Author: Billy Zhu (zyx-billy)

Changes

GreedyPatternRewriteDriver needs to handle failures of materializeConstant gracefully. Previously it was not checking whether the returned op was null and crashing.


Full diff: https://github.com/llvm/llvm-project/pull/77258.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+26-5)
  • (modified) mlir/test/Transforms/canonicalize.mlir (+11)
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 82438e2bf706c1..146091f96b8fd3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       SmallVector<OpFoldResult> foldResults;
       if (succeeded(op->fold(foldResults))) {
         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-        changed = true;
         if (foldResults.empty()) {
           // Op was modified in-place.
           notifyOperationModified(op);
+          changed = true;
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
           if (config.scope && failed(verify(config.scope->getParentOp())))
             llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         OpBuilder::InsertionGuard g(*this);
         setInsertionPoint(op);
         SmallVector<Value> replacements;
+        bool materializationSucceeded = true;
         for (auto [ofr, resultType] :
              llvm::zip_equal(foldResults, op->getResultTypes())) {
           if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           // Materialize Attributes as SSA values.
           Operation *constOp = op->getDialect()->materializeConstant(
               *this, ofr.get<Attribute>(), resultType, op->getLoc());
+
+          if (!constOp) {
+            // If materialization fails, cleanup any operations generated for
+            // the previous results.
+            llvm::SmallDenseSet<Operation *> replacementOps;
+            std::transform(replacements.begin(), replacements.end(),
+                           replacementOps.begin(), [](Value replacement) {
+                             return replacement.getDefiningOp();
+                           });
+            for (Operation *op : replacementOps)
+              eraseOp(op);
+
+            materializationSucceeded = false;
+            break;
+          }
+
           assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
                  "materializeConstant produced op that is not a ConstantLike");
           assert(constOp->getResultTypes()[0] == resultType &&
                  "materializeConstant produced incorrect result type");
           replacements.push_back(constOp->getResult(0));
         }
-        replaceOp(op, replacements);
+
+        if (materializationSucceeded) {
+          replaceOp(op, replacements);
+          changed = true;
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-        if (config.scope && failed(verify(config.scope->getParentOp())))
-          llvm::report_fatal_error("IR failed to verify after folding");
+          if (config.scope && failed(verify(config.scope->getParentOp())))
+            llvm::report_fatal_error("IR failed to verify after folding");
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-        continue;
+          continue;
+        }
       }
     }
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 47a19bb598c25e..9b578e6c2631a7 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
 // CHECK-NEXT: scf.yield %[[ALLOC3_2]]
 //      CHECK: memref.dealloc %[[ALLOC1]]
 // CHECK-NEXT: return %[[ALLOC2]]
+
+// -----
+
+// CHECK-LABEL: func @test_materialize_failure
+func.func @test_materialize_failure() -> i64 {
+  %const = index.constant 1234
+  // Cannot materialize this castu's output constant.
+  // CHECK: index.castu
+  %u = index.castu %const : index to i64
+  return %u: i64
+}

@llvmbot
Copy link
Collaborator

llvmbot commented Jan 7, 2024

@llvm/pr-subscribers-mlir-core

Author: Billy Zhu (zyx-billy)

Changes

GreedyPatternRewriteDriver needs to handle failures of materializeConstant gracefully. Previously it was not checking whether the returned op was null and crashing.


Full diff: https://github.com/llvm/llvm-project/pull/77258.diff

2 Files Affected:

  • (modified) mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp (+26-5)
  • (modified) mlir/test/Transforms/canonicalize.mlir (+11)
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 82438e2bf706c1..146091f96b8fd3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       SmallVector<OpFoldResult> foldResults;
       if (succeeded(op->fold(foldResults))) {
         LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-        changed = true;
         if (foldResults.empty()) {
           // Op was modified in-place.
           notifyOperationModified(op);
+          changed = true;
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
           if (config.scope && failed(verify(config.scope->getParentOp())))
             llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
         OpBuilder::InsertionGuard g(*this);
         setInsertionPoint(op);
         SmallVector<Value> replacements;
+        bool materializationSucceeded = true;
         for (auto [ofr, resultType] :
              llvm::zip_equal(foldResults, op->getResultTypes())) {
           if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
           // Materialize Attributes as SSA values.
           Operation *constOp = op->getDialect()->materializeConstant(
               *this, ofr.get<Attribute>(), resultType, op->getLoc());
+
+          if (!constOp) {
+            // If materialization fails, cleanup any operations generated for
+            // the previous results.
+            llvm::SmallDenseSet<Operation *> replacementOps;
+            std::transform(replacements.begin(), replacements.end(),
+                           replacementOps.begin(), [](Value replacement) {
+                             return replacement.getDefiningOp();
+                           });
+            for (Operation *op : replacementOps)
+              eraseOp(op);
+
+            materializationSucceeded = false;
+            break;
+          }
+
           assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
                  "materializeConstant produced op that is not a ConstantLike");
           assert(constOp->getResultTypes()[0] == resultType &&
                  "materializeConstant produced incorrect result type");
           replacements.push_back(constOp->getResult(0));
         }
-        replaceOp(op, replacements);
+
+        if (materializationSucceeded) {
+          replaceOp(op, replacements);
+          changed = true;
 #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-        if (config.scope && failed(verify(config.scope->getParentOp())))
-          llvm::report_fatal_error("IR failed to verify after folding");
+          if (config.scope && failed(verify(config.scope->getParentOp())))
+            llvm::report_fatal_error("IR failed to verify after folding");
 #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
-        continue;
+          continue;
+        }
       }
     }
 
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 47a19bb598c25e..9b578e6c2631a7 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
 // CHECK-NEXT: scf.yield %[[ALLOC3_2]]
 //      CHECK: memref.dealloc %[[ALLOC1]]
 // CHECK-NEXT: return %[[ALLOC2]]
+
+// -----
+
+// CHECK-LABEL: func @test_materialize_failure
+func.func @test_materialize_failure() -> i64 {
+  %const = index.constant 1234
+  // Cannot materialize this castu's output constant.
+  // CHECK: index.castu
+  %u = index.castu %const : index to i64
+  return %u: i64
+}

@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

materializeConstant will send a notifyOperationInserted which is not correct if we erase the ops again. We need a new temporary OpBuilder instead of *this and if all constants materialized successfully, trigger a manual this->notifyOperationInserted for each materialized constant.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you are erasing the ops with eraseOp(op);, which will send a notifyOperationRemoved. So it should be fine.

@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, you are erasing the ops with eraseOp(op);, which will send a notifyOperationRemoved. So it should be fine.

return replacement.getDefiningOp();
});
for (Operation *op : replacementOps)
eraseOp(op);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe assert that materialized constant does not have any uses, to protect against "smart" materializers that try to reuse an existing op.

(We should fail here anyway, but then we have a nicer error message.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah good point. Let me add the check.

@zyx-billy zyx-billy merged commit eb42868 into llvm:main Jan 8, 2024
4 checks passed
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…er (llvm#77258)

Make GreedyPatternRewriteDriver handle failures of `materializeConstant`
gracefully. Previously it was not checking whether the returned op was
null and crashing. This PR handles it similarly to how OperationFolder
does it.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
mlir:core MLIR Core Infrastructure mlir
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants