Skip to content

Commit

Permalink
[shape_polyO] Performance improvements for symbolic dimension manipul…
Browse files Browse the repository at this point in the history
…ations (step 2)

We make the following improvements:

  * Cache the state of the decision procedure after we process the explicit
    constraints, and reuse it for new decisions.
  * Rationalize the usage of add_implicit_constraints. We used to call it
    conservatively, too often. Now we call it only once for each explicit constraint,
    and once for each bounds decision we make. Then, in the add_implicit_constraints
    we call it recursively when we encounter new sub-expressions.
  * Eliminate some usage of __str__ for symbolic expressions in combine_and_add_constraints
    since we should only need it for reporting error messages.

This speeds up inequality reasoning:

Before:
```
In [1]:     from jax.experimental import export
   ...:     from jax import core
   ...:     a, b, c = export.symbolic_shape("a, b, c", constraints=["a >= 3", "a <= 5", "b >= 8"])

In [2]: %timeit a >= b
109 µs ± 637 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
442 µs ± 2.22 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
```

After:
```
In [2]: %timeit a >= b
11.7 µs ± 27.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [3]: %timeit core.max_dim(a, c) >= a - c
34.8 µs ± 175 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)
```
  • Loading branch information
gnecula committed Feb 15, 2024
1 parent c55f187 commit eb9caf0
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 50 deletions.
10 changes: 8 additions & 2 deletions jax/experimental/export/_shape_poly.py
Expand Up @@ -666,7 +666,7 @@ def _one_monomial(mon, c):
# We print first the "larger" monomials, so that the constant is last.
res = " ".join(_one_monomial(mon, c)
for mon, c in self._monomials_sorted)
if res[0:2] == "+ ":
if res.startswith("+ "):
res = res[2:]
return res

