Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 46 additions & 44 deletions mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -180,23 +180,20 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(
return;
}

/// Given the results of getConstant{Lower,Upper}Bound() or getConstantStep()
/// on a LoopLikeInterface return the lower/upper bound for that result if
/// possible.
auto getLoopBoundFromFold = [&](std::optional<OpFoldResult> loopBound,
Type boundType, Block *block, bool getUpper) {
/// Given a lower bound, upper bound, or step from a LoopLikeInterface return
/// the lower/upper bound for that result if possible.
auto getLoopBoundFromFold = [&](OpFoldResult loopBound, Type boundType,
Block *block, bool getUpper) {
unsigned int width = ConstantIntRanges::getStorageBitwidth(boundType);
if (loopBound.has_value()) {
if (auto attr = dyn_cast<Attribute>(*loopBound)) {
if (auto bound = dyn_cast_or_null<IntegerAttr>(attr))
return bound.getValue();
} else if (auto value = llvm::dyn_cast_if_present<Value>(*loopBound)) {
const IntegerValueRangeLattice *lattice =
getLatticeElementFor(getProgramPointBefore(block), value);
if (lattice != nullptr && !lattice->getValue().isUninitialized())
return getUpper ? lattice->getValue().getValue().smax()
: lattice->getValue().getValue().smin();
}
if (auto attr = dyn_cast<Attribute>(loopBound)) {
if (auto bound = dyn_cast<IntegerAttr>(attr))
return bound.getValue();
} else if (auto value = llvm::dyn_cast<Value>(loopBound)) {
const IntegerValueRangeLattice *lattice =
getLatticeElementFor(getProgramPointBefore(block), value);
if (lattice != nullptr && !lattice->getValue().isUninitialized())
return getUpper ? lattice->getValue().getValue().smax()
: lattice->getValue().getValue().smin();
}
// Given the results of getConstant{Lower,Upper}Bound()
// or getConstantStep() on a LoopLikeInterface return the lower/upper
Expand All @@ -207,38 +204,43 @@ void IntegerRangeAnalysis::visitNonControlFlowArguments(

// Infer bounds for loop arguments that have static bounds
if (auto loop = dyn_cast<LoopLikeOpInterface>(op)) {
std::optional<Value> iv = loop.getSingleInductionVar();
if (!iv) {
std::optional<llvm::SmallVector<Value>> maybeIvs =
loop.getLoopInductionVars();
if (!maybeIvs) {
return SparseForwardDataFlowAnalysis ::visitNonControlFlowArguments(
op, successor, argLattices, firstIndex);
}
Block *block = iv->getParentBlock();
std::optional<OpFoldResult> lowerBound = loop.getSingleLowerBound();
std::optional<OpFoldResult> upperBound = loop.getSingleUpperBound();
std::optional<OpFoldResult> step = loop.getSingleStep();
APInt min = getLoopBoundFromFold(lowerBound, iv->getType(), block,
/*getUpper=*/false);
APInt max = getLoopBoundFromFold(upperBound, iv->getType(), block,
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
getLoopBoundFromFold(step, iv->getType(), block, /*getUpper=*/true);

if (stepVal.isNegative()) {
std::swap(min, max);
} else {
// Correct the upper bound by subtracting 1 so that it becomes a <=
// bound, because loops do not generally include their upper bound.
max -= 1;
}
// This shouldn't be returning nullopt if there are indunction variables.
SmallVector<OpFoldResult> lowerBounds = *loop.getLoopLowerBounds();
SmallVector<OpFoldResult> upperBounds = *loop.getLoopUpperBounds();
SmallVector<OpFoldResult> steps = *loop.getLoopSteps();
for (auto [iv, lowerBound, upperBound, step] :
llvm::zip_equal(*maybeIvs, lowerBounds, upperBounds, steps)) {
Block *block = iv.getParentBlock();
APInt min = getLoopBoundFromFold(lowerBound, iv.getType(), block,
/*getUpper=*/false);
APInt max = getLoopBoundFromFold(upperBound, iv.getType(), block,
/*getUpper=*/true);
// Assume positivity for uniscoverable steps by way of getUpper = true.
APInt stepVal =
getLoopBoundFromFold(step, iv.getType(), block, /*getUpper=*/true);

if (stepVal.isNegative()) {
std::swap(min, max);
} else {
// Correct the upper bound by subtracting 1 so that it becomes a <=
// bound, because loops do not generally include their upper bound.
max -= 1;
}

// If we infer the lower bound to be larger than the upper bound, the
// resulting range is meaningless and should not be used in further
// inferences.
if (max.sge(min)) {
IntegerValueRangeLattice *ivEntry = getLatticeElement(*iv);
auto ivRange = ConstantIntRanges::fromSigned(min, max);
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
// If we infer the lower bound to be larger than the upper bound, the
// resulting range is meaningless and should not be used in further
// inferences.
if (max.sge(min)) {
IntegerValueRangeLattice *ivEntry = getLatticeElement(iv);
auto ivRange = ConstantIntRanges::fromSigned(min, max);
propagateIfChanged(ivEntry, ivEntry->join(IntegerValueRange{ivRange}));
}
}
return;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,19 @@ func.func @propagate_from_block_to_iterarg(%arg0: index, %arg1: i1) {
}
return
}

// CHECK-LABEL: func @multiple_loop_ivs
func.func @multiple_loop_ivs(%arg0: memref<?x64xi32>) {
%ub1 = test.with_bounds { umin = 1 : index, umax = 32 : index,
smin = 1 : index, smax = 32 : index } : index
%c0_i32 = arith.constant 0 : i32
// CHECK: scf.forall
scf.forall (%arg1, %arg2) in (%ub1, 64) {
// CHECK: test.reflect_bounds {smax = 31 : index, smin = 0 : index, umax = 31 : index, umin = 0 : index}
%1 = test.reflect_bounds %arg1 : index
// CHECK-NEXT: test.reflect_bounds {smax = 63 : index, smin = 0 : index, umax = 63 : index, umin = 0 : index}
%2 = test.reflect_bounds %arg2 : index
memref.store %c0_i32, %arg0[%1, %2] : memref<?x64xi32>
}
return
}
Loading