diff --git a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp index c3fb73acf5ef0..7607f7068d708 100644 --- a/mlir/lib/Interfaces/ControlFlowInterfaces.cpp +++ b/mlir/lib/Interfaces/ControlFlowInterfaces.cpp @@ -799,6 +799,44 @@ static llvm::EquivalenceClasses computeTiedSuccessorInputs( return tiedSuccessorInputs; } +/// Successor input mappings are edge-local: when control flow paths are pruned +/// by constants, some edge pairs may disappear and a pure edge-based tie +/// relation can miss structural couplings between op results and region block +/// arguments. For single-region region branch ops (e.g. `scf.for`), tie parent +/// successor inputs and region successor inputs by slot so canonicalizations +/// only erase such values together. +static void tieRegionAndParentSuccessorInputs( + RegionBranchOpInterface regionBranchOp, + llvm::EquivalenceClasses &tiedSuccessorInputs) { + if (regionBranchOp->getNumRegions() != 1) + return; + + ValueRange parentInputs = + regionBranchOp.getSuccessorInputs(RegionSuccessor::parent()); + if (parentInputs.empty()) + return; + + SmallVector regionInputs; + for (Region ®ion : regionBranchOp->getRegions()) { + ValueRange inputs = + regionBranchOp.getSuccessorInputs(RegionSuccessor(®ion)); + if (!inputs.empty()) + regionInputs.push_back(inputs); + } + if (regionInputs.empty()) + return; + + for (ValueRange inputs : regionInputs) { + unsigned commonInputCount = + std::min(parentInputs.size(), inputs.size()); + for (unsigned i = 0; i < commonInputCount; ++i) { + tiedSuccessorInputs.insert(parentInputs[i]); + tiedSuccessorInputs.insert(inputs[i]); + tiedSuccessorInputs.unionSets(parentInputs[i], inputs[i]); + } + } +} + /// Remove dead successor inputs from region branch ops. A successor input is /// dead if it has no uses. Successor inputs come in sets of tied values: if /// you remove one value from a set, you must remove all values from the set. @@ -856,6 +894,7 @@ struct RemoveDeadRegionBranchOpSuccessorInputs : public RewritePattern { regionBranchOp.getSuccessorOperandInputMapping(operandToInputs); llvm::EquivalenceClasses tiedSuccessorInputs = computeTiedSuccessorInputs(operandToInputs); + tieRegionAndParentSuccessorInputs(regionBranchOp, tiedSuccessorInputs); // Determine which values to remove and group them by block and operation. SmallVector valuesToRemove; diff --git a/mlir/test/Dialect/SCF/canonicalize.mlir b/mlir/test/Dialect/SCF/canonicalize.mlir index c324d34942bf8..8ae709d30bec5 100644 --- a/mlir/test/Dialect/SCF/canonicalize.mlir +++ b/mlir/test/Dialect/SCF/canonicalize.mlir @@ -2360,3 +2360,24 @@ func.func @fold_tensor_cast_into_forall_non_sequential_writes( // %0#0 contains %arg1 data; %0#1 contains %arg0 data. return %0#0, %0#1 : tensor, tensor } + +// ----- + +// CHECK-LABEL: func.func @single_iteration_loop_keeps_tied_inputs_valid +func.func @single_iteration_loop_keeps_tied_inputs_valid() { + // CHECK: %[[LB:.*]] = arith.constant 42 + %c42 = arith.constant 42 : index + %c43 = arith.constant 43 : index + %c1 = arith.constant 1 : index + // CHECK: %[[INIT:.*]] = "test.init" + %init = "test.init"() : () -> i32 + // CHECK-NOT: scf.for + // CHECK: %[[VAL:.*]] = "test.op"(%[[LB]], %[[INIT]]) + %0 = scf.for %i = %c42 to %c43 step %c1 iter_args(%arg = %init) -> (i32) { + %1 = "test.op"(%i, %init) : (index, i32) -> i32 + scf.yield %1 : i32 + } + // CHECK: "test.consume"(%[[VAL]]) : (i32) -> () + "test.consume"(%0) : (i32) -> () + return +}