Skip to content

Commit

Permalink
Make jit(f).lower(*args) go via lower_sharding_computation when `ja…
Browse files Browse the repository at this point in the history
…x_array` is enabled.

PiperOrigin-RevId: 476148608
  • Loading branch information
yashk2810 authored and jax authors committed Sep 22, 2022
1 parent 640e15f commit a157982
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 28 deletions.
30 changes: 23 additions & 7 deletions jax/_src/api.py
Expand Up @@ -669,10 +669,19 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,
# all the other arguments stored as attributes.

def arg_spec(x):
from jax.experimental.sharding import PmapSharding
# like xla.arg_spec but duck-types on x.shape and x.dtype
aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x)
device = getattr(x, '_device', None)
return aval, device
if jax.config.jax_array:
if hasattr(x, 'sharding'):
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (x.sharding if x._committed else None)
else:
return aval, None
else:
device = getattr(x, '_device', None)
return aval, device

@api_boundary
def lower(*args, **kwargs) -> stages.Lowered:
Expand All @@ -699,11 +708,18 @@ def lower(*args, **kwargs) -> stages.Lowered:
if abstracted_axes:
raise ValueError("abstracted_axes must be used with --jax_dynamic_shapes")
in_avals, _ = unzip2(arg_specs_and_devices)
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())
if jax.config.jax_array:
computation = dispatch.sharded_lowering(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())
else:
computation = dispatch.lower_xla_callable(
flat_fun, device, backend, flat_fun.__name__, donated_invars, True,
keep_unused, *arg_specs_and_devices)
return stages.Lowered.from_flat_info(
computation, in_tree, in_avals, donate_argnums, out_tree())

return lower

Expand Down
13 changes: 7 additions & 6 deletions jax/_src/dispatch.py
Expand Up @@ -309,8 +309,8 @@ def not_none_device_or_backend_on_jit(backend, device, num_ins):
return committed, da, in_shardings


def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
*arg_specs):
def sharded_lowering(fun, device, backend, name, donated_invars, always_lower,
keep_unused, *arg_specs):
# TODO(yashkatariya): Remove the local imports from here when the functions
# in pxla.py move to dispatch.py or a utils file.
from jax.interpreters import pxla
Expand Down Expand Up @@ -353,15 +353,16 @@ def sharded_lowering(fun, device, backend, name, donated_invars, keep_unused,
fun, 'jit', name, in_shardings, pjit._UNSPECIFIED,
donated_invars, in_avals,
in_is_global=(True,) * len(arg_specs), keep_unused=keep_unused,
committed=committed, inp_device_assignment=inp_device_assignment).compile(
_allow_propagation_to_outputs=True).unsafe_call
committed=committed, always_lower=always_lower,
inp_device_assignment=inp_device_assignment)


