Skip to content

Commit

Permalink
Revert "[mlir][SCF] ValueBoundsConstraintSet: Support scf.if (bra…
Browse files Browse the repository at this point in the history
…nches) (#85895)"

This reverts commit 6b30ffe.

gcc7 bot is broken
  • Loading branch information
joker-eph committed Apr 5, 2024
1 parent e5e1bc0 commit 8487e05
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 334 deletions.
43 changes: 9 additions & 34 deletions mlir/include/mlir/Interfaces/ValueBoundsOpInterface.h
Expand Up @@ -203,26 +203,6 @@ class ValueBoundsConstraintSet
std::optional<int64_t> dim1 = std::nullopt,
std::optional<int64_t> dim2 = std::nullopt);

/// Traverse the IR starting from the given value/dim and populate constraints
/// as long as the stop condition holds. Also process all values/dims that are
/// already on the worklist.
void populateConstraints(Value value, std::optional<int64_t> dim);

/// Comparison operator for `ValueBoundsConstraintSet::compare`.
enum ComparisonOperator { LT, LE, EQ, GT, GE };

/// Try to prove that, based on the current state of this constraint set
/// (i.e., without analyzing additional IR or adding new constraints), the
/// "lhs" value/dim is LE/LT/EQ/GT/GE than the "rhs" value/dim.
///
/// Return "true" if the specified relation between the two values/dims was
/// proven to hold. Return "false" if the specified relation could not be
/// proven. This could be because the specified relation does in fact not hold
/// or because there is not enough information in the constraint set. In other
/// words, if we do not know for sure, this function returns "false".
bool compare(Value lhs, std::optional<int64_t> lhsDim, ComparisonOperator cmp,
Value rhs, std::optional<int64_t> rhsDim);

/// Compute whether the given values/dimensions are equal. Return "failure" if
/// equality could not be determined.
///
Expand Down Expand Up @@ -294,13 +274,13 @@ class ValueBoundsConstraintSet

ValueBoundsConstraintSet(MLIRContext *ctx, StopConditionFn stopCondition);

/// Given an affine map with a single result (and map operands), add a new
/// column to the constraint set that represents the result of the map.
/// Traverse additional IR starting from the map operands as needed (as long
/// as the stop condition is not satisfied). Also process all values/dims that
/// are already on the worklist. Return the position of the newly added
/// column.
int64_t populateConstraints(AffineMap map, ValueDimList mapOperands);
/// Populates the constraint set for a value/map without actually computing
/// the bound. Returns the position for the value/map (via the return value
/// and `posOut` output parameter).
int64_t populateConstraintsSet(Value value,
std::optional<int64_t> dim = std::nullopt);
int64_t populateConstraintsSet(AffineMap map, ValueDimList mapOperands,
int64_t *posOut = nullptr);

/// Iteratively process all elements on the worklist until an index-typed
/// value or shaped value meets `stopCondition`. Such values are not processed
Expand All @@ -315,19 +295,14 @@ class ValueBoundsConstraintSet
/// value/dimension exists in the constraint set.
int64_t getPos(Value value, std::optional<int64_t> dim = std::nullopt) const;

/// Return an affine expression that represents column `pos` in the constraint
/// set.
AffineExpr getPosExpr(int64_t pos);

/// Insert a value/dimension into the constraint set. If `isSymbol` is set to
/// "false", a dimension is added. The value/dimension is added to the
/// worklist if `addToWorklist` is set.
/// worklist.
///
/// Note: There are certain affine restrictions wrt. dimensions. E.g., they
/// cannot be multiplied. Furthermore, bounds can only be queried for
/// dimensions but not for symbols.
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true,
bool addToWorklist = true);
int64_t insert(Value value, std::optional<int64_t> dim, bool isSymbol = true);

/// Insert an anonymous column into the constraint set. The column is not
/// bound to any value/dimension. If `isSymbol` is set to "false", a dimension
Expand Down
61 changes: 0 additions & 61 deletions mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
Expand Up @@ -111,66 +111,6 @@ struct ForOpInterface
}
};

