Skip to content

Commit

Permalink
improve-constant-propagation
Browse files Browse the repository at this point in the history
  • Loading branch information
JonathanStarup committed Jun 20, 2024
1 parent d8ed722 commit 1e80b7b
Showing 1 changed file with 33 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import scala.collection.mutable
/// - We move a single variable to the lhs.
/// - See the definition of `Equation.mk` for details.
/// - Progress in the following order:
/// 1. Propagate ground terms (univ/empty/element/constant) in a fixpoint. // TODO also fx x ∪ y ∪ z ~ empty
/// 1. Propagate ground terms (univ/empty/element/constant) in a fixpoint.
/// 2. Propagate variables (i.e. resolving constraints of the form x = y).
/// 3. Perform trivial assignments where the left-hand variables does not occur in the RHS.
/// 4. Eliminate (now) trivial and redundant constraints.
Expand Down Expand Up @@ -352,9 +352,11 @@ object FastSetUnification {
case (Term.Univ, Term.Empty) => throw ConflictException(t1, t2, loc)
case (Term.Univ, Term.ElemSet(_)) => throw ConflictException(t1, t2, loc)
case (Term.Univ, Term.Cst(_)) => throw ConflictException(t1, t2, loc)
case (Term.Univ, inter: Term.Inter) if inter.trivialNonUniv => throw ConflictException(t1, t2, loc)
case (Term.Empty, Term.Univ) => throw ConflictException(t1, t2, loc)
case (Term.Empty, Term.ElemSet(_)) => throw ConflictException(t1, t2, loc)
case (Term.Empty, Term.Cst(_)) => throw ConflictException(t1, t2, loc)
case (Term.Empty, union: Term.Union) if union.trivialNonEmpty => throw ConflictException(t1, t2, loc)
case (Term.ElemSet(_), Term.Univ) => throw ConflictException(t1, t2, loc)
case (Term.ElemSet(_), Term.Empty) => throw ConflictException(t1, t2, loc)
case (Term.ElemSet(i1), Term.ElemSet(i2)) if i1 != i2 => throw ConflictException(t1, t2, loc)
Expand All @@ -363,6 +365,9 @@ object FastSetUnification {
case (Term.Cst(_), Term.Empty) => throw ConflictException(t1, t2, loc)
case (Term.Cst(_), Term.ElemSet(_)) => throw ConflictException(t1, t2, loc)
case (Term.Cst(c1), Term.Cst(c2)) if c1 != c2 => throw ConflictException(t1, t2, loc)
case (inter: Term.Inter, Term.Univ) if inter.trivialNonUniv => throw ConflictException(t1, t2, loc)
case (union: Term.Union, Term.Empty) if union.trivialNonEmpty => throw ConflictException(t1, t2, loc)
// TODO can check trivially impossible `element ~ intersection` and `constant ~ intersection`

// Non-trivial and non-conflicted equation: keep it.
case _ => Equation(t1, t2, loc) :: checkAndSimplify(es)
Expand All @@ -374,13 +379,13 @@ object FastSetUnification {
*
* The implementation saturates the system, i.e. it computes a fixpoint.
*
* The implementation uses three rewrite rules:
* The implementation uses five rewrite rules:
*
* - `x ~ univ` becomes `[x -> univ]`.
* - `x ~ c` becomes `[x -> c]`.
* - `x ~ e` becomes `[x -> e]`.
* - `x ∩ y ∩ !z ∩ ... = univ` becomes `[x -> univ, y -> univ, z -> empty, ...]`.
* - `x ∪ y ∪ !z ∪ ... = empty` becomes `[x -> empty, y -> empty, z -> univ, ...]`.
* - `x ∩ y ∩ !z ∩ ... ∩ rest ~ univ` becomes `[x -> univ, y -> univ, z -> empty, ...]` and `rest ~ univ`.
* - `x ∪ y ∪ !z ∪ ... ∪ rest ~ empty` becomes `[x -> empty, y -> empty, z -> univ, ...]` and `rest ~ empty`.
*
* For example, if the equation system is:
*
Expand Down Expand Up @@ -441,7 +446,7 @@ object FastSetUnification {
while (changed) {
changed = false

var rest: List[Equation] = Nil
var remaining: List[Equation] = Nil
// OBS: subst.extended checks for conflicting mappings
for (e <- pending) {
e match {
Expand All @@ -460,9 +465,9 @@ object FastSetUnification {
subst = subst.extended(x, e, loc)
changed = true

// Case 4: x ∩ y ∩ !z ∩ ... ~ univ
// Case 4: x ∩ y ∩ !z ∩ ... ∩ rest ~ univ
case Equation(Term.Inter(None, posCsts, posVars, negElem, negCsts, negVars, rest), Term.Univ, loc) if
posCsts.isEmpty && negElem.isEmpty && negCsts.isEmpty && rest.isEmpty =>
posCsts.isEmpty && negElem.isEmpty && negCsts.isEmpty =>
{
for (Term.Var(x) <- posVars) {
subst = subst.extended(x, Term.Univ, loc)
Expand All @@ -472,11 +477,15 @@ object FastSetUnification {
subst = subst.extended(x, Term.Empty, loc)
changed = true
}
if (rest.nonEmpty) {
remaining = rest.map(Equation.mk(_, Term.Univ, loc)) ++ remaining
changed = true
}
}

// Case 5: x ∪ y ∪ !z ∪ ... ~ empty
// Case 5: x ∪ y ∪ !z ∪ ... ∪ rest ~ empty
case Equation(Term.Union(posElem, posCsts, posVars, negElem, negCsts, negVars, rest), Term.Empty, loc) if
posElem.isEmpty && posCsts.isEmpty && negElem.isEmpty && negCsts.isEmpty && rest.isEmpty =>
posElem.isEmpty && posCsts.isEmpty && negElem.isEmpty && negCsts.isEmpty =>
{
for (Term.Var(x) <- posVars) {
subst = subst.extended(x, Term.Empty, loc)
Expand All @@ -486,14 +495,18 @@ object FastSetUnification {
subst = subst.extended(x, Term.Univ, loc)
changed = true
}
if (rest.nonEmpty) {
remaining = rest.map(Equation.mk(_, Term.Empty, loc)) ++ remaining
changed = true
}
}

case _ =>
rest = e :: rest
remaining = e :: remaining
}
}
// INVARIANT: We apply the current substitution to all remaining equations.
pending = subst(rest)
pending = subst(remaining)
}

// Reverse the unsolved equations to ensure they are returned in the original order.
Expand Down Expand Up @@ -1205,9 +1218,10 @@ object FastSetUnification {
def mk(t1: Term, t2: Term, loc: SourceLocation): Equation = (t1, t2) match {
case (Term.Cst(c1), Term.Cst(c2)) => if (c1 <= c2) Equation(t1, t2, loc) else Equation(t2, t1, loc)
case (Term.Var(x1), Term.Var(x2)) => if (x1 <= x2) Equation(t1, t2, loc) else Equation(t2, t1, loc)
case (Term.Univ, _) => Equation(t2, Term.Univ, loc)
case (Term.Empty, _) => Equation(t2, Term.Empty, loc)
case (Term.Univ, _) => Equation(t2, t1, loc)
case (Term.Empty, _) => Equation(t2, t1, loc)
case (Term.ElemSet(_), _) => Equation(t2, t1, loc)
case (Term.Cst(_), _) => Equation(t2, t1, loc)
case (_, Term.Var(_)) => Equation(t2, t1, loc)
case _ => Equation(t1, t2, loc)
}
Expand Down Expand Up @@ -1406,7 +1420,9 @@ object FastSetUnification {
*
* `None, Set(c1), Set(x4, x7), Set(), Set(), Set(x2), List(e1 ∪ x9)`.
*/
case class Inter(posElem: Option[Term.ElemSet], posCsts: Set[Term.Cst], posVars: Set[Term.Var], negElem: Option[Term.ElemSet], negCsts: Set[Term.Cst], negVars: Set[Term.Var], rest: List[Term]) extends Term
case class Inter(posElem: Option[Term.ElemSet], posCsts: Set[Term.Cst], posVars: Set[Term.Var], negElem: Option[Term.ElemSet], negCsts: Set[Term.Cst], negVars: Set[Term.Var], rest: List[Term]) extends Term {
def trivialNonUniv: Boolean = posElem.isDefined || posCsts.nonEmpty || negElem.isDefined || negCsts.nonEmpty
}

/**
* A union of the terms `ts` (`∪`). An empty union is empty.
Expand All @@ -1415,7 +1431,9 @@ object FastSetUnification {
*
* Represented similarly to [[Inter]].
*/
case class Union(posElem: Option[Term.ElemSet], posCsts: Set[Term.Cst], posVars: Set[Term.Var], negElem: Option[Term.ElemSet], negCsts: Set[Term.Cst], negVars: Set[Term.Var], rest: List[Term]) extends Term
case class Union(posElem: Option[Term.ElemSet], posCsts: Set[Term.Cst], posVars: Set[Term.Var], negElem: Option[Term.ElemSet], negCsts: Set[Term.Cst], negVars: Set[Term.Var], rest: List[Term]) extends Term {
def trivialNonEmpty: Boolean = posElem.isDefined || posCsts.nonEmpty || negElem.isDefined || negCsts.nonEmpty
}

final def mkElemSet(i: Int): Term = {
Term.ElemSet(SortedSet(i))
Expand Down

0 comments on commit 1e80b7b

Please sign in to comment.