Skip to content

Commit

Permalink
[remove-units] remove units from api_util.py
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 26, 2022
1 parent 5d68280 commit e7acb82
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 42 deletions.
8 changes: 4 additions & 4 deletions jax/_src/ad_util.py
Expand Up @@ -15,8 +15,8 @@
from typing import Any, Callable, Dict, Type

from jax import core
from jax.core import (lattice_join, Primitive, Unit, unit, AbstractUnit,
valid_jaxtype, raise_to_shaped, get_aval)
from jax.core import (lattice_join, Primitive, valid_jaxtype, raise_to_shaped,
get_aval)
from jax.tree_util import register_pytree_node
from jax._src.util import safe_map

Expand All @@ -28,7 +28,7 @@
map = safe_map

jaxval_adders: Dict[type, Callable] = {}
jaxval_adders[Unit] = lambda _, __: unit
jaxval_adders[core.Unit] = lambda _, __: core.unit

def add_jaxvals(x, y):
if core.get_aval(x) is core.abstract_unit is core.get_aval(y):
Expand All @@ -53,7 +53,7 @@ def zeros_like_aval(aval):
return aval_zeros_likers[type(aval)](aval)

aval_zeros_likers: Dict[Type[core.AbstractValue], Array] = {}
aval_zeros_likers[AbstractUnit] = lambda _: unit
aval_zeros_likers[core.AbstractUnit] = lambda _: core.unit

def zeros_like_jaxval(val):
return zeros_like_p.bind(val)
Expand Down
40 changes: 18 additions & 22 deletions jax/_src/api_util.py
Expand Up @@ -26,7 +26,6 @@
from jax._src.tree_util import _replace_nones
from jax import linear_util as lu
from jax._src.util import safe_map, WrapKwArgs, Hashable, Unhashable
from jax.core import unit

from jax._src import traceback_util
traceback_util.register_exclusion(__file__)
Expand Down Expand Up @@ -136,19 +135,19 @@ def __eq__(self, other):

def argnums_partial(f, dyn_argnums, args, require_static_args_hashable=True):
dyn_argnums = _ensure_index_tuple(dyn_argnums)
fixed_args = [unit] * len(args)
for i, arg in enumerate(args):
if i in dyn_argnums: continue
if require_static_args_hashable:
if require_static_args_hashable:
fixed_args = []
for i, arg in enumerate(args):
if i in dyn_argnums: continue
if not is_hashable(arg):
raise ValueError(
"Non-hashable static arguments are not supported, as this can lead "
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(arg)} for function {f.__name__} is non-hashable.")
fixed_args[i] = _HashableWithStrictTypeEquality(arg)
else:
fixed_args[i] = Unhashable(arg)

fixed_args.append(_HashableWithStrictTypeEquality(arg))
else:
fixed_args = [Unhashable(arg) for i, arg in enumerate(args)
if i not in dyn_argnums]
dyn_args = tuple(args[i] for i in dyn_argnums)
return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args

Expand All @@ -160,10 +159,9 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
dyn_argnums = tuple(i for i in range(len(args)) if i not in static_argnums)
dyn_args = tuple(args[i] for i in dyn_argnums)

fixed_args = [unit] * len(args) # type: ignore
fixed_args = []
for i in static_argnums:
# TODO(shoyer): set allow_invalid=True permanently after enabling
# static_argnames.
# TODO(shoyer): set allow_invalid=True permanently after static_argnames.
if allow_invalid and i >= len(args):
continue
static_arg = args[i]
Expand All @@ -175,16 +173,19 @@ def argnums_partial_except(f: lu.WrappedFun, static_argnums: Tuple[int, ...],
f"to unexpected cache-misses. Static argument (index {i}) of type "
f"{type(static_arg)} for function {f.__name__} is non-hashable.")
else:
fixed_args[i] = _HashableWithStrictTypeEquality(static_arg) # type: ignore
fixed_args.append(_HashableWithStrictTypeEquality(static_arg)) # type: ignore

return _argnums_partial(f, dyn_argnums, tuple(fixed_args)), dyn_args


@lu.transformation
def _argnums_partial(dyn_argnums, fixed_args, *dyn_args, **kwargs):
args = [None if arg is unit else arg.val for arg in fixed_args]
sentinel = object()
args = [sentinel] * (len(fixed_args) + len(dyn_args))
for i, arg in zip(dyn_argnums, dyn_args):
args[i] = arg
fixed_args_ = iter(fixed_args)
args = [next(fixed_args_).val if x is sentinel else x for x in args]
assert next(fixed_args_, sentinel) is sentinel
ans = yield args, kwargs
yield ans

Expand All @@ -197,9 +198,7 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],

