Skip to content

Commit

Permalink
[shape_poly] Improve the lexicographic ordering of symbolic expressions
Browse files Browse the repository at this point in the history
In preparation for upcoming changes in the reasoning about
inequalities, we change the lexicographic ordering to
ensure that a symbolic expressions is strictly larger than
any constituent subexpressions. We add a `_size` attribute
that computes (and caches) the syntactic size of the expression.
  • Loading branch information
gnecula committed Jan 9, 2024
1 parent 7ad7890 commit b7f82e8
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 15 deletions.
29 changes: 18 additions & 11 deletions jax/experimental/export/shape_poly.py
Expand Up @@ -133,7 +133,13 @@ def __init__(self, *operands: _DimExpr,
self.var = var
self.operation = operation
self.operands = operands
# Precompute the hash (used extensively because these are kept in
# dictionaries) and the size (used for sorting, which is important for
# some of the reasoning). It is important for the size of _DimAtom,
# _DimMon, or _DimExpr, to be strictly larger than the size of any
# constituent sub-item.
self._hash = hash((self.var, self.operation, *self.operands))
self._size = 1 if var is not None else 1 + sum(o._size for o in operands)

@classmethod
def from_var(cls, v: str) -> _DimAtom:
Expand Down Expand Up @@ -172,12 +178,9 @@ def _syntactic_cmp(self, other: _DimAtom) -> int:
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
if c := cmp_comparable(self._size, other._size): return c
if self.var is not None:
if other.var is not None:
return cmp_comparable(self.var, other.var)
else:
return -1
if other.var is not None: return 1
return cmp_comparable(self.var, other.var)
if c := cmp_comparable(self.operation, other.operation): return c # type: ignore
return cmp_sequence(self.operands, other.operands,
lambda s_o, o_o: s_o._syntactic_cmp(o_o))
Expand Down Expand Up @@ -288,12 +291,13 @@ class _DimMon(dict):
The exponents are integers >= 1.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._hash = hash(frozenset(self.items()))
self._size = sum((1 + a._size) for a, a_exp in self.items())

def __hash__(self):
h = getattr(self, "_hash", None)
if h is not None: return h
h = hash(frozenset(self.items()))
self._hash = h
return h
return self._hash

def __str__(self):
return "*".join(f"{key}^{exponent}" if exponent != 1 else str(key)
Expand Down Expand Up @@ -344,6 +348,7 @@ def _syntactic_cmp(self, other: _DimMon) -> int:
The comparison is done lexicographically (syntactic), to be used for sorting.
The result is not related to the semantic value.
"""
if c := cmp_comparable(self._size, other._size): return c
if c := cmp_comparable(self.degree, other.degree): return c
def cmp_atom(s_a: tuple[_DimAtom, int], o_a: tuple[_DimAtom, int]) -> int:
if c := s_a[0]._syntactic_cmp(o_a[0]): return c
Expand Down Expand Up @@ -431,6 +436,8 @@ def __init__(self, coeffs: dict[_DimMon, int]):
self._coeffs = coeffs or {_DimMon(): 0}
self._monomials_sorted = tuple(sorted(self._coeffs.items(), reverse=True))
self._hash = hash(self._monomials_sorted)
self._size = sum((1 + m._size)
for m, m_count in self._monomials_sorted)

def monomials(self) -> Iterable[tuple[_DimMon, int]]:
"""The monomials in sorted reverse lexicographic order.
Expand Down Expand Up @@ -655,7 +662,7 @@ def _one_monomial(mon, c):
# We print first the "larger" monomials, so that the constant is last.
res = " + ".join(_one_monomial(mon, c)
for mon, c in self._monomials_sorted)
res = res.replace(" + -", " - ")
res = res.replace(" + -", " - ").replace(" - 1*", " - ")
return res

def __repr__(self):
Expand Down
12 changes: 8 additions & 4 deletions tests/shape_poly_test.py
Expand Up @@ -162,6 +162,8 @@ def test_parse_dim(self, dim_spec, dim_poly):
dict(dim_spec=dim_spec)
for dim_spec in [
"b + a",
"b - a",
"b + 3*a",
"a*b + a^2 + b + a",
"mod(a, 4) + floordiv(a, 4) + a",
"2*a^2 - 3*a - 1",
Expand Down Expand Up @@ -285,6 +287,8 @@ def test_monomial_ordering(self):
self.assertTrue(b.to_monomial() >= a.to_monomial())
self.assertTrue(b.to_monomial() > a.to_monomial())

self.assertTrue(((3 * b) // a).to_monomial() >= ((2 * b) // a).to_monomial())
self.assertTrue(((3 * b) // a).to_monomial() >= ((4 * a) // b).to_monomial())
self.assertTrue(a.to_monomial() < (a * a).to_monomial())
self.assertTrue(b.to_monomial() < (a * a).to_monomial())
self.assertTrue((a * a * b).to_monomial() < (a * b * b).to_monomial())
Expand All @@ -295,12 +299,12 @@ def test_monomial_ordering(self):
self.assertSequenceEqual(sorted_e1,
[a * b * b, a * a * b, a * b, a * a, b, a, 2])

e2 = a * (a // 4) + (a // 4) + b * (a // 4) + b * (a % 4) + a * a + b
e2 = a * (a // 4) + (a // 4) + b * (a // 4) + b * (a % 4) + a * a + b + 15
sorted_e2 = [shape_poly._DimExpr.from_monomial(m, m_count)
for m, m_count in e2.monomials()]
self.assertSequenceEqual(sorted_e2,
[b * (a % 4), b * (a // 4), a * (a // 4), a * a,
a // 4, b])
[b * (a % 4), b * (a // 4), a * (a // 4), a // 4,
a * a, b, 15])

# This failed with a previous implementation of atom equality
self.assertNotEqual(shape_poly._DimMon.from_operation(shape_poly._DimAtom.NON_NEGATIVE,
Expand Down Expand Up @@ -581,7 +585,7 @@ def test_poly_int_results(self):
(a * a - b * b, a + b, a - b, 0),
(a, b, "floordiv(a, b)", "mod(a, b)"),
(3 * a, 2, "floordiv(3*a, 2)", "mod(3*a, 2)"),
(2 * a * b + b * b, a + b, "floordiv(b^2 + 2*a*b, b + a)", "mod(b^2 + 2*a*b, b + a)"),
(2 * a * b + b * b, a + b, "floordiv(2*a*b + b^2, b + a)", "mod(2*a*b + b^2, b + a)"),
(3, a, "floordiv(3, a)", "mod(3, a)"),
]])
def test_poly_divmod(self, *, dividend, quotient, divisor, remainder):
Expand Down

0 comments on commit b7f82e8

Please sign in to comment.