Skip to content

Commit

Permalink
[FIRRTL][InferWidths] Fix back-prop, fix ref equality, fix upper boun…
Browse files Browse the repository at this point in the history
…d. (#5403)

Fix upperBoundSolution being set from bool indicating if cycle,
instead of the value solved for (.first instead of .second).

Tweak how/when upper bound constraint is solved.

Fixes #5002.
Fixes #5391.
  • Loading branch information
dtzSiFive committed Jun 14, 2023
1 parent c5aff9a commit 6c01320
Show file tree
Hide file tree
Showing 2 changed files with 88 additions and 26 deletions.
65 changes: 39 additions & 26 deletions lib/Dialect/FIRRTL/Transforms/InferWidths.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -887,9 +887,14 @@ static ExprSolution solveExpr(Expr *expr, SmallPtrSetImpl<Expr *> &seenVars,
.Case<VarExpr>([&](auto *expr) {
if (solvedExprs.contains(expr->constraint)) {
auto solution = solvedExprs[expr->constraint];
if (expr->upperBound)
expr->upperBoundSolution = solvedExprs[expr->upperBound].second;

// If we've solved the upper bound already, store the solution.
// This will be explicitly solved for later if not computed as
// part of the solving that resolved this constraint.
// This should only happen if somehow the constraint is
// solved before visiting this expression, so that our upperBound
// was not added to the worklist such that it was handled first.
if (expr->upperBound && solvedExprs.contains(expr->upperBound))
expr->upperBoundSolution = solvedExprs[expr->upperBound].first;
seenVars.erase(expr);
// Constrain variables >= 0.
if (solution.first && *solution.first < 0)
Expand Down Expand Up @@ -1051,7 +1056,8 @@ LogicalResult ConstraintSolver::solve() {
<< "- Solving " << *var << " >= " << *var->constraint << "\n");
seenVars.insert(var);
auto solution = solveExpr(var->constraint, seenVars, defaultWorklistSize);
if (var->upperBound)
// Compute the upperBound if there is one and haven't already.
if (var->upperBound && !var->upperBoundSolution)
var->upperBoundSolution =
solveExpr(var->upperBound, seenVars, defaultWorklistSize).first;
seenVars.clear();
Expand Down Expand Up @@ -1188,12 +1194,12 @@ class InferenceMapping {

/// Constrain the value "larger" to be greater than or equal to "smaller".
/// These may be aggregate values. This is used for regular connects.
void constrainTypes(Value larger, Value smaller);
void constrainTypes(Value larger, Value smaller, bool equal = false);

/// Constrain the expression "larger" to be greater than or equals to
/// the expression "smaller".
void constrainTypes(Expr *larger, Expr *smaller,
bool imposeUpperBounds = false);
bool imposeUpperBounds = false, bool equal = false);

/// Assign the constraint expressions of the fields in the `src` argument as
/// the expressions for the `dst` argument. Both fields must be of the given
Expand Down Expand Up @@ -1544,15 +1550,10 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
})
.Case<ConnectOp>(
[&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
.Case<RefDefineOp>(
[&](auto op) { constrainTypes(op.getDest(), op.getSrc()); })
// StrictConnect is an identify constraint
.Case<StrictConnectOp>([&](auto op) {
constrainTypes(op.getDest(), op.getSrc());
constrainTypes(op.getSrc(), op.getDest());
// unifyTypes(FieldRef(op.getDest(), 0),
// FieldRef(op.getSrc(), 0),
// op.getDest().getType().template cast<FIRRTLType>());
.Case<RefDefineOp, StrictConnectOp>([&](auto op) {
// Dest >= Src, but also check Src <= Dest for correctness
// (but don't solve to make this true, don't back-propagate)
constrainTypes(op.getDest(), op.getSrc(), true);
})
.Case<AttachOp>([&](auto op) {
// Attach connects multiple analog signals together. All signals must
Expand Down Expand Up @@ -1672,15 +1673,15 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {

.Case<RefSendOp>([&](auto op) {
declareVars(op.getResult(), op.getLoc());
constrainTypes(op.getResult(), op.getBase());
constrainTypes(op.getResult(), op.getBase(), true);
})
.Case<RefResolveOp>([&](auto op) {
declareVars(op.getResult(), op.getLoc());
constrainTypes(op.getResult(), op.getRef());
constrainTypes(op.getResult(), op.getRef(), true);
})
.Case<RefCastOp>([&](auto op) {
declareVars(op.getResult(), op.getLoc());
constrainTypes(op.getResult(), op.getInput());
constrainTypes(op.getResult(), op.getInput(), true);
})
.Case<mlir::UnrealizedConversionCastOp>([&](auto op) {
for (Value result : op.getResults()) {
Expand All @@ -1695,10 +1696,9 @@ LogicalResult InferenceMapping::mapOperation(Operation *op) {
});

// Forceable declarations should have the ref constrained to data result.
if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable()) {
declareVars(fop.getDataRef(), fop.getLoc());
constrainTypes(fop.getDataRef(), fop.getDataRaw());
}
if (auto fop = dyn_cast<Forceable>(op); fop && fop.isForceable())
unifyTypes(FieldRef(fop.getDataRef(), 0), FieldRef(fop.getDataRaw(), 0),
fop.getDataType());

return failure(mappingFailed);
}
Expand Down Expand Up @@ -1789,7 +1789,9 @@ void InferenceMapping::maximumOfTypes(Value result, Value rhs, Value lhs) {
/// of bit widths.
///
/// This function is used to apply regular connects.
void InferenceMapping::constrainTypes(Value larger, Value smaller) {
/// Set `equal` for constraining larger <= smaller for correctness but not
/// solving.
void InferenceMapping::constrainTypes(Value larger, Value smaller, bool equal) {
// Recurse to every leaf element and set larger >= smaller. Ignore foreign
// types as these do not participate in width inference.

Expand Down Expand Up @@ -1819,7 +1821,7 @@ void InferenceMapping::constrainTypes(Value larger, Value smaller) {
} else if (type.isGround()) {
// Leaf element, look up their expressions, and create the constraint.
constrainTypes(getExpr(FieldRef(larger, fieldID)),
getExpr(FieldRef(smaller, fieldID)));
getExpr(FieldRef(smaller, fieldID)), false, equal);
fieldID++;
} else {
llvm_unreachable("Unknown type inside a bundle!");
Expand All @@ -1833,7 +1835,7 @@ void InferenceMapping::constrainTypes(Value larger, Value smaller) {
/// Establishes constraints to ensure the sizes in the `larger` type are greater
/// than or equal to the sizes in the `smaller` type.
void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
bool imposeUpperBounds) {
bool imposeUpperBounds, bool equal) {
assert(larger && "Larger expression should be specified");
assert(smaller && "Smaller expression should be specified");

Expand All @@ -1860,6 +1862,17 @@ void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
LLVM_ATTRIBUTE_UNUSED auto *c = solver.addGeqConstraint(largerVar, smaller);
LLVM_DEBUG(llvm::dbgs()
<< "Constrained " << *largerVar << " >= " << *c << "\n");
// If we're constraining larger == smaller, add the LEQ contraint as well.
// Solve for GEQ but check that LEQ is true.
// Used for strictconnect, some reference operations, and anywhere the
// widths should be inferred strictly in one direction but are required to
// also be equal for correctness.
if (equal) {
LLVM_ATTRIBUTE_UNUSED auto *leq =
solver.addLeqConstraint(largerVar, smaller);
LLVM_DEBUG(llvm::dbgs()
<< "Constrained " << *largerVar << " <= " << *leq << "\n");
}
return;
}

Expand All @@ -1869,7 +1882,7 @@ void InferenceMapping::constrainTypes(Expr *larger, Expr *smaller,
// `>=` constraints, any `<=` constraints have no effect on the solution
// besides indicating that a width is unsatisfiable.
if (auto *smallerVar = dyn_cast<VarExpr>(smaller)) {
if (imposeUpperBounds) {
if (imposeUpperBounds || equal) {
LLVM_ATTRIBUTE_UNUSED auto *c =
solver.addLeqConstraint(smallerVar, larger);
LLVM_DEBUG(llvm::dbgs()
Expand Down
49 changes: 49 additions & 0 deletions test/Dialect/FIRRTL/infer-widths-errors.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,52 @@ firrtl.circuit "Foo" {
firrtl.connect %out, %0 : !firrtl.uint, !firrtl.uint
}
}

// -----
// https://github.com/llvm/circt/issues/5391
// Unclear if widthCast is/should be allowed to be CSE'd.
// This IR was generated from test on that issue before InferWidths.

// Don't back-propagate widths.
firrtl.circuit "Issue5391" {
firrtl.module @Issue5391(in %x: !firrtl.uint<1>,
in %y: !firrtl.uint<2>,
out %out1: !firrtl.uint,
out %out2: !firrtl.uint) {
// expected-error @below {{uninferred width: wire "w" cannot satisfy all width requirements}}
%w = firrtl.wire : !firrtl.uint
%0 = firrtl.widthCast %x : (!firrtl.uint<1>) -> !firrtl.uint
firrtl.strictconnect %w, %0 : !firrtl.uint
%1 = firrtl.widthCast %y : (!firrtl.uint<2>) -> !firrtl.uint
// expected-note @below {{width is constrained to be at most 1 here:}}
// expected-note @below {{width is constrained to be at least 2 here:}}
firrtl.strictconnect %w, %1 : !firrtl.uint
%2 = firrtl.widthCast %w : (!firrtl.uint) -> !firrtl.uint
firrtl.strictconnect %out1, %2 : !firrtl.uint
%wx = firrtl.wire : !firrtl.uint
firrtl.strictconnect %wx, %0 : !firrtl.uint
%3 = firrtl.widthCast %wx : (!firrtl.uint) -> !firrtl.uint
firrtl.strictconnect %out2, %3 : !firrtl.uint
}
}

// -----
// https://github.com/llvm/circt/issues/5002

firrtl.circuit "Issue5002" {
// expected-error @below {{uninferred width: port "ref" cannot satisfy all width requirements}}
firrtl.module private @InRef(in %ref : !firrtl.rwprobe<uint>) { }
firrtl.module @Issue5002(in %x : !firrtl.uint<1>, in %y : !firrtl.uint<2>) {
%w1, %w1_ref = firrtl.wire forceable : !firrtl.uint, !firrtl.rwprobe<uint>
%w2, %w2_ref = firrtl.wire forceable : !firrtl.uint, !firrtl.rwprobe<uint>
firrtl.connect %w1, %x : !firrtl.uint, !firrtl.uint<1>
firrtl.connect %w2, %y : !firrtl.uint, !firrtl.uint<2>

%inst1_ref = firrtl.instance inst1 @InRef(in ref: !firrtl.rwprobe<uint>)
%inst2_ref = firrtl.instance inst2 @InRef(in ref: !firrtl.rwprobe<uint>)
firrtl.ref.define %inst1_ref, %w1_ref : !firrtl.rwprobe<uint>
// expected-note @below {{width is constrained to be at most 1 here:}}
// expected-note @below {{width is constrained to be at least 2 here:}}
firrtl.ref.define %inst2_ref, %w2_ref : !firrtl.rwprobe<uint>
}
}

0 comments on commit 6c01320

Please sign in to comment.