def _xla_callable_uncached(fun: lu.WrappedFun, device, backend, name,
donated_invars, keep_unused, *arg_specs):
if config.jax_array:
return sharded_lowering(fun, device, backend, name,
donated_invars, keep_unused, *arg_specs)
computation = sharded_lowering(fun, device, backend, name, donated_invars,
False, keep_unused, *arg_specs)
return computation.compile(_allow_propagation_to_outputs=True).unsafe_call
else:
return lower_xla_callable(fun, device, backend, name, donated_invars, False,
keep_unused, *arg_specs).compile().unsafe_call
Expand Down
5 changes: 4 additions & 1 deletion jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -632,7 +632,10 @@ def _lower_native_and_run(fun_jax: Callable,
logging.vlog(2, f"XlaCallModule {mhlo_module_text}")

# Figure out the result types and shapes
out_avals = lowered.compile_args["out_avals"]
if config.jax_array:
out_avals = lowered.compile_args["global_out_avals"]
else:
out_avals = lowered.compile_args["out_avals"]
# TODO(necula): handle d being InDBIdx
out_shapes = tuple(
tuple(d if type(d) is int else None
Expand Down
2 changes: 1 addition & 1 deletion jax/experimental/pjit.py
Expand Up @@ -1031,7 +1031,7 @@ def _pjit_lower_cached(
return pxla.lower_sharding_computation(
fun, 'pjit', name, in_shardings, out_shardings, donated_invars,
jaxpr.in_avals, in_is_global=in_is_global, keep_unused=True,
committed=True)
committed=True, always_lower=False)


def _pjit_abstract_eval(*args, jaxpr, out_shardings, resource_env,
Expand Down
31 changes: 18 additions & 13 deletions jax/interpreters/pxla.py
Expand Up @@ -2672,6 +2672,7 @@ def lower_sharding_computation(
in_is_global: Sequence[bool],
keep_unused: bool,
committed: bool,
always_lower: bool,
inp_device_assignment: Optional[Sequence[xc.Device]] = None):
"""Lowers a computation to XLA. It can take arbitrary shardings as input.
Expand Down Expand Up @@ -2742,7 +2743,7 @@ def lower_sharding_computation(
# Computations that only produce constants and/or only rearrange their inputs,
# which are often produced from partial evaluation, don't need compilation,
# and don't need to evaluate their arguments.
if (not (jaxpr.effects or has_outfeed) and
if (not always_lower and not (jaxpr.effects or has_outfeed) and
(not jaxpr.eqns and all(kept_outputs) or not jaxpr.outvars) and
all(_is_unspecified(o) for o in out_shardings) and # type: ignore
not hasattr(backend, "compile_replicated")): # this means 'not pathways'
Expand Down Expand Up @@ -3135,19 +3136,21 @@ def _get_mesh_pspec_shardings_from_executable(xla_executable, mesh):


class MeshExecutable(stages.XlaExecutable):
__slots__ = ['xla_executable', 'unsafe_call', '_input_avals',
'_in_shardings', '_out_shardings', '_auto_spmd_lowering']
__slots__ = ['xla_executable', 'unsafe_call', 'in_avals',
'_in_shardings', '_out_shardings', '_auto_spmd_lowering',
'_kept_var_idx']

def __init__(self, xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering):
def __init__(self, xla_executable, unsafe_call, in_avals,
in_shardings, out_shardings, auto_spmd_lowering, kept_var_idx):
self.xla_executable = xla_executable
self.unsafe_call = unsafe_call
# input_avals is a list of global and local avals. Aval is global if input
# is a GDA else local.
self._input_avals = input_avals
# in_avals is a list of global and local avals. Aval is global if input
# is a GDA or jax.Array else local.
self.in_avals = in_avals
self._in_shardings = in_shardings
self._out_shardings = out_shardings
self._auto_spmd_lowering = auto_spmd_lowering
self._kept_var_idx = kept_var_idx

@staticmethod
def from_hlo(name: str,
Expand Down Expand Up @@ -3281,7 +3284,8 @@ def from_hlo(name: str,
bool(host_callbacks), kept_var_idx)

return MeshExecutable(xla_executable, unsafe_call, input_avals,
in_shardings, out_shardings, auto_spmd_lowering)
in_shardings, out_shardings, auto_spmd_lowering,
kept_var_idx)

@staticmethod
def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
Expand All @@ -3307,19 +3311,20 @@ def from_trivial_jaxpr(jaxpr, consts, global_in_avals, global_out_avals,
unsafe_call = partial(_execute_trivial, jaxpr, consts, handle_ins,
handle_outs, kept_var_idx)
return MeshExecutable(None, unsafe_call, global_in_avals, in_shardings,
out_shardings, False)
out_shardings, False, kept_var_idx)

# -- stages.XlaExecutable overrides

def xla_extension_executable(self):
return self.xla_executable

def call(self, *args):
arg_avals = map(xla.abstractify, args)
ref_avals = self._input_avals
kept_args = [a for i, a in enumerate(args) if i in self._kept_var_idx]
arg_avals = map(xla.abstractify, kept_args)
ref_avals = self.in_avals
dispatch.check_arg_avals_for_call(ref_avals, arg_avals)
# Check the GDA sharding and the input sharding.
_check_gda_or_array_xla_sharding_match(args, self._in_shardings)
_check_gda_or_array_xla_sharding_match(kept_args, self._in_shardings)
return self.unsafe_call(*args)


Expand Down
10 changes: 10 additions & 0 deletions tests/api_test.py
Expand Up @@ -55,6 +55,7 @@
from jax.interpreters import partial_eval as pe
from jax.interpreters.pxla import PartitionSpec as P
from jax.experimental import array, sharding
from jax.experimental import pjit
from jax._src import config as jax_config
from jax._src import custom_derivatives
from jax._src import device_array
Expand Down Expand Up @@ -973,6 +974,15 @@ def f(x): return x
out = self.jit(f).lower(1.).compile()(4.)
self.assertAllClose(out, 4.)

def test_jit_lower_compile_sharding_computation(self):
if not config.jax_array:
self.skipTest('with_sharding_constraint only works with the Array path '
'in jit.')
s = sharding.SingleDeviceSharding(jax.devices()[0])
def f(x): return pjit.with_sharding_constraint(x, s)
out = self.jit(f).lower(1.).compile()(4.)
self.assertAllClose(out, 4.)

def test_jit_lower_compile_trivial_in_tree_mismatch(self):
def f(x): return x
f_exe = self.jit(f).lower(1.).compile()
Expand Down

0 comments on commit a157982

Please sign in to comment.