-
Notifications
You must be signed in to change notification settings - Fork 11.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver #77258
[MLIR] Handle materializeConstant failure in GreedyPatternRewriteDriver #77258
Conversation
@llvm/pr-subscribers-mlir Author: Billy Zhu (zyx-billy) ChangesGreedyPatternRewriteDriver needs to handle failures of Full diff: https://github.com/llvm/llvm-project/pull/77258.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 82438e2bf706c1..146091f96b8fd3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
+ changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
SmallVector<Value> replacements;
+ bool materializationSucceeded = true;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
+
+ if (!constOp) {
+ // If materialization fails, cleanup any operations generated for
+ // the previous results.
+ llvm::SmallDenseSet<Operation *> replacementOps;
+ std::transform(replacements.begin(), replacements.end(),
+ replacementOps.begin(), [](Value replacement) {
+ return replacement.getDefiningOp();
+ });
+ for (Operation *op : replacementOps)
+ eraseOp(op);
+
+ materializationSucceeded = false;
+ break;
+ }
+
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
- replaceOp(op, replacements);
+
+ if (materializationSucceeded) {
+ replaceOp(op, replacements);
+ changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ if (config.scope && failed(verify(config.scope->getParentOp())))
+ llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- continue;
+ continue;
+ }
}
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 47a19bb598c25e..9b578e6c2631a7 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
// CHECK-NEXT: scf.yield %[[ALLOC3_2]]
// CHECK: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: return %[[ALLOC2]]
+
+// -----
+
+// CHECK-LABEL: func @test_materialize_failure
+func.func @test_materialize_failure() -> i64 {
+ %const = index.constant 1234
+ // Cannot materialize this castu's output constant.
+ // CHECK: index.castu
+ %u = index.castu %const : index to i64
+ return %u: i64
+}
|
@llvm/pr-subscribers-mlir-core Author: Billy Zhu (zyx-billy) ChangesGreedyPatternRewriteDriver needs to handle failures of Full diff: https://github.com/llvm/llvm-project/pull/77258.diff 2 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 82438e2bf706c1..146091f96b8fd3 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -434,10 +434,10 @@ bool GreedyPatternRewriteDriver::processWorklist() {
SmallVector<OpFoldResult> foldResults;
if (succeeded(op->fold(foldResults))) {
LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
- changed = true;
if (foldResults.empty()) {
// Op was modified in-place.
notifyOperationModified(op);
+ changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
if (config.scope && failed(verify(config.scope->getParentOp())))
llvm::report_fatal_error("IR failed to verify after folding");
@@ -451,6 +451,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
OpBuilder::InsertionGuard g(*this);
setInsertionPoint(op);
SmallVector<Value> replacements;
+ bool materializationSucceeded = true;
for (auto [ofr, resultType] :
llvm::zip_equal(foldResults, op->getResultTypes())) {
if (auto value = ofr.dyn_cast<Value>()) {
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() {
// Materialize Attributes as SSA values.
Operation *constOp = op->getDialect()->materializeConstant(
*this, ofr.get<Attribute>(), resultType, op->getLoc());
+
+ if (!constOp) {
+ // If materialization fails, cleanup any operations generated for
+ // the previous results.
+ llvm::SmallDenseSet<Operation *> replacementOps;
+ std::transform(replacements.begin(), replacements.end(),
+ replacementOps.begin(), [](Value replacement) {
+ return replacement.getDefiningOp();
+ });
+ for (Operation *op : replacementOps)
+ eraseOp(op);
+
+ materializationSucceeded = false;
+ break;
+ }
+
assert(constOp->hasTrait<OpTrait::ConstantLike>() &&
"materializeConstant produced op that is not a ConstantLike");
assert(constOp->getResultTypes()[0] == resultType &&
"materializeConstant produced incorrect result type");
replacements.push_back(constOp->getResult(0));
}
- replaceOp(op, replacements);
+
+ if (materializationSucceeded) {
+ replaceOp(op, replacements);
+ changed = true;
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- if (config.scope && failed(verify(config.scope->getParentOp())))
- llvm::report_fatal_error("IR failed to verify after folding");
+ if (config.scope && failed(verify(config.scope->getParentOp())))
+ llvm::report_fatal_error("IR failed to verify after folding");
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
- continue;
+ continue;
+ }
}
}
diff --git a/mlir/test/Transforms/canonicalize.mlir b/mlir/test/Transforms/canonicalize.mlir
index 47a19bb598c25e..9b578e6c2631a7 100644
--- a/mlir/test/Transforms/canonicalize.mlir
+++ b/mlir/test/Transforms/canonicalize.mlir
@@ -1224,3 +1224,14 @@ func.func @clone_nested_region(%arg0: index, %arg1: index, %arg2: index) -> memr
// CHECK-NEXT: scf.yield %[[ALLOC3_2]]
// CHECK: memref.dealloc %[[ALLOC1]]
// CHECK-NEXT: return %[[ALLOC2]]
+
+// -----
+
+// CHECK-LABEL: func @test_materialize_failure
+func.func @test_materialize_failure() -> i64 {
+ %const = index.constant 1234
+ // Cannot materialize this castu's output constant.
+ // CHECK: index.castu
+ %u = index.castu %const : index to i64
+ return %u: i64
+}
|
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() { | |||
// Materialize Attributes as SSA values. | |||
Operation *constOp = op->getDialect()->materializeConstant( | |||
*this, ofr.get<Attribute>(), resultType, op->getLoc()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
materializeConstant
will send a notifyOperationInserted
which is not correct if we erase the ops again. We need a new temporary OpBuilder
instead of *this
and if all constants materialized successfully, trigger a manual this->notifyOperationInserted
for each materialized constant.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you are erasing the ops with eraseOp(op);
, which will send a notifyOperationRemoved
. So it should be fine.
@@ -462,18 +463,38 @@ bool GreedyPatternRewriteDriver::processWorklist() { | |||
// Materialize Attributes as SSA values. | |||
Operation *constOp = op->getDialect()->materializeConstant( | |||
*this, ofr.get<Attribute>(), resultType, op->getLoc()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, you are erasing the ops with eraseOp(op);
, which will send a notifyOperationRemoved
. So it should be fine.
return replacement.getDefiningOp(); | ||
}); | ||
for (Operation *op : replacementOps) | ||
eraseOp(op); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe assert that materialized constant does not have any uses, to protect against "smart" materializers that try to reuse an existing op.
(We should fail here anyway, but then we have a nicer error message.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah good point. Let me add the check.
…er (llvm#77258) Make GreedyPatternRewriteDriver handle failures of `materializeConstant` gracefully. Previously it was not checking whether the returned op was null and crashing. This PR handles it similarly to how OperationFolder does it.
Make GreedyPatternRewriteDriver handle failures of
materializeConstant
gracefully. Previously it was not checking whether the returned op was null and crashing. This PR handles it similarly to how OperationFolder does it.