Expand Down Expand Up @@ -953,7 +953,7 @@ def __init__(self,
raise ValueError(
"The symbolic constraints should be a sequence of strings. "
f"Got {repr(constraints_str)}")

self._initialized = False
self._location_frame = source_info_util.user_frame(source_info_util.current())
# Keep the explicit constraints in the order in which they were added
self._explicit_constraints: list[_SymbolicConstraint] = []
Expand All @@ -964,6 +964,11 @@ def __init__(self,
# bounds precision with which we computed the cached result.
self._bounds_cache: dict[_DimExpr,
tuple[float, float, BoundsPrecision]] = {}

# We store here a decision procedure state initialized with all the
# _explicit_constraints.
self._decision_initial_state: Any | None = None

# We turn the equality constraints into normalization rules.
# For an explicit constraint `t*tk == e`, we keep
# `_normalization_rules[t] = (e, tk)`.
Expand All @@ -974,6 +979,7 @@ def __init__(self,
for c_str in constraints_str:
self._parse_and_process_explicit_constraint(c_str)
self._bounds_cache.clear()
self._initialized = True

def __str__(self) -> str:
extras = []
Expand Down
125 changes: 83 additions & 42 deletions jax/experimental/export/_shape_poly_decision.py
Expand Up @@ -17,11 +17,9 @@

from __future__ import annotations

import collections
from collections.abc import Sequence
import itertools
import math
from typing import Callable

import numpy as np

Expand All @@ -42,8 +40,8 @@ def bounds_decision(e: DimSize,
prec: BoundsPrecision) -> tuple[float, float]:
if not isinstance(e, _DimExpr):
return (int(e), int(e))
decision = _DecisionByElimination(e.scope)
return decision.bounds(e, prec)
decision = _DecisionByElimination.build(e.scope)
return decision.bounds(e, prec, add_implicit_constraints=True)

_shape_poly._bounds_decision = bounds_decision

Expand All @@ -65,6 +63,7 @@ class _DecisionByElimination:
then `abs(m_c)*e <= e0`, hence, `UB(e) <= floor(UB(e0) / abs(m_c))`,
See the implementation in self.combine_term_with_existing.
Do not use the constructor directly, use the `build` static method.
"""
def __init__(self, scope: SymbolicScope):
self.scope = scope
Expand All @@ -76,28 +75,52 @@ def __init__(self, scope: SymbolicScope):
# just simple terms. The set is represented as a mapping from a
# term "t" to tuples (cmp, k, c) where "c >= 0" (if cmp is GEQ else "c == 0")
# represents a constraint that has "t" as the leading term with coefficient "k".
self._expr_constraints: dict[_DimMon, set[tuple[Comparator, int, _DimExpr]]] = collections.defaultdict(set)
self._expr_constraints: dict[_DimMon, set[tuple[Comparator, int, _DimExpr]]] = {}

def initialize(self) -> _DecisionByElimination:
# Process the explicit constraints in the order in which the user specifies
# them. This is because the heuristics depend on the order in which the
# constraints are processed, and this way we give the user a way to control
# the result (albeit, for now, without a good feedback loop to understand
# how the order matters for inequalities).
for constr in scope._explicit_constraints:
for constr in self.scope._explicit_constraints:
self.add_implicit_constraints_expr(constr.diff)
# The equality constraints are not needed for inequality decisions,
# because the LHS should always be rewritten in terms of the RHS.
# In fact, adding them may break the assumption that if we eliminate
# the leading term we end up with only smaller terms, because the LHS
# may appear in the rest and may be rewritten to something larger.
# However, we want to add the implicit constraints within.
if constr.cmp == Comparator.GEQ:
self.combine_and_add_constraint(constr.cmp, constr.diff, 0, constr.debug_str)
else:
# The equality constraints are not needed for inequality decisions,
# because the LHS should always be rewritten in terms of the RHS.
# In fact, adding them may break the assumption that if we eliminate
# the leading term we end up with only smaller terms, because the LHS
# may appear in the rest and may be rewritten to something larger.
# However, we want to add the implicit constraints within.
for m, _ in constr.diff.monomials():
if m.degree == 0: continue
self.add_implicit_constraints(m)
self.combine_and_add_constraint(constr.cmp, constr.diff, 0,
constr.debug_str)


# Clear the cache, since we have added constraints.
self.scope._bounds_cache.clear()
return self

@staticmethod
def build(scope: SymbolicScope) -> _DecisionByElimination:
"""Builds an initialized DecisionByElimination for a scope.
Caches the initial state of the decision procedure in the scope.
"""
if not scope._initialized or not scope._explicit_constraints:
# We do not cache until the scope is fully initialized.
return _DecisionByElimination(scope).initialize()

if not scope._decision_initial_state:
scope._decision_initial_state = _DecisionByElimination(scope).initialize()
d = scope._decision_initial_state
# Return a copy, because the decision procedure state is mutable
c = _DecisionByElimination(scope)
c._processed_for_internal_constraints = d._processed_for_internal_constraints.copy()
c._term_bounds = d._term_bounds.copy()
c._expr_constraints = {
lead_t: lead_t_constraints.copy()
for lead_t, lead_t_constraints in d._expr_constraints.items()}
return c

def combine_and_add_constraint(self,
cmp: Comparator,
Expand All @@ -114,11 +137,9 @@ def combine_and_add_constraint(self,
assert e2 == np.floor(e2)
e2 = int(e2)
e = e1 - e2
if debug_str is None:
debug_str = f"{e1} >= {e2}"
if (const := _DimExpr.to_constant(e)) is not None:
if const < 0:
raise ValueError(f"Unsatisfiable constraint: {debug_str}")
raise ValueError(f"Unsatisfiable constraint: {debug_str or str(e1) + ' >= ' + str(e2)}")
return
assert isinstance(e, _DimExpr)
# TODO: we only really need to add the implicit constraints now, else
Expand All @@ -127,17 +148,14 @@ def combine_and_add_constraint(self,
self.add_to_state(cmp, e, debug_str)
geq_combinations = self.combine_constraint_with_existing(cmp, e, debug_str)
for cmp, a in geq_combinations:
self.add_to_state(cmp, a, f"{a} >= 0")
self.add_to_state(cmp, a, None)

def add_to_state(self,
cmp: Comparator,
e: _DimExpr,
debug_str: str):
debug_str: str | None):
"""Updates the internal state to reflect "e >= 0". """
assert _DimExpr.to_constant(e) is None
for m, m_c in e.monomials():
if m.degree == 0: continue
self.add_implicit_constraints(m)

if (mon_factors := e.to_single_term()) is not None:
n, mon_c, mon = mon_factors # n + mon * mon_c [== | >=] 0
Expand All @@ -161,7 +179,11 @@ def add_to_state(self,
return

lead_t, lead_t_k = e.leading_term
self._expr_constraints[lead_t].add((cmp, lead_t_k, e))
lead_t_constraints = self._expr_constraints.get(lead_t)
if lead_t_constraints is None:
lead_t_constraints = set()
self._expr_constraints[lead_t] = lead_t_constraints
lead_t_constraints.add((cmp, lead_t_k, e))

def combine_term_with_existing(self, t: _DimMon, t_k: int, *,
scope: _shape_poly.SymbolicScope,
Expand Down Expand Up @@ -197,7 +219,8 @@ def combine_term_with_existing(self, t: _DimMon, t_k: int, *,
acc.append((Comparator.GEQ, _DimExpr(((t, -1),), scope) + int(t_ub),
abs(t_k), sgn(t_k)))

for prev_constraint in ([self._expr_constraints[t]] if only_smaller_than_t
prev_constraint: set[tuple[Comparator, int, _DimExpr]]
for prev_constraint in ([self._expr_constraints.get(t, set())] if only_smaller_than_t
else self._expr_constraints.values()):
for c_eq, _, c in prev_constraint:
# TODO: optimize this dict()
Expand All @@ -212,7 +235,7 @@ def combine_term_with_existing(self, t: _DimMon, t_k: int, *,
def combine_constraint_with_existing(self,
eq: Comparator,
e: _DimExpr,
debug_str: str) -> set[tuple[Comparator, _DimExpr]]:
debug_str: str | None) -> set[tuple[Comparator, _DimExpr]]:
combinations: set[tuple[Comparator, _DimExpr]] = set()
for t, t_k in e._monomials_sorted:
if t.degree == 0: continue
Expand All @@ -226,17 +249,22 @@ def combine_constraint_with_existing(self,
if (const := _DimExpr.to_constant(new_e)) is not None:
if ((new_eq == Comparator.GEQ and const < 0) or
(new_eq == Comparator.EQ and const != 0)):
raise ValueError(f"Unsatisfiable constraints: {debug_str}")
raise ValueError(f"Unsatisfiable constraints: {debug_str or str(e) + ' >= 0'}")
else:
combinations.add((new_eq, new_e)) # type: ignore
return combinations

def bounds(self, e: DimSize,
prec: BoundsPrecision
prec: BoundsPrecision,
add_implicit_constraints: bool = False
) -> tuple[float, float]:
"""Returns the lower and upper bounds, or -+inf.
See more details in `_shape_poly.bounds_decision`.
Args:
e: the expression for which to compute the bounds.
prec: the desired precision. See comments in `BoundsPrecision`.
add_implicit_constraints: if True, then before computing the bounds
add the implicit constraints for the terms inside `e`.
"""
if (const := _DimExpr.to_constant(e)) is not None:
return (const, const)
Expand All @@ -251,6 +279,9 @@ def bounds(self, e: DimSize,
lb, ub, prev_prec = res
if prec._bounds_are_sufficient(lb, ub): return (lb, ub)
if prev_prec.value >= prec.value: return (lb, ub)

if add_implicit_constraints:
self.add_implicit_constraints_expr(e)
lb, ub = self._bounds_for_sorted_terms(e.scope, e._monomials_sorted, 0, prec)
lb, ub = (int(lb) if lb > -np.inf else lb,
int(ub) if ub < np.inf else ub)
Expand All @@ -274,7 +305,6 @@ def _bounds_for_sorted_terms(self,
assert i == len(e) - 1 # Must be last
return (t_k, t_k)

self.add_implicit_constraints(t)
lb = -np.inf
ub = np.inf

Expand Down Expand Up @@ -327,8 +357,13 @@ def _bounds_for_sorted_terms(self,

return lb, ub

def add_implicit_constraints(self: _DecisionByElimination, m: _DimMon):
"""Adds the internal constraints for the monomial `m`."""
def add_implicit_constraints_expr(self, e: _DimExpr):
"""Adds the implicit constraints for the expression `e`"""
for m, _ in e.monomials():
if m.degree == 0: continue
self.add_implicit_constraints_term(m)

def add_implicit_constraints_term(self, m: _DimMon):
if m in self._processed_for_internal_constraints: return
self._processed_for_internal_constraints.add(m)
m_e = _DimExpr.from_monomial(m, 1, self.scope) # m as a _DimExpr
Expand All @@ -337,12 +372,13 @@ def add_implicit_constraints(self: _DecisionByElimination, m: _DimMon):
# This is a multiplication of atoms. Try to compute bounds based on
# the bounds of the atoms.
bounds = []
for a, exp in m.items():
a_l, a_u = self.bounds(_DimExpr.from_monomial(_DimMon.from_atom(a, 1),
1, self.scope),
BoundsPrecision.BEST)
assert a_l <= a_u
bounds.append((a_l ** exp, a_u ** exp))
for a1, a1_exp in m.items():
a1_t = _DimMon.from_atom(a1, 1)
a1_e = _DimExpr.from_monomial(a1_t, 1, self.scope)
self.add_implicit_constraints_term(a1_t)
a1_l, a1_u = self.bounds(a1_e, BoundsPrecision.BEST)
assert a1_l <= a1_u
bounds.append((a1_l ** a1_exp, a1_u ** a1_exp))

candidate_bounds = [math.prod(atom_bounds)
for atom_bounds in itertools.product(*bounds)]
Expand All @@ -354,10 +390,12 @@ def add_implicit_constraints(self: _DecisionByElimination, m: _DimMon):

# It is an atom, is it a variable?
if (v := a.to_var()) is not None:
self.combine_and_add_constraint(Comparator.GEQ, m_e, 1,
debug_str=f"{v} >= 1") # v >= 1
self.combine_and_add_constraint(Comparator.GEQ, m_e, 1) # v >= 1
return

for oper in a.operands:
self.add_implicit_constraints_expr(oper)

if a.operation == _DimAtom.MOD:
op1, op2 = a.operands
op2_b_l, op2_b_u = self.bounds(op2, BoundsPrecision.FOR_GEQ0_OR_LT0)
Expand Down Expand Up @@ -405,6 +443,7 @@ def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
self.combine_and_add_constraint(Comparator.GEQ, m_e, 0)
mod_e = _DimExpr.from_operation(_DimAtom.MOD, op1, op2,
scope=self.scope)
self.add_implicit_constraints_expr(mod_e)
combined = op2 * m_e + mod_e
self.combine_and_add_constraint(Comparator.EQ, op1, combined)
return
Expand All @@ -417,6 +456,7 @@ def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
self.combine_and_add_constraint(Comparator.GEQ, max(op1_b_u, op2_b_u), m_e)
self.combine_and_add_constraint(Comparator.GEQ, m_e, op1)
self.combine_and_add_constraint(Comparator.GEQ, m_e, op2)
return

if a.operation == _DimAtom.MIN:
op1, op2 = a.operands
Expand All @@ -426,3 +466,4 @@ def math_floor_with_inf(a: float, b: float): # math.floor, but aware of inf
self.combine_and_add_constraint(Comparator.GEQ, min(op1_b_u, op2_b_u), m_e)
self.combine_and_add_constraint(Comparator.GEQ, op1, m_e)
self.combine_and_add_constraint(Comparator.GEQ, op2, m_e)
return
10 changes: 4 additions & 6 deletions tests/shape_poly_test.py
Expand Up @@ -79,8 +79,7 @@ def _bounds(e: shape_poly.DimSize) -> tuple[float, float]:
scope = e.scope
else:
scope = shape_poly.SymbolicScope()
decision = shape_poly_decision._DecisionByElimination(scope)
return decision.bounds(e, shape_poly.BoundsPrecision.BEST)
return shape_poly._bounds_decision(e, shape_poly.BoundsPrecision.BEST)

def _assert_equal_bounds(tst: jtu.JaxTestCase,
e: shape_poly.DimSize,
Expand Down Expand Up @@ -725,7 +724,7 @@ def test_unit_combine_term_with_constraints(self):
def _m(e: shape_poly._DimExpr) -> shape_poly._DimMon:
return e.to_monomial()
Comparator = shape_poly.Comparator
decision = shape_poly_decision._DecisionByElimination(scope)
decision = shape_poly_decision._DecisionByElimination(scope).initialize()

self.assertSetEqual(
set(),
Expand Down Expand Up @@ -1085,8 +1084,7 @@ def test_constraints_a_minus_4d_eq(self):
scope1 = shape_poly.SymbolicScope(assumptions1)
a1, d1, m1 = shape_poly.symbolic_shape("a1, d1, m1", scope=scope1)
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a1 - 4*d1),
_expect(best=(1, 3), current=(1, 3))) # a - 4d = m >= 1
self.assertEqual(_bounds(a1 - 4*d1), (1, 3)) # a - 4d = m >= 1
self.assertEqual(_bounds(a1 - 2*d1), (3, np.inf)) # a - 2d = m + 2d >= 3
# TODO: The incompleteness is due to the way we combine external constraints
self.assertEqual(_bounds(a1),
Expand Down Expand Up @@ -1611,7 +1609,7 @@ def test_constraints_for_profile(self):
# performance
def f(x): # x: i32[a, b]
acc = 0
for start in range(0, 10):
for start in range(0, 50):
slice = x[start::2] # exercises floordiv and min
acc += jnp.sum(slice, axis=0)

Expand Down

0 comments on commit eb9caf0

Please sign in to comment.