From 1e7ca8f77a739b93b156506bd6eb7b13aa504781 Mon Sep 17 00:00:00 2001 From: Matthew Johnson Date: Fri, 23 Sep 2022 14:21:18 -0700 Subject: [PATCH] fix bug in djax type signature inference logic Co-authored-by: Sharad Vikram --- jax/interpreters/partial_eval.py | 14 ++++++++------ jax/linear_util.py | 19 +++++++++++++++---- tests/dynamic_api_test.py | 6 ++++++ 3 files changed, 29 insertions(+), 10 deletions(-) diff --git a/jax/interpreters/partial_eval.py b/jax/interpreters/partial_eval.py index 793de35384ae..bfacfe16bd43 100644 --- a/jax/interpreters/partial_eval.py +++ b/jax/interpreters/partial_eval.py @@ -2095,7 +2095,9 @@ def infer_lambda_input_type( idxs, implicit_types = _collect_implicit(args, specs) implicit_sig = [(ty, False) for ty in implicit_types] explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)] - return (*implicit_sig, *explicit_sig) + input_type = (*implicit_sig, *explicit_sig) + lu._check_input_type(input_type) + return input_type def _canonicalize_specs( ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]] @@ -2143,6 +2145,7 @@ def _complete_specs( for x, spec in zip(args, specs)) return specs + def _collect_implicit( args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]] ) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]: @@ -2153,24 +2156,23 @@ def _collect_implicit( idxs: Dict[AbstractedAxisName, DBIdx] = {} implicit_types: List[AbstractValue] = [] explicit_tracers: Dict[TracerId, int] = {} - counter = (DBIdx(i) for i in it.count()) + counter = it.count() # Add implicit arguments to idxs. - for explicit_idx, (x, spec) in enumerate(zip(args, specs)): for i, name in spec.items(): if name not in idxs and id(x.shape[i]) not in explicit_tracers: - idxs[name] = next(counter) + idxs[name] = DBIdx(next(counter)) implicit_types.append(raise_to_shaped(get_aval(x.shape[i]))) if isinstance(x, Tracer): - explicit_tracers[id(x)] = explicit_idx + explicit_tracers.setdefault(id(x), explicit_idx) # use the first # Now that we know the implicit args, add explicit args to idxs. offset = len(implicit_types) for x, spec in zip(args, specs): for i, name in spec.items(): if id(x.shape[i]) in explicit_tracers: - idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])]) + idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])])) return idxs, implicit_types diff --git a/jax/linear_util.py b/jax/linear_util.py index 1d4599a38216..316b6f54aaa4 100644 --- a/jax/linear_util.py +++ b/jax/linear_util.py @@ -236,17 +236,29 @@ def wrap_init(f, params=None) -> WrappedFun: params = () if params is None else tuple(sorted(params.items())) return WrappedFun(f, (), (), params, None) -def annotate(f: WrappedFun, - in_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]] - ) -> WrappedFun: + +def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun: assert f.in_type is None if in_type is None: return f + _check_input_type(in_type) + return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type) + +def _check_input_type(in_type: core.InputType) -> None: + # Check that in_type is syntactically well-formed assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and all(isinstance(a, core.AbstractValue) and type(b) is bool and not isinstance(a, core.ConcreteArray) for a, b in in_type) and all(isinstance(d, (int, core.BInt, core.DBIdx)) for a, _ in in_type if type(a) is core.DShapedArray for d in a.shape)) + + # Check that all DBIdx point to positions to the left of the input on which + # they appear. + assert all(d.val < i for i, (aval, _) in enumerate(in_type) + if isinstance(aval, core.DShapedArray) for d in aval.shape + if isinstance(d, core.DBIdx)) + + # Check that all implicit arguments have at least one DBIdx pointing to them. provided = [e for _, e in in_type] for aval, _ in in_type: if type(aval) is core.DShapedArray: @@ -254,7 +266,6 @@ def annotate(f: WrappedFun, if isinstance(d, core.DBIdx): provided[d.val] = True assert all(provided) - return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type) class _CacheLocalContext(threading.local): diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index fffc86790b4b..1fe68748944b 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -1354,6 +1354,12 @@ def f(x): d, = jaxpr.eqns[0].outvars self.assertEqual(d.aval.shape, (a, a)) + def test_inferring_valid_subjaxpr_type_add(self): + def f(x): + return x + x.shape[0] + + jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash + if __name__ == '__main__': absltest.main(testLoader=jtu.JaxTestLoader())