diff --git a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h index 9820a91291fdb..a6d914dfae4ab 100644 --- a/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h +++ b/mlir/include/mlir/Analysis/DataFlow/IntegerRangeAnalysis.h @@ -61,6 +61,21 @@ class IntegerRangeAnalysis ArrayRef operands, ArrayRef results) override; + // Override visitRegionSuccessors to add a visit cap on loop-carried + // lattice elements, preventing non-convergence on scf.while loops with + // dynamic bounds. + // + // Without this cap, loop-carried values whose ranges grow by +1 per + // worklist visit (e.g. [0,0]->[0,1]->[0,2]->...) require O(2^31) + // iterations to converge for i32. The existing widening in + // visitOperation only catches op results yielded directly to a + // terminator, not values propagated through nested region ops like + // scf.if. + void visitRegionSuccessors(ProgramPoint *point, + RegionBranchOpInterface branch, + RegionSuccessor successor, + ArrayRef lattices) override; + /// Visit block arguments or operation results of an operation with region /// control-flow for which values are not defined by region control-flow. This /// function calls `InferIntRangeInterface` to provide values for block @@ -70,6 +85,14 @@ class IntegerRangeAnalysis Operation *op, const RegionSuccessor &successor, ValueRange nonSuccessorInputs, ArrayRef nonSuccessorInputLattices) override; + +private: + // Maximum lattice updates per (loop, element) before forcing max-range. + static constexpr int64_t kMaxLoopVisits = 4; + + // Per-(loop-op, lattice-element) visit counter. + DenseMap, int64_t> + loopVisits; }; /// Succeeds if an op can be converted to its unsigned equivalent without diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 818450e2bc696..5f3e81fe9cb7c 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -137,6 +137,87 @@ LogicalResult IntegerRangeAnalysis::visitOperation( return success(); } +void IntegerRangeAnalysis::visitRegionSuccessors( + ProgramPoint *point, RegionBranchOpInterface branch, + RegionSuccessor successor, + ArrayRef lattices) { + + Operation *branchOp = branch.getOperation(); + bool isLoop = isa(branchOp); + + // For non-loop regions (scf.if, scf.index_switch, etc.), delegate to + // the upstream implementation — no visit cap needed. + if (!isLoop) { + SparseForwardDataFlowAnalysis::visitRegionSuccessors( + point, branch, successor, lattices); + return; + } + + // For loops: replicate base class implementation with a visit cap. + const auto *predecessors = + getOrCreateFor(point, point); + assert(predecessors->allPredecessorsKnown() && + "unexpected unresolved region successors"); + + for (Operation *op : predecessors->getKnownPredecessors()) { + std::optional operands; + if (op == branch) { + operands = branch.getEntrySuccessorOperands(successor); + } else if (auto regionTerminator = + dyn_cast(op)) { + operands = regionTerminator.getSuccessorOperands(successor); + } + if (!operands) { + setAllToEntryStates(lattices); + return; + } + + ValueRange inputs = predecessors->getSuccessorInputs(op); + assert(inputs.size() == operands->size() && + "expected the same number of successor inputs as operands"); + + unsigned firstIndex = 0; + if (inputs.size() != lattices.size()) { + if (successor.isParent()) { + if (!inputs.empty()) + firstIndex = cast(inputs.front()).getResultNumber(); + } else { + if (!inputs.empty()) + firstIndex = cast(inputs.front()).getArgNumber(); + } + } + + for (auto [oper, lattice] : + llvm::zip(*operands, ArrayRef(lattices).drop_front(firstIndex))) { + auto key = std::make_pair(branchOp, lattice); + int64_t &visits = loopVisits[key]; + + if (visits >= kMaxLoopVisits) { + // Force to max-range (lattice top) — guarantees convergence. + auto *intLattice = static_cast(lattice); + ChangeResult changed = + intLattice->join(IntegerValueRange::getMaxRange(oper)); + propagateIfChanged(intLattice, changed); + LLVM_DEBUG({ + if (changed == ChangeResult::Change) { + LDBG() << "Forcing max-range after " << visits << " visits for "; + oper.printAsOperand(llvm::dbgs(), {}); + llvm::dbgs() << "\n"; + } + }); + continue; + } + + // Normal join with visit tracking. + ChangeResult changed = + lattice->join(*getLatticeElementFor(point, oper)); + propagateIfChanged(lattice, changed); + if (changed == ChangeResult::Change) + ++visits; + } + } +} + void IntegerRangeAnalysis::visitNonControlFlowArguments( Operation *op, const RegionSuccessor &successor, ValueRange nonSuccessorInputs, diff --git a/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir new file mode 100644 index 0000000000000..d8dabd253c698 --- /dev/null +++ b/mlir/test/Dialect/Arith/int-range-analysis-convergence.mlir @@ -0,0 +1,70 @@ +// IntegerRangeAnalysis non-convergence on scf.while with dynamic bounds. +// +// The carry range ratchets [0,0]->[0,1]->[0,2]->... without bound. +// Two nested scf.if layers with differing arith chains (addi, muli) +// bounded by remui create enough worklist cascade to prevent the +// solver's back-to-back convergence shortcut from firing. +// +// After the fix (visit cap in visitRegionSuccessors), the analysis +// converges in bounded time. +// +// RUN: mlir-opt -int-range-optimizations %s -o /dev/null + +func.func @grouped_gemm_while_hang(%n: i32, %flag: i1) -> i32 { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c3 = arith.constant 3 : i32 + %c7 = arith.constant 7 : i32 + %c127 = arith.constant 127 : i32 + %init = arith.cmpi slt, %c0, %n : i32 + + %res:2 = scf.while (%a0 = %c0, %cond = %init) : (i32, i1) -> (i32, i1) { + scf.condition(%cond) %a0, %cond : i32, i1 + } do { + ^bb0(%b0: i32, %bc: i1): + %t0 = arith.addi %b0, %c1 : i32 + %ic = arith.cmpi slt, %t0, %n : i32 + + %inner:2 = scf.while (%i0 = %t0, %iic = %ic) : (i32, i1) -> (i32, i1) { + scf.condition(%iic) %i0, %iic : i32, i1 + } do { + ^bb1(%j0: i32, %jc: i1): + + // Layer 0: branches must differ to prevent folding. + // remui bounds ranges to [0,126], preventing overflow-cascade + // convergence. Both branches must have ops (not just passthrough) + // to generate enough worklist items. + %L0 = scf.if %flag -> (i32) { + %a0_0 = arith.addi %j0, %c1 : i32 + %a0_1 = arith.muli %a0_0, %c7 : i32 + %a0_r = arith.remui %a0_1, %c127 : i32 + scf.yield %a0_r : i32 + } else { + %b0_0 = arith.addi %j0, %c3 : i32 + %b0_1 = arith.muli %b0_0, %c7 : i32 + %b0_r = arith.remui %b0_1, %c127 : i32 + scf.yield %b0_r : i32 + } + + // Layer 1: second nested scf.if feeds from layer 0. + %L1 = scf.if %flag -> (i32) { + %a1_0 = arith.addi %L0, %c1 : i32 + %a1_1 = arith.muli %a1_0, %c7 : i32 + %a1_r = arith.remui %a1_1, %c127 : i32 + scf.yield %a1_r : i32 + } else { + %b1_0 = arith.addi %L0, %c3 : i32 + %b1_1 = arith.muli %b1_0, %c7 : i32 + %b1_r = arith.remui %b1_1, %c127 : i32 + scf.yield %b1_r : i32 + } + + %nic = arith.cmpi slt, %L1, %n : i32 + scf.yield %L1, %nic : i32, i1 + } + + %nc = arith.cmpi slt, %inner#0, %n : i32 + scf.yield %inner#0, %nc : i32, i1 + } + return %res#0 : i32 +}