Skip to content

Commit

Permalink
fix bug in djax type signature inference logic
Browse files Browse the repository at this point in the history
Co-authored-by: Sharad Vikram <sharad.vikram@gmail.com>
  • Loading branch information
mattjj and sharadmv committed Sep 27, 2022
1 parent 82636b0 commit 1e7ca8f
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 10 deletions.
14 changes: 8 additions & 6 deletions jax/interpreters/partial_eval.py
Expand Up @@ -2095,7 +2095,9 @@ def infer_lambda_input_type(
idxs, implicit_types = _collect_implicit(args, specs)
implicit_sig = [(ty, False) for ty in implicit_types]
explicit_sig = [(_arg_type(idxs, x, s), True) for x, s in zip(args, specs)]
return (*implicit_sig, *explicit_sig)
input_type = (*implicit_sig, *explicit_sig)
lu._check_input_type(input_type)
return input_type

def _canonicalize_specs(
ndims: Sequence[int], specs: Optional[Sequence[AbstractedAxesSpec]]
Expand Down Expand Up @@ -2143,6 +2145,7 @@ def _complete_specs(
for x, spec in zip(args, specs))
return specs


def _collect_implicit(
args: Sequence[Any], specs: List[Dict[int, AbstractedAxisName]]
) -> Tuple[Dict[AbstractedAxisName, DBIdx], List[AbstractValue]]:
Expand All @@ -2153,24 +2156,23 @@ def _collect_implicit(
idxs: Dict[AbstractedAxisName, DBIdx] = {}
implicit_types: List[AbstractValue] = []
explicit_tracers: Dict[TracerId, int] = {}
counter = (DBIdx(i) for i in it.count())
counter = it.count()

# Add implicit arguments to idxs.

for explicit_idx, (x, spec) in enumerate(zip(args, specs)):
for i, name in spec.items():
if name not in idxs and id(x.shape[i]) not in explicit_tracers:
idxs[name] = next(counter)
idxs[name] = DBIdx(next(counter))
implicit_types.append(raise_to_shaped(get_aval(x.shape[i])))
if isinstance(x, Tracer):
explicit_tracers[id(x)] = explicit_idx
explicit_tracers.setdefault(id(x), explicit_idx) # use the first

# Now that we know the implicit args, add explicit args to idxs.
offset = len(implicit_types)
for x, spec in zip(args, specs):
for i, name in spec.items():
if id(x.shape[i]) in explicit_tracers:
idxs[name] = DBIdx(offset + explicit_tracers[id(x.shape[i])])
idxs.setdefault(name, DBIdx(offset + explicit_tracers[id(x.shape[i])]))

return idxs, implicit_types

Expand Down
19 changes: 15 additions & 4 deletions jax/linear_util.py
Expand Up @@ -236,25 +236,36 @@ def wrap_init(f, params=None) -> WrappedFun:
params = () if params is None else tuple(sorted(params.items()))
return WrappedFun(f, (), (), params, None)

def annotate(f: WrappedFun,
in_type: Optional[Tuple[Tuple[core.AbstractValue, bool], ...]]
) -> WrappedFun:

def annotate(f: WrappedFun, in_type: core.InputType) -> WrappedFun:
assert f.in_type is None
if in_type is None:
return f
_check_input_type(in_type)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)

def _check_input_type(in_type: core.InputType) -> None:
# Check that in_type is syntactically well-formed
assert (type(in_type) is tuple and all(type(e) is tuple for e in in_type) and
all(isinstance(a, core.AbstractValue) and type(b) is bool
and not isinstance(a, core.ConcreteArray) for a, b in in_type) and
all(isinstance(d, (int, core.BInt, core.DBIdx)) for a, _ in in_type
if type(a) is core.DShapedArray for d in a.shape))

# Check that all DBIdx point to positions to the left of the input on which
# they appear.
assert all(d.val < i for i, (aval, _) in enumerate(in_type)
if isinstance(aval, core.DShapedArray) for d in aval.shape
if isinstance(d, core.DBIdx))

# Check that all implicit arguments have at least one DBIdx pointing to them.
provided = [e for _, e in in_type]
for aval, _ in in_type:
if type(aval) is core.DShapedArray:
for d in aval.shape:
if isinstance(d, core.DBIdx):
provided[d.val] = True
assert all(provided)
return WrappedFun(f.f, f.transforms, f.stores, f.params, in_type)


class _CacheLocalContext(threading.local):
Expand Down
6 changes: 6 additions & 0 deletions tests/dynamic_api_test.py
Expand Up @@ -1354,6 +1354,12 @@ def f(x):
d, = jaxpr.eqns[0].outvars
self.assertEqual(d.aval.shape, (a, a))

def test_inferring_valid_subjaxpr_type_add(self):
def f(x):
return x + x.shape[0]

jax.make_jaxpr(f, abstracted_axes=('n',))(jnp.arange(3)) # doesn't crash


if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 1e7ca8f

Please sign in to comment.