fixed_kwargs: Dict[str, Any] = {}
for k, arg in kwargs.items():
if k in dyn_kwargs:
fixed_kwargs[k] = unit
else:
if k not in dyn_kwargs:
try:
hash(arg)
except TypeError:
Expand All @@ -212,12 +211,9 @@ def argnames_partial_except(f: lu.WrappedFun, static_argnames: Tuple[str, ...],

return _argnames_partial(f, WrapKwArgs(fixed_kwargs)), dyn_kwargs


@lu.transformation
def _argnames_partial(fixed_kwargs: WrapKwArgs, *args, **dyn_kwargs):
kwargs = {k: None if arg is unit else arg.val
for k, arg in fixed_kwargs.val.items()}
kwargs.update(dyn_kwargs)
kwargs = dict({k: v.val for k, v in fixed_kwargs.val.items()}, **dyn_kwargs)
ans = yield args, kwargs
yield ans

Expand Down
33 changes: 17 additions & 16 deletions jax/interpreters/partial_eval.py
Expand Up @@ -39,10 +39,10 @@
merge_lists, partition_list, OrderedSet,
as_hashable_function, weakref_lru_cache)
from jax.core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue,
unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
ConcreteArray, raise_to_shaped, Var, DropVar, Atom,
JaxprEqn, Primitive, ShapedArray, DShapedArray,
AbstractBInt, mapped_aval, unmapped_aval)
ClosedJaxpr, new_jaxpr_eqn, ConcreteArray,
raise_to_shaped, Var, DropVar, Atom, JaxprEqn, Primitive,
ShapedArray, DShapedArray, AbstractBInt, mapped_aval,
unmapped_aval)
from jax._src import source_info_util
from jax.config import config

Expand Down Expand Up @@ -125,7 +125,7 @@ def sublift(self, val) -> JaxprTracer:
def new_const(self, val) -> JaxprTracer:
if isinstance(val, Tracer) and val._trace.level == self.level:
raise Exception
return JaxprTracer(self, PartialVal.known(val), unit)
return JaxprTracer(self, PartialVal.known(val), core.unit)

def new_instantiated_literal(self, val) -> JaxprTracer:
aval = get_aval(val)
Expand Down Expand Up @@ -721,7 +721,7 @@ def tracers_to_jaxpr(
def getvar(t: JaxprTracer) -> Atom:
var = t_to_var.get(id(t))
if var is None:
aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit
aval = t.pval.get_aval() if not t.pval.is_known() else core.abstract_unit
var = t_to_var[id(t)] = newvar(aval)
return var
sorted_tracers = toposort(out_tracers)
Expand Down Expand Up @@ -754,8 +754,8 @@ def getconstvar(c):
consts[v] = recipe.val
elif isinstance(recipe, Literal):
t_to_var[id(t)] = recipe
elif recipe is unit:
t_to_var[id(t)] = unitvar
elif recipe is core.unit:
t_to_var[id(t)] = core.unitvar
else:
raise TypeError(recipe)

Expand Down Expand Up @@ -790,7 +790,7 @@ def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int) -> Jaxpr:


def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]:
return (abstract_unit, aval) if unknown else (aval, abstract_unit)
return (core.abstract_unit, aval) if unknown else (aval, core.abstract_unit)


def partial_eval_jaxpr_nounits(
Expand Down Expand Up @@ -952,7 +952,7 @@ def _remat_partial_eval(trace, _, f, tracers, params):
out_consts = core.eval_jaxpr(jaxpr_known, (), *in_consts)
out_consts_ = iter(out_consts)
# reconstruct known outs, inserting units
outs1 = [pval.get_known() if x.aval is abstract_unit else next(out_consts_)
outs1 = [pval.get_known() if x.aval is core.abstract_unit else next(out_consts_)
for uk, pval, x in zip(out_unknowns, eval_out_pvals, jaxpr.outvars)
if not uk]
# form known outputs and collect residual tracers
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom:
return x

known_eqns, staged_eqns = [], []
write(False, True, unitvar)
write(False, True, core.unitvar)
map(write, in_unknowns, [True] * len(in_unknowns), jaxpr.invars)
for eqn in jaxpr.eqns:
unks_in, inst_in = unzip2(map(read, eqn.invars))
Expand All @@ -1081,7 +1081,7 @@ def ensure_instantiated(inst: bool, x: Atom) -> Atom:

ins_known, _ = partition_list(in_unknowns, jaxpr.invars)
outs_known_, _ = partition_list(out_unknowns, jaxpr.outvars)
outs_known = [x for x in outs_known_ if x.aval is not abstract_unit]
outs_known = [x for x in outs_known_ if x.aval is not core.abstract_unit]
known_effects = core.join_effects(*(eqn.effects for eqn in known_eqns))
jaxpr_known = Jaxpr((), ins_known, [*outs_known, *residuals], known_eqns,
known_effects)
Expand Down Expand Up @@ -1254,7 +1254,7 @@ def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False)
if drop_outputs:
new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output]
else:
new_outvars = [var if output else unitvar
new_outvars = [var if output else core.unitvar
for var, output in zip(jaxpr.outvars, outputs)]

needed_vars = {v for v in new_outvars if type(v) is not Literal}
Expand Down Expand Up @@ -2139,7 +2139,7 @@ def read(x):
def write(v, val) -> None:
env[v] = val

write(unitvar, unit)
write(core.unitvar, core.unit)
map(write, jaxpr.constvars, consts)
map(write, jaxpr.invars, args)
for eqn in jaxpr.eqns:
Expand Down Expand Up @@ -2219,7 +2219,8 @@ def fun(*vals):

# For jaxpr_known we pass core.unit for the unknown inputs, and known
# PartialVal for the known inputs.
in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)]
in_avals = [core.abstract_unit if uk else a
for a, uk in zip(jaxpr.in_avals, unknowns)]
jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals)
(out_pvs_2, jaxpr_2, num_res), = cell
assert len(jaxpr_2.constvars) == num_res
Expand All @@ -2232,7 +2233,7 @@ def fun(*vals):
jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res]
for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns):
if not unknown:
var.aval = abstract_unit
var.aval = core.abstract_unit

uk_out = [pv is not None for pv in out_pvs_2]

Expand Down

0 comments on commit e7acb82

Please sign in to comment.