diff --git a/mlir/lib/Transforms/RemoveDeadValues.cpp b/mlir/lib/Transforms/RemoveDeadValues.cpp index 66f369e8a5f65..12a47ba2fb65a 100644 --- a/mlir/lib/Transforms/RemoveDeadValues.cpp +++ b/mlir/lib/Transforms/RemoveDeadValues.cpp @@ -720,6 +720,28 @@ static void cleanUpDeadVals(MLIRContext *ctx, RDVFinalCleanupList &list) { // When erasing a terminator, insert an unreachable op in its place. ub::UnreachableOp::create(rewriter, op->getLoc()); } + + // Before erasing the operation, replace all result values with live-uses by + // ub.poison values. This is important to maintain IR validity. For example, + // if we have an op with one of its results used by another op, erasing the + // op without replacing its corresponding result would leave us with a + // dangling operand in the user op. By replacing the result with a ub.poison + // value, we ensure that the user op still has a valid operand, even though + // it's a poison value which will be cleaned up later if it can be cleaned + // up. This keeps the IR valid for further simplification and + // canonicalization. + auto opResults = op->getResults(); + for (Value opResult : opResults) { + // Early continue for the case where the op result has no uses. No need to + // create a poison op here. + if (opResult.use_empty()) + continue; + + rewriter.setInsertionPoint(op); + Value poisonedValue = createPoisonedValues(rewriter, opResult).front(); + rewriter.replaceAllUsesWith(opResult, poisonedValue); + } + op->dropAllUses(); rewriter.eraseOp(op); } diff --git a/mlir/test/Transforms/remove-dead-values.mlir b/mlir/test/Transforms/remove-dead-values.mlir index ae83eac0c376f..87e77b2eb700f 100644 --- a/mlir/test/Transforms/remove-dead-values.mlir +++ b/mlir/test/Transforms/remove-dead-values.mlir @@ -796,3 +796,33 @@ func.func @scf_while_dead_iter_args() -> i32 { } return %result#0 : i32 } + +// ----- + +// CHECK-LABEL: func.func @replace_dead_operation_results_with_poison +func.func @replace_dead_operation_results_with_poison(%0: vector<1xindex>) -> vector<1xindex> { + %1 = scf.while (%arg0 = %0) : (vector<1xindex>) -> vector<1xindex> { + %cond = arith.constant true + scf.condition(%cond) %arg0 : vector<1xindex> + } do { + ^bb0(%arg0: vector<1xindex>): + scf.yield %arg0 : vector<1xindex> + } + %2 = scf.while (%arg0 = %1) : (vector<1xindex>) -> vector<1xindex> { + // Check that the binary value in condition is replaced with poison, and + // the condition itself is well-formed IR. This prevents a crash in the + // canonicalization phase which happens after the dead value removal phase. + // Also check that only used results of an erased op are replaced with ub.poison. + // CHECK-CANONICALIZE: %[[COND:.*]] = ub.poison : i1 + // CHECK-CANONICALIZE-NEXT: %[[NEXT:.*]] = ub.poison : vector<1xindex> + // CHECK-CANONICALIZE-NEXT: scf.condition(%[[COND]]) %[[NEXT]] + // CHECK-CANONICALIZE-NOT: ub.poison : i32 + // CHECK-CANONICALIZE-NOT: "test.three" + %cond, %unused, %next = "test.three"(%1) : (vector<1xindex>) -> (i1, i32, vector<1xindex>) + scf.condition(%cond) %next : vector<1xindex> + } do { + ^bb0(%arg0: vector<1xindex>): + scf.yield %arg0 : vector<1xindex> + } + return %2 : vector<1xindex> +}