struct IfOpInterface
: public ValueBoundsOpInterface::ExternalModel<IfOpInterface, IfOp> {

static void populateBounds(scf::IfOp ifOp, Value value,
std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
unsigned int resultNum = cast<OpResult>(value).getResultNumber();
Value thenValue = ifOp.thenYield().getResults()[resultNum];
Value elseValue = ifOp.elseYield().getResults()[resultNum];

// Populate constraints for the yielded value (and all values on the
// backward slice, as long as the current stop condition is not satisfied).
cstr.populateConstraints(thenValue, dim);
cstr.populateConstraints(elseValue, dim);
auto boundsBuilder = cstr.bound(value);
if (dim)
boundsBuilder[*dim];

// Compare yielded values.
// If thenValue <= elseValue:
// * result <= elseValue
// * result >= thenValue
if (cstr.compare(thenValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
elseValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(thenValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(elseValue, dim);
} else {
cstr.bound(value) >= thenValue;
cstr.bound(value) <= elseValue;
}
}
// If elseValue <= thenValue:
// * result <= thenValue
// * result >= elseValue
if (cstr.compare(elseValue, dim,
ValueBoundsConstraintSet::ComparisonOperator::LE,
thenValue, dim)) {
if (dim) {
cstr.bound(value)[*dim] >= cstr.getExpr(elseValue, dim);
cstr.bound(value)[*dim] <= cstr.getExpr(thenValue, dim);
} else {
cstr.bound(value) >= elseValue;
cstr.bound(value) <= thenValue;
}
}
}

void populateBoundsForIndexValue(Operation *op, Value value,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, /*dim=*/std::nullopt, cstr);
}

void populateBoundsForShapedValueDim(Operation *op, Value value, int64_t dim,
ValueBoundsConstraintSet &cstr) const {
populateBounds(cast<IfOp>(op), value, dim, cstr);
}
};

} // namespace
} // namespace scf
} // namespace mlir
Expand All @@ -179,6 +119,5 @@ void mlir::scf::registerValueBoundsOpInterfaceExternalModels(
DialectRegistry &registry) {
registry.addExtension(+[](MLIRContext *ctx, scf::SCFDialect *dialect) {
scf::ForOp::attachInterface<scf::ForOpInterface>(*ctx);
scf::IfOp::attachInterface<scf::IfOpInterface>(*ctx);
});
}
14 changes: 5 additions & 9 deletions mlir/lib/Dialect/Vector/IR/ScalableValueBoundsConstraintSet.cpp
Expand Up @@ -59,24 +59,20 @@ ScalableValueBoundsConstraintSet::computeScalableBound(
ScalableValueBoundsConstraintSet scalableCstr(
value.getContext(), stopCondition ? stopCondition : defaultStopCondition,
vscaleMin, vscaleMax);
int64_t pos = scalableCstr.insert(value, dim, /*isSymbol=*/false);
scalableCstr.processWorklist();
int64_t pos = scalableCstr.populateConstraintsSet(value, dim);

// Project out all columns apart from vscale and the starting point
// (value/dim). This should result in constraints in terms of vscale only.
// Project out all variables apart from vscale.
// This should result in constraints in terms of vscale only.
auto projectOutFn = [&](ValueDim p) {
bool isStartingPoint =
p.first == value &&
p.second == dim.value_or(ValueBoundsConstraintSet::kIndexValue);
return p.first != scalableCstr.getVscaleValue() && !isStartingPoint;
return p.first != scalableCstr.getVscaleValue();
};
scalableCstr.projectOut(projectOutFn);

assert(scalableCstr.cstr.getNumDimAndSymbolVars() ==
scalableCstr.positionToValueDim.size() &&
"inconsistent mapping state");

// Check that the only columns left are vscale and the starting point.
// Check that the only symbols left are vscale.
for (int64_t i = 0; i < scalableCstr.cstr.getNumDimAndSymbolVars(); ++i) {
if (i == pos)
continue;
Expand Down
143 changes: 30 additions & 113 deletions mlir/lib/Interfaces/ValueBoundsOpInterface.cpp
Expand Up @@ -110,47 +110,25 @@ AffineExpr ValueBoundsConstraintSet::getExpr(Value value,
assertValidValueDim(value, dim);
#endif // NDEBUG

// Check if the value/dim is statically known. In that case, an affine
// constant expression should be returned. This allows us to support
// multiplications with constants. (Multiplications of two columns in the
// constraint set is not supported.)
std::optional<int64_t> constSize = std::nullopt;
auto shapedType = dyn_cast<ShapedType>(value.getType());
if (shapedType) {
// Static dimension: return constant directly.
if (shapedType.hasRank() && !shapedType.isDynamicDim(*dim))
constSize = shapedType.getDimSize(*dim);
} else if (auto constInt = ::getConstantIntValue(value)) {
constSize = *constInt;
return builder.getAffineConstantExpr(shapedType.getDimSize(*dim));
} else {
// Constant index value: return directly.
if (auto constInt = ::getConstantIntValue(value))
return builder.getAffineConstantExpr(*constInt);
}

// If the value/dim is already mapped, return the corresponding expression
// directly.
// Dynamic value: add to constraint set.
ValueDim valueDim = std::make_pair(value, dim.value_or(kIndexValue));
if (valueDimToPosition.contains(valueDim)) {
// If it is a constant, return an affine constant expression. Otherwise,
// return an affine expression that represents the respective column in the
// constraint set.
if (constSize)
return builder.getAffineConstantExpr(*constSize);
return getPosExpr(getPos(value, dim));
}

if (constSize) {
// Constant index value/dim: add column to the constraint set, add EQ bound
// and return an affine constant expression without pushing the newly added
// column to the worklist.
(void)insert(value, dim, /*isSymbol=*/true, /*addToWorklist=*/false);
if (shapedType)
bound(value)[*dim] == *constSize;
else
bound(value) == *constSize;
return builder.getAffineConstantExpr(*constSize);
}

// Dynamic value/dim: insert column to the constraint set and put it on the
// worklist. Return an affine expression that represents the newly inserted
// column in the constraint set.
return getPosExpr(insert(value, dim, /*isSymbol=*/true));
if (!valueDimToPosition.contains(valueDim))
(void)insert(value, dim);
int64_t pos = getPos(value, dim);
return pos < cstr.getNumDimVars()
? builder.getAffineDimExpr(pos)
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
}

AffineExpr ValueBoundsConstraintSet::getExpr(OpFoldResult ofr) {
Expand All @@ -167,7 +145,7 @@ AffineExpr ValueBoundsConstraintSet::getExpr(int64_t constant) {

int64_t ValueBoundsConstraintSet::insert(Value value,
std::optional<int64_t> dim,
bool isSymbol, bool addToWorklist) {
bool isSymbol) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG
Expand All @@ -182,12 +160,7 @@ int64_t ValueBoundsConstraintSet::insert(Value value,
if (positionToValueDim[i].has_value())
valueDimToPosition[*positionToValueDim[i]] = i;

if (addToWorklist) {
LLVM_DEBUG(llvm::dbgs() << "Push to worklist: " << value
<< " (dim: " << dim.value_or(kIndexValue) << ")\n");
worklist.push(pos);
}

worklist.push(pos);
return pos;
}

Expand Down Expand Up @@ -217,13 +190,6 @@ int64_t ValueBoundsConstraintSet::getPos(Value value,
return it->second;
}

AffineExpr ValueBoundsConstraintSet::getPosExpr(int64_t pos) {
assert(pos >= 0 && pos < cstr.getNumDimAndSymbolVars() && "invalid position");
return pos < cstr.getNumDimVars()
? builder.getAffineDimExpr(pos)
: builder.getAffineSymbolExpr(pos - cstr.getNumDimVars());
}

static Operation *getOwnerOfValue(Value value) {
if (auto bbArg = dyn_cast<BlockArgument>(value))
return bbArg.getOwner()->getParentOp();
Expand Down Expand Up @@ -526,16 +492,15 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(

// Default stop condition if none was specified: Keep adding constraints until
// a bound could be computed.
int64_t pos = 0;
int64_t pos;
auto defaultStopCondition = [&](Value v, std::optional<int64_t> dim,
ValueBoundsConstraintSet &cstr) {
return cstr.cstr.getConstantBound64(type, pos).has_value();
};

ValueBoundsConstraintSet cstr(
map.getContext(), stopCondition ? stopCondition : defaultStopCondition);
pos = cstr.populateConstraints(map, operands);
assert(pos == 0 && "expected `map` is the first column");
cstr.populateConstraintsSet(map, operands, &pos);

// Compute constant bound for `valueDim`.
int64_t ubAdjustment = closedUB ? 0 : 1;
Expand All @@ -544,28 +509,29 @@ FailureOr<int64_t> ValueBoundsConstraintSet::computeConstantBound(
return failure();
}

void ValueBoundsConstraintSet::populateConstraints(Value value,
std::optional<int64_t> dim) {
int64_t
ValueBoundsConstraintSet::populateConstraintsSet(Value value,
std::optional<int64_t> dim) {
#ifndef NDEBUG
assertValidValueDim(value, dim);
#endif // NDEBUG

// `getExpr` pushes the value/dim onto the worklist (unless it was already
// analyzed).
(void)getExpr(value, dim);
// Process all values/dims on the worklist. This may traverse and analyze
// additional IR, depending the current stop function.
processWorklist();
AffineMap map =
AffineMap::get(/*dimCount=*/1, /*symbolCount=*/0,
Builder(value.getContext()).getAffineDimExpr(0));
return populateConstraintsSet(map, {{value, dim}});
}

int64_t ValueBoundsConstraintSet::populateConstraints(AffineMap map,
ValueDimList operands) {
int64_t ValueBoundsConstraintSet::populateConstraintsSet(AffineMap map,
ValueDimList operands,
int64_t *posOut) {
assert(map.getNumResults() == 1 && "expected affine map with one result");
int64_t pos = insert(/*isSymbol=*/false);
if (posOut)
*posOut = pos;

// Add map and operands to the constraint set. Dimensions are converted to
// symbols. All operands are added to the worklist (unless they were already
// processed).
// symbols. All operands are added to the worklist.
auto mapper = [&](std::pair<Value, std::optional<int64_t>> v) {
return getExpr(v.first, v.second);
};
Expand Down Expand Up @@ -600,55 +566,6 @@ ValueBoundsConstraintSet::computeConstantDelta(Value value1, Value value2,
{{value1, dim1}, {value2, dim2}});
}

bool ValueBoundsConstraintSet::compare(Value lhs, std::optional<int64_t> lhsDim,
ComparisonOperator cmp, Value rhs,
std::optional<int64_t> rhsDim) {
// This function returns "true" if "lhs CMP rhs" is proven to hold.
//
// Example for ComparisonOperator::LE and index-typed values: We would like to
// prove that lhs <= rhs. Proof by contradiction: add the inverse
// relation (lhs > rhs) to the constraint set and check if the resulting
// constraint set is "empty" (i.e. has no solution). In that case,
// lhs > rhs must be incorrect and we can deduce that lhs <= rhs holds.

// We cannot prove anything if the constraint set is already empty.
if (cstr.isEmpty()) {
LLVM_DEBUG(
llvm::dbgs()
<< "cannot compare value/dims: constraint system is already empty");
return false;
}

// EQ can be expressed as LE and GE.
if (cmp == EQ)
return compare(lhs, lhsDim, ComparisonOperator::LE, rhs, rhsDim) &&
compare(lhs, lhsDim, ComparisonOperator::GE, rhs, rhsDim);

// Construct inequality. For the above example: lhs > rhs.
// `IntegerRelation` inequalities are expressed in the "flattened" form and
// with ">= 0". I.e., lhs - rhs - 1 >= 0.
SmallVector<int64_t> eq(cstr.getNumDimAndSymbolVars() + 1, 0);
if (cmp == LT || cmp == LE) {
++eq[getPos(lhs, lhsDim)];
--eq[getPos(rhs, rhsDim)];
} else if (cmp == GT || cmp == GE) {
--eq[getPos(lhs, lhsDim)];
++eq[getPos(rhs, rhsDim)];
} else {
llvm_unreachable("unsupported comparison operator");
}
if (cmp == LE || cmp == GE)
eq[cstr.getNumDimAndSymbolVars()] -= 1;

// Add inequality to the constraint set and check if it made the constraint
// set empty.
int64_t ineqPos = cstr.getNumInequalities();
cstr.addInequality(eq);
bool isEmpty = cstr.isEmpty();
cstr.removeInequality(ineqPos);
return isEmpty;
}

FailureOr<bool>
ValueBoundsConstraintSet::areEqual(Value value1, Value value2,
std::optional<int64_t> dim1,
Expand Down

0 comments on commit 8487e05

Please sign in to comment.