Skip to content

Commit

Permalink
[remove-units] don't use abstract_unit for dropvar avals
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Apr 29, 2022
1 parent cdd1167 commit 477dfa6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 23 deletions.
40 changes: 17 additions & 23 deletions jax/interpreters/partial_eval.py
Expand Up @@ -554,7 +554,7 @@ def aval(self) -> AbstractValue:
@property
def parents(self) -> Sequence[JaxprTracer]:
if isinstance(self.recipe, JaxprEqnRecipe):
return self.recipe.invars
return self.recipe.in_tracers
else:
return []

Expand Down Expand Up @@ -655,53 +655,47 @@ def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer):
ConstVar = namedtuple('ConstVar', ['val'])
LambdaBinding = namedtuple('LambdaBinding', [])
class JaxprEqnRecipe(NamedTuple):
eqn_id: object
invars: Sequence[JaxprTracer]
outvars: Sequence[ref[JaxprTracer]]
eqn_id: Any
in_tracers: Sequence[JaxprTracer]
out_tracer_refs: Sequence[ref[JaxprTracer]]
out_avals: Sequence[core.AbstractValue]
primitive: Primitive
params: Dict[str, Any]
effects: core.Effects
source_info: source_info_util.SourceInfo

def new_eqn_recipe(invars: Sequence[JaxprTracer],
outvars: Sequence[JaxprTracer],
def new_eqn_recipe(in_tracers: Sequence[JaxprTracer],
out_tracers: Sequence[JaxprTracer],
primitive: Primitive,
params: Dict[str, Any],
effects: core.Effects,
source_info: source_info_util.SourceInfo
) -> JaxprEqnRecipe:
"""Constructs a new JaxEqnRecipe.
Params:
invars: the tracers for the primitive inputs.
outvars: the tracers for the primitive outputs.
primitive: the primitive.
params: the primitive params
"""
# TODO(necula): move these checks to core.check_jaxpr, and call in more places
if primitive.call_primitive or primitive.map_primitive:
assert "call_jaxpr" in params
# assert len(invars) == len(params["call_jaxpr"].invars) # TODO constvars?
assert len(outvars) == len(params["call_jaxpr"].outvars)
assert len(out_tracers) == len(params["call_jaxpr"].outvars)
assert ("donated_invars" not in params or
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
if primitive.map_primitive:
assert ("in_axes" in params and
len(params["in_axes"]) == len(params["call_jaxpr"].invars))
assert ("donated_invars" in params and
len(params["donated_invars"]) == len(params["call_jaxpr"].invars))
return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive,
params, effects, source_info)
out_avals = [core.raise_to_shaped(t.aval) for t in out_tracers]
return JaxprEqnRecipe(object(), tuple(in_tracers), map(ref, out_tracers),
out_avals, primitive, params, effects, source_info)


def recipe_to_eqn(getvar: Callable[[JaxprTracer], Atom],
recipe: JaxprEqnRecipe) -> core.JaxprEqn:
_, in_tracers, out_tracer_refs, primitive, params, effects, source_info = recipe
out_tracers = [t_ref() for t_ref in out_tracer_refs]
(_, in_tracers, out_tracer_refs, out_avals, prim, params, eff, src) = recipe
invars = [getvar(t) for t in in_tracers]
outvars = [DropVar(core.abstract_unit) if t is None
else cast(Var, getvar(t)) for t in out_tracers]
return new_jaxpr_eqn(invars, outvars, primitive, params, effects, source_info)
out_tracers = [t_ref() for t_ref in out_tracer_refs]
outvars = [DropVar(a) if t is None else getvar(t) # type: ignore
for a, t in zip(out_avals, out_tracers)]
return new_jaxpr_eqn(invars, outvars, prim, params, eff, src)

def tracers_to_jaxpr(
in_tracers: Sequence[JaxprTracer],
Expand Down Expand Up @@ -750,7 +744,7 @@ def getconstvar(c):
t, "Tracer not among input tracers {}".format(t))
assert in_tracers, "Lambda binding with no args"
elif isinstance(recipe, FreeVar):
env[cast(Var, getvar(t))] = recipe.val
env[getvar(t)] = recipe.val # type: ignore
elif isinstance(recipe, ConstVar):
v = t_to_var[id(t)] = getconstvar(recipe.val)
consts[v] = recipe.val
Expand Down
13 changes: 13 additions & 0 deletions tests/core_test.py
Expand Up @@ -343,6 +343,19 @@ def test_concrete_array_string_representation(self):
np.array([1], dtype=np.int32))),
'ConcreteArray([1], dtype=int32)')

def test_dropvar_avals(self):
def f(x):
def body(c, _):
return c, None
(x1, x2), _ = jax.lax.scan(body, (x, x), None, length=1)
return [x2]

aval = core.ShapedArray((), jnp.dtype('int32'))
pval = pe.PartialVal.unknown(aval)
jaxpr, _, _ = pe.trace_to_jaxpr_nounits(lu.wrap_init(f), [pval], False)
dropvar, b = jaxpr.eqns[0].outvars
self.assertEqual(dropvar.aval, aval)


class JaxprTypeChecks(jtu.JaxTestCase):

Expand Down

0 comments on commit 477dfa6

Please sign in to comment.