Skip to content

Commit

Permalink
[remove-units] remove units from partial_eval.py
Browse files Browse the repository at this point in the history
After last week's changes, units are no longer traced or introduced into jaxprs
in any way, so we don't need to use them in partial evaluation.

(Also there are some unrelated removals of dead code in maps.py.)
  • Loading branch information
mattjj committed May 2, 2022
1 parent 44006c7 commit 11ad045
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 30 deletions.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Expand Up @@ -1512,7 +1512,7 @@ def _get_axis_size(name: str, shape: Tuple[int, ...], axis: int):
tree, leaf = treedef_children(tree)
assert treedef_is_leaf(leaf)
# TODO(mattjj,phawkins): add a way to inspect pytree kind more directly
if tree == tree_flatten((core.unit,) * tree.num_leaves)[1]:
if tree == tree_flatten((0,) * tree.num_leaves)[1]:
lines1 = [f"arg {i} has shape {np.shape(x)} and axis {d} is to be mapped"
for i, (x, d) in enumerate(zip(vals, dims))]
sizes = collections.defaultdict(list)
Expand Down
19 changes: 0 additions & 19 deletions jax/experimental/maps.py
Expand Up @@ -1106,25 +1106,6 @@ def out_local_named_shapes(local_axes, *args, **kwargs):
ans_axes = [frozenset(a.aval.named_shape) & local_axes for a in ans]
yield ans, ans_axes

@lu.transformation_with_aux
def hide_units(unit_args, *args, **kwargs):
ans = yield restore_units(unit_args, args), kwargs
yield filter_units(ans)

def filter_units(vals):
vals_no_units = [v for v in vals if v is not core.unit]
vals_is_unit = [v is core.unit for v in vals]
return vals_no_units, vals_is_unit

def restore_units(is_unit, vals):
vals_it = iter(vals)
vals_with_units = [core.unit if u else next(vals_it) for u in is_unit]
try:
next(vals_it)
raise RuntimeError("Expected the iterator to be exhausted")
except StopIteration:
return vals_with_units


def _jaxpr_trace_process_xmap(self, primitive, f: lu.WrappedFun, tracers, params):
assert primitive is xmap_p
Expand Down
20 changes: 10 additions & 10 deletions jax/interpreters/partial_eval.py
Expand Up @@ -62,20 +62,20 @@ class PartialVal(tuple):
"""Partial value: either a known value or an unknown (abstract) value.
Represented as a pair `(aval_opt, const)` of one of two kinds:
* `(None, <Constant>)` indicates a known value, either a Python regular
value, or a Tracer.
* `(<AbstractValue>, *)` indicates an unknown value characterized by an
* `(None, <Constant>)` indicates a known value, where the constant is either a
Tracer or satisfies `core.valid_jaxtype(const)`;
* `(<AbstractValue>, None)` indicates an unknown value characterized by an
abstract value.
"""
def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]):
pv, const = xs
if config.jax_enable_checks:
# type checks
assert isinstance(pv, (AbstractValue, type(None))), xs
assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs
assert (const is None or isinstance(const, core.Tracer) or
core.valid_jaxtype(const)), const
# invariant checks
if isinstance(pv, AbstractValue):
assert get_aval(const) == core.abstract_unit, xs
assert (pv is None) ^ (const is None)
return tuple.__new__(cls, xs)

@classmethod
Expand All @@ -84,7 +84,7 @@ def known(cls, const: core.Value) -> PartialVal:

@classmethod
def unknown(cls, aval: AbstractValue) -> PartialVal:
return PartialVal((aval, core.unit))
return PartialVal((aval, None))

def is_known(self) -> bool:
return self[0] is None
Expand Down Expand Up @@ -120,7 +120,7 @@ def sublift(self, val) -> JaxprTracer:
def new_const(self, val) -> JaxprTracer:
if isinstance(val, Tracer) and val._trace.level == self.level:
raise Exception
return JaxprTracer(self, PartialVal.known(val), core.unit)
return JaxprTracer(self, PartialVal.known(val), None)

def new_instantiated_literal(self, val) -> JaxprTracer:
aval = get_aval(val)
Expand Down Expand Up @@ -742,8 +742,8 @@ def getconstvar(c):
consts[v] = recipe.val
elif isinstance(recipe, Literal):
t_to_var[id(t)] = recipe
elif recipe is core.unit:
t_to_var[id(t)] = core.unitvar
elif recipe is None:
assert False
else:
raise TypeError(recipe)

Expand Down

0 comments on commit 11ad045

Please sign in to comment.