diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index 70b56ca77b2da..a93e605445465 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( return; } - /// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep() - /// on a LoopLikeInterface return the lower/upper bound for that result if - /// possible. - auto getLoopBoundFromFold = [&](std::optional loopBound, - Type boundType, Block *block, bool getUpper) { + /// Given a lower bound, upper bound, or step from a LoopLikeInterface return + /// the lower/upper bound for that result if possible. + auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType, + Block *block, bool getUpper) { unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType); - if (loopBound.has_value()) { - if (auto attr = dyn_cast(*loopBound)) { - if (auto bound = dyn_cast_or_null(attr)) - return bound.getValue(); - } else if (auto value = llvm::dyn_cast_if_present(*loopBound)) { - const IntegerValueRangeLattice *lattice = - getLatticeElementFor(getProgramPointBefore(block), value); - if (lattice != nullptr && !lattice->getValue().isUninitialized()) - return getUpper ? lattice->getValue().getValue().smax() - : lattice->getValue().getValue().smin(); - } + if (auto attr = dyn_cast(loopBound)) { + if (auto bound = dyn_cast(attr)) + return bound.getValue(); + } else if (auto value = llvm::dyn_cast(loopBound)) { + const IntegerValueRangeLattice *lattice = + getLatticeElementFor(getProgramPointBefore(block), value); + if (lattice != nullptr && !lattice->getValue().isUninitialized()) + return getUpper ? lattice->getValue().getValue().smax() + : lattice->getValue().getValue().smin(); } // Given the results of getConstant{Lower,Upper}Bound() // or getConstantStep() on a LoopLikeInterface return the lower/upper @@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments( // Infer bounds for loop arguments that have static bounds if (auto loop = dyn_cast(op)) { - std::optional iv = loop.getSingleInductionVar(); - if (!iv) { + std::optional> maybeIvs = + loop.getLoopInductionVars(); + if (!maybeIvs) { return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments( op, successor, argLattices, firstIndex); } - Block *block = iv->getParentBlock(); - std::optional lowerBound = loop.getSingleLowerBound(); - std::optional upperBound = loop.getSingleUpperBound(); - std::optional step = loop.getSingleStep(); - APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block, - /*getUpper=*/false); - APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block, - /*getUpper=*/true); - // Assume positivity for uniscoverable steps by way of getUpper = true. - APInt stepVal = - getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true); - - if (stepVal.isNegative()) { - std::swap(min, max); - } else { - // Correct the upper bound by subtracting 1 so that it becomes a <= - // bound, because loops do not generally include their upper bound. - max -= 1; - } + // This shouldn't be returning nullopt if there are indunction variables. + SmallVector lowerBounds = *loop.getLoopLowerBounds(); + SmallVector upperBounds = *loop.getLoopUpperBounds(); + SmallVector steps = *loop.getLoopSteps(); + for (auto [iv, lowerBound, upperBound, step] : + llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) { + Block *block = iv.getParentBlock(); + APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block, + /*getUpper=*/false); + APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block, + /*getUpper=*/true); + // Assume positivity for uniscoverable steps by way of getUpper = true. + APInt stepVal = + getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true); + + if (stepVal.isNegative()) { + std::swap(min, max); + } else { + // Correct the upper bound by subtracting 1 so that it becomes a <= + // bound, because loops do not generally include their upper bound. + max -= 1; + } - // If we infer the lower bound to be larger than the upper bound, the - // resulting range is meaningless and should not be used in further - // inferences. - if (max.sge(min)) { - IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv); - auto ivRange = ConstantIntRanges::fromSigned(min, max); - propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); + // If we infer the lower bound to be larger than the upper bound, the + // resulting range is meaningless and should not be used in further + // inferences. + if (max.sge(min)) { + IntegerValueRangeLattice *ivEntry = getLatticeElement(iv); + auto ivRange = ConstantIntRanges::fromSigned(min, max); + propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange})); + } } return; } diff --git a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir index b98e8b07db5ce..c6344447d9f74 100644 --- a/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir +++ b/mlir/test/Interfaces/InferIntRangeInterface/infer-int-range-test-ops.mlir @@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) { } return } + +// CHECK-LABEL: func @multiple_loop_ivs +func.func @multiple_loop_ivs(%arg0: memref) { + %ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index, + smin = 1 : index, smax = 32 : index } : index + %c0_i32 = arith.constant 0 : i32 + // CHECK: scf.forall + scf.forall (%arg1, %arg2) in (%ub1, 64) { + // CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index} + %1 = test.reflect_bounds %arg1 : index + // CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index} + %2 = test.reflect_bounds %arg2 : index + memref.store %c0_i32, %arg0[%1, %2] : memref + } + return +}