Skip to content

Commit

Permalink
[dynamic shapes] revive iree
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Jul 6, 2022
1 parent 95e7933 commit 6bb90fd
Show file tree
Hide file tree
Showing 8 changed files with 266 additions and 176 deletions.
40 changes: 12 additions & 28 deletions jax/_src/dispatch.py
Expand Up @@ -192,7 +192,10 @@ def _device_from_arg_devices(devices: Sequence[Optional[Device]]) -> Optional[De
def _xla_call_impl(fun: lu.WrappedFun, *args, device, backend, name,
donated_invars, inline, keep_unused: bool):
del inline # Only used at tracing time
arg_specs = unsafe_map(arg_spec, args)
if fun.in_type is None:
arg_specs = unsafe_map(arg_spec, args)
else:
arg_specs = [(None, getattr(x, '_device', None)) for x in args]
compiled_fun = xla_callable(fun, device, backend, name, donated_invars,
keep_unused, *arg_specs)
try:
Expand Down Expand Up @@ -283,25 +286,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
in_type = tuple(unsafe_zip(abstract_args, itertools.repeat(True)))
fun = lu.annotate(fun, in_type)
else:
# Check that the provided abstract_args are consistent with in_type by first
# collecting values of axis size arguments, then substituting them in for
# DBIdx occurrences.
axis_sizes: Dict[core.DBIdx, int] = {}
abstract_args_iter = iter(abstract_args)
for expected_type, explicit in fun.in_type:
if explicit:
aval = next(abstract_args_iter)
if isinstance(expected_type, core.DShapedArray):
# Check the value for any DBIdx variables is consistent.
assert all(axis_sizes.setdefault(d1, d2) == d2
for d1, d2 in zip(expected_type.shape, aval.shape)
if type(d1) is core.DBIdx)
# Check the type matches after substitution.
expected_shape = [axis_sizes.get(d, d) for d in expected_type.shape] # type: ignore
expected_aval = core.ShapedArray(
shape=tuple(expected_shape), dtype=expected_type.dtype,
weak_type=expected_type.weak_type)
assert core.typematch(expected_aval, aval)
assert abstract_args == (None,) * len(abstract_args)
abstract_args = [aval for aval, _ in fun.in_type]
with log_elapsed_time(f"Finished tracing + transforming {fun.__name__} "
"for jit in {elapsed_time} sec"):
jaxpr, out_type, consts = pe.trace_to_jaxpr_final2(
Expand All @@ -326,7 +312,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
if i in kept_var_idx]
del kept_const_idx
else:
kept_var_idx = set(range(len(abstract_args)))
kept_var_idx = set(range(len(fun.in_type)))

nreps = jaxpr_replicas(jaxpr)
device = _xla_callable_device(nreps, backend, device, arg_devices)
Expand Down Expand Up @@ -430,12 +416,10 @@ def jaxpr_has_pmap(jaxpr):
return False

def jaxpr_has_bints(jaxpr: core.Jaxpr) -> bool:
return (any(type(d) is core.Var for v in jaxpr.invars
if type(v.aval) is core.DShapedArray for d in v.aval.shape) or
any(type(d) is core.Var
return (any(type(v.aval) is core.AbstractBInt for v in jaxpr.invars) or
any(type(v.aval) is core.AbstractBInt
for j in itertools.chain([jaxpr], core.subjaxprs(jaxpr))
for e in j.eqns for v in itertools.chain(e.invars, e.outvars)
if type(v.aval) is core.DShapedArray for d in v.aval.shape))
for e in j.eqns for v in e.outvars))

def _prune_unused_inputs(
jaxpr: core.Jaxpr) -> Tuple[core.Jaxpr, Set[int], Set[int]]:
Expand Down Expand Up @@ -545,7 +529,7 @@ def _input_handler(backend: Backend,
in_avals, which_explicit = util.unzip2(in_type)
# Check whether we actually need an input_handler.
needs_implicit = which_explicit and not all(which_explicit)
needs_out_handling = any(type(d) is core.InDBIdx for a in out_type or []
needs_out_handling = any(type(d) is core.InDBIdx for a, _ in out_type or []
if type(a) is core.DShapedArray for d in a.shape)

if not needs_implicit and not needs_out_handling:
Expand All @@ -565,7 +549,7 @@ def _input_handler(backend: Backend,

# Precompute which input values are needed for output types.
inputs_needed_for_out_types = out_type and [
d.val for aval in out_type if type(aval) is core.DShapedArray # type: ignore
d.val for aval, _ in out_type if type(aval) is core.DShapedArray # type: ignore
for d in aval.shape if type(d) is core.InDBIdx]

def elaborate(explicit_args: Sequence[Any]) -> Tuple[Tuple, Optional[Tuple]]:
Expand Down

0 comments on commit 6bb90fd

Please sign in to comment.