diff --git a/mlir/lib/Analysis/SliceWalk.cpp b/mlir/lib/Analysis/SliceWalk.cpp index 5c5a68ef11b36..d48c9a1bb4ad1 100644 --- a/mlir/lib/Analysis/SliceWalk.cpp +++ b/mlir/lib/Analysis/SliceWalk.cpp @@ -68,11 +68,17 @@ mlir::getControlFlowPredecessors(Value value) { if (!regionOp) return std::nullopt; // Add the control flow predecessor operands to the work list. - RegionSuccessor region = RegionSuccessor::parent(); + RegionSuccessor parentSuccessor = RegionSuccessor::parent(); + // Find the position of `opResult` in the successor inputs of the parent. + // `getPredecessorValues` indexes into the successor inputs, not into the + // op results directly, since some results may not be successor inputs. + ValueRange successorInputs = regionOp.getSuccessorInputs(parentSuccessor); + auto it = llvm::find(successorInputs, opResult); + if (it == successorInputs.end()) + return std::nullopt; SmallVector predecessorOperands; - // TODO (#175168): This assumes that there are no non-successor-inputs - // in front of the op result. - regionOp.getPredecessorValues(region, opResult.getResultNumber(), + regionOp.getPredecessorValues(parentSuccessor, + std::distance(successorInputs.begin(), it), predecessorOperands); return predecessorOperands; } @@ -83,12 +89,20 @@ mlir::getControlFlowPredecessors(Value value) { if (block->isEntryBlock()) { if (auto regionBranchOp = dyn_cast(block->getParentOp())) { - RegionSuccessor region(blockArg.getParentRegion()); + RegionSuccessor regionSuccessor(blockArg.getParentRegion()); + // Find the position of `blockArg` in the successor inputs of the region. + // `getPredecessorValues` indexes into the successor inputs, not into the + // block arguments directly, since some block arguments may not be + // successor inputs (e.g., block arguments produced by the terminator). + ValueRange successorInputs = + regionBranchOp.getSuccessorInputs(regionSuccessor); + auto it = llvm::find(successorInputs, blockArg); + if (it == successorInputs.end()) + return std::nullopt; SmallVector predecessorOperands; - // TODO (#175168): This assumes that there are no non-successor-inputs - // in front of the block argument. - regionBranchOp.getPredecessorValues(region, blockArg.getArgNumber(), - predecessorOperands); + regionBranchOp.getPredecessorValues( + regionSuccessor, std::distance(successorInputs.begin(), it), + predecessorOperands); return predecessorOperands; } // If the interface is not implemented, there are no control flow diff --git a/mlir/test/Analysis/test-control-flow-predecessors.mlir b/mlir/test/Analysis/test-control-flow-predecessors.mlir new file mode 100644 index 0000000000000..e345363e0dec6 --- /dev/null +++ b/mlir/test/Analysis/test-control-flow-predecessors.mlir @@ -0,0 +1,41 @@ +// RUN: mlir-opt --mlir-disable-threading -pass-pipeline="builtin.module(any(test-control-flow-predecessors))" %s 2>&1 | FileCheck %s + +// Test that getControlFlowPredecessors correctly handles values that are not +// successor inputs (issue #175168). Before the fix, these cases would either +// crash (out-of-bounds index) or silently return wrong results. + +// Test case 1: scf.for with no iter_args. +// The induction variable (%iv) is block arg #0 and is NOT in +// getSuccessorInputs(body region), since there are no iter_args. +// Before the fix, getControlFlowPredecessors(%iv) would call +// getPredecessorValues with index 0 on an empty successor-inputs range, +// causing an out-of-bounds access. + +// CHECK-LABEL: Control flow predecessors for '"test_scf_for_no_iter_args"' +// CHECK: 'scf.for': region #0 block arg #0: no predecessors +func.func @test_scf_for_no_iter_args(%lb: index, %ub: index, %step: index) { + scf.for %iv = %lb to %ub step %step { + } + return +} + +// Test case 2: test.loop_with_extra_result. +// - result #0 (extraResult): NOT in getSuccessorInputs(parent) — no predecessors. +// - result #1 (iterResult): IS in getSuccessorInputs(parent) — has predecessors. +// - body block arg #0 (extraArg): NOT in getSuccessorInputs(body) — no predecessors. +// - body block arg #1 (iterArg): IS in getSuccessorInputs(body) — has predecessors. +// Before the fix, querying extraResult/extraArg would use the wrong index into +// getSuccessorInputs, returning a predecessor for the wrong value. + +// CHECK-LABEL: Control flow predecessors for '"test_loop_with_extra_result"' +// CHECK: 'test.loop_with_extra_result': result #0: no predecessors +// CHECK: 'test.loop_with_extra_result': result #1: 1 predecessor(s) +// CHECK: 'test.loop_with_extra_result': region #0 block arg #0: no predecessors +// CHECK: 'test.loop_with_extra_result': region #0 block arg #1: 2 predecessor(s) +func.func @test_loop_with_extra_result(%init: i32) { + %extra, %iter = test.loop_with_extra_result %init : i32 -> (i32, i32) { + ^bb0(%extra_arg: i32, %iter_arg: i32): + test.loop_with_extra_result_yield %iter_arg : i32 + } + return +} diff --git a/mlir/test/lib/Analysis/TestSlice.cpp b/mlir/test/lib/Analysis/TestSlice.cpp index 7e8320dbf3ec3..c7fc6df8cc7c6 100644 --- a/mlir/test/lib/Analysis/TestSlice.cpp +++ b/mlir/test/lib/Analysis/TestSlice.cpp @@ -6,9 +6,12 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Analysis/SliceWalk.h" #include "mlir/Analysis/TopologicalSortUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/ControlFlowInterfaces.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" using namespace mlir; @@ -42,12 +45,60 @@ struct TestTopologicalSortPass } }; +/// Pass to test getControlFlowPredecessors from SliceWalk. +struct TestControlFlowPredecessorsPass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestControlFlowPredecessorsPass) + + StringRef getArgument() const final { + return "test-control-flow-predecessors"; + } + StringRef getDescription() const final { + return "Test getControlFlowPredecessors from SliceWalk."; + } + + void runOnOperation() override { + FunctionOpInterface func = getOperation(); + llvm::errs() << "Control flow predecessors for '" << func.getNameAttr() + << "':\n"; + func->walk([](Operation *op) { + if (!isa(op)) + return; + for (OpResult result : op->getResults()) { + auto predecessors = mlir::getControlFlowPredecessors(result); + llvm::errs() << " '" << op->getName() << "': result #" + << result.getResultNumber() << ": "; + if (!predecessors) + llvm::errs() << "no predecessors\n"; + else + llvm::errs() << predecessors->size() << " predecessor(s)\n"; + } + for (Region ®ion : op->getRegions()) { + if (region.empty()) + continue; + for (BlockArgument arg : region.front().getArguments()) { + auto predecessors = mlir::getControlFlowPredecessors(arg); + llvm::errs() << " '" << op->getName() << "': region #" + << region.getRegionNumber() << " block arg #" + << arg.getArgNumber() << ": "; + if (!predecessors) + llvm::errs() << "no predecessors\n"; + else + llvm::errs() << predecessors->size() << " predecessor(s)\n"; + } + } + }); + } +}; + } // namespace namespace mlir { namespace test { void registerTestSliceAnalysisPass() { PassRegistration(); + PassRegistration(); } } // namespace test } // namespace mlir diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index ed5dc5bead78a..358110d87b317 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1313,6 +1313,37 @@ LoopBlockTerminatorOp::getMutableSuccessorOperands(RegionSuccessor successor) { return getNextIterArgMutable(); } +//===----------------------------------------------------------------------===// +// LoopWithExtraResultOp / LoopWithExtraResultYieldOp +//===----------------------------------------------------------------------===// + +void LoopWithExtraResultOp::getSuccessorRegions( + RegionBranchPoint point, SmallVectorImpl ®ions) { + // Parent always enters the body; the body can loop back or exit to parent. + regions.emplace_back(&getBody()); + if (!point.isParent()) + regions.push_back(RegionSuccessor::parent()); +} + +ValueRange +LoopWithExtraResultOp::getSuccessorInputs(RegionSuccessor successor) { + // When branching to the parent, only iterResult (result #1) is a successor + // input; extraResult (#0) is not. Similaly when branching to the body, only + // body block arg #1 (iterArg) is a successor input; arg #0 is not. + if (successor.isParent()) + return getResults().drop_front(1); + return getBody().getArguments().drop_front(1); +} + +OperandRange LoopWithExtraResultOp::getEntrySuccessorOperands(RegionSuccessor) { + return MutableOperandRange(getInitMutable()); +} + +MutableOperandRange +LoopWithExtraResultYieldOp::getMutableSuccessorOperands(RegionSuccessor) { + return getIterArgMutable(); +} + //===----------------------------------------------------------------------===// // TestCrashingReturnOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 4c9e6b3fe9e45..83408bdb0d0a4 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2731,6 +2731,33 @@ def LoopBlockTerminatorOp : TEST_Op<"loop_block_term", }]; } +// An op that has a result and a block argument that are NOT successor inputs, +// used to test the fix for issue #175168 in getControlFlowPredecessors. +// - result #0 (extraResult): NOT a successor input to parent +// - result #1 (iterResult): IS a successor input to parent +// - body block arg #0: NOT a successor input to the body region +// - body block arg #1: IS a successor input to the body region +// This mirrors the scf.for induction variable / iter_arg structure. +def LoopWithExtraResultYieldOp : TEST_Op<"loop_with_extra_result_yield", + [DeclareOpInterfaceMethods, + Pure, Terminator]> { + let arguments = (ins I32:$iterArg); + let assemblyFormat = "$iterArg `:` type($iterArg) attr-dict"; +} + +def LoopWithExtraResultOp : TEST_Op<"loop_with_extra_result", + [DeclareOpInterfaceMethods, + RecursiveMemoryEffects]> { + let results = (outs I32:$extraResult, I32:$iterResult); + let arguments = (ins I32:$init); + let regions = (region SizedRegion<1>:$body); + let assemblyFormat = [{ + $init `:` type($init) `->` `(` type($extraResult) `,` type($iterResult) `)` $body attr-dict + }]; +} + def TestNoTerminatorOp : TEST_Op<"switch_with_no_break", [ NoTerminator, DeclareOpInterfaceMethods