Skip to content

Commit

Permalink
Merge pull request #8561 from mattjj:add-donated-invars-to-xlacomputa…
Browse files Browse the repository at this point in the history
…tion

PiperOrigin-RevId: 410368194
  • Loading branch information
jax authors committed Nov 16, 2021
2 parents b7e3129 + 5d35b8a commit 6883571
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 17 deletions.
21 changes: 15 additions & 6 deletions jax/_src/api.py
Expand Up @@ -488,17 +488,21 @@ class Lowered:
querying properties of lowered computations across JAX's various
lowering paths (``jit``, ``pmap``, etc.).
"""
__slots__ = ['in_tree', 'out_tree', '_lowering', '_no_kwargs']
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_lowering',
'_no_kwargs']

in_tree: PyTreeDef
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_lowering: Union[xla.XlaComputation, pxla.MeshComputation]
_no_kwargs: bool

def __init__(self, lowering, in_tree, out_tree, no_kwargs=False):
def __init__(self, lowering, in_tree, out_tree, donate_argnums,
no_kwargs=False):
self._lowering = lowering
self.in_tree = in_tree
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def _xla_computation(self):
Expand All @@ -508,7 +512,8 @@ def _xla_computation(self):

def compile(self) -> 'Compiled':
return Compiled(
self._lowering.compile(), self.in_tree, self.out_tree, self._no_kwargs)
self._lowering.compile(), self.in_tree, self.out_tree,
self.donate_argnums, self._no_kwargs)


class Compiled:
Expand All @@ -519,17 +524,21 @@ class Compiled:
common API for querying properties of compiled computations across
JAX's various compilation paths and backends.
"""
__slots__ = ['in_tree', 'out_tree', '_executable', '_no_kwargs']
__slots__ = ['in_tree', 'out_tree', 'donate_argnums', '_executable',
'_no_kwargs']

in_tree: PyTreeDef
out_tree: PyTreeDef
donate_argnums: Tuple[int]
_executable: Union[xla.XlaCompiledComputation, pxla.MeshExecutable]
_no_kwargs: bool

def __init__(self, executable, in_tree, out_tree, no_kwargs=False):
def __init__(self, executable, in_tree, out_tree, donate_argnums,
no_kwargs=False):
self._executable = executable
self.in_tree = in_tree
self.out_tree = out_tree
self.donate_argnums = donate_argnums
self._no_kwargs = no_kwargs

def _xla_executable(self):
Expand Down Expand Up @@ -589,7 +598,7 @@ def lower(*args, **kwargs) -> Lowered:
arg_specs = unsafe_map(arg_spec, args_flat)
computation = xla.lower_xla_callable(
flat_fun, device, backend, name, donated_invars, *arg_specs)
return Lowered(computation, in_tree, out_tree())
return Lowered(computation, in_tree, out_tree(), donate_argnums)

return lower

Expand Down
9 changes: 5 additions & 4 deletions jax/experimental/pjit.py
Expand Up @@ -229,23 +229,24 @@ def infer_params(*args, **kwargs):
donated_invars=donated_invars,
name=flat_fun.__name__,
positional_semantics=maps._positional_semantics)
return args_flat, params, in_tree, out_tree()
return args_flat, params, in_tree, out_tree(), donate_argnums

@wraps(fun)
def wrapped(*args, **kwargs):
for arg in tree_leaves(args):
_check_arg(arg)
args_flat, params, _, out_tree = infer_params(*args, **kwargs)
args_flat, params, _, out_tree, _ = infer_params(*args, **kwargs)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)

def lower(*args, **kwargs):
args_flat, params, in_tree, out_tree = infer_params(*args, **kwargs)
args_flat, params, in_tree, out_tree, donate_argnums = \
infer_params(*args, **kwargs)
lowering = _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'], maps._positional_semantics)
return Lowered(lowering, in_tree, out_tree, no_kwargs=True)
return Lowered(lowering, in_tree, out_tree, donate_argnums, no_kwargs=True)

wrapped.lower = lower
return wrapped
Expand Down
5 changes: 3 additions & 2 deletions jax/interpreters/pxla.py
Expand Up @@ -1723,15 +1723,16 @@ def lower_mesh_computation(

built = c.Build(out_tuple)
return MeshComputation(
built, mesh, local_in_untiled_avals,
built, donated_invars, mesh, local_in_untiled_avals,
local_out_untiled_avals, (out_jaxpr_avals if spmd_lowering else None),
in_axes, out_axes, spmd_lowering, tuple_args)


class MeshComputation:
def __init__(self, hlo, *compile_args):
def __init__(self, hlo, donated_invars, *compile_args):
self._executable = None
self._hlo = hlo
self._donated_invars = donated_invars
self.compile_args = compile_args

def hlo(self):
Expand Down
11 changes: 7 additions & 4 deletions jax/interpreters/xla.py
Expand Up @@ -800,7 +800,7 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
# and don't need to evaluate their arguments.
if not jaxpr.eqns:
return XlaComputation(
name, None, True, jaxpr, consts, device, abstract_args, out_avals,
name, None, True, None, jaxpr, consts, device, abstract_args, out_avals,
kept_var_idx)

if not _on_exit:
Expand Down Expand Up @@ -850,8 +850,8 @@ def lower_xla_callable(fun: lu.WrappedFun, device, backend, name,
", ".join(unused_donations)))
built = c.build(output)
return XlaComputation(
name, built, False, nreps, device, backend, tuple_args, abstract_args,
out_avals, kept_var_idx)
name, built, False, donated_invars, nreps, device, backend, tuple_args,
abstract_args, out_avals, kept_var_idx)


def compile_or_get_cached(backend, computation, compile_options):
Expand All @@ -875,11 +875,14 @@ class XlaComputation:
name: str
_is_trivial: bool
_executable: Optional['XlaCompiledComputation']
_donated_invars: Optional[Sequence[bool]]

def __init__(self, name: str, hlo, is_trivial: bool, *compile_args):
def __init__(self, name: str, hlo, is_trivial: bool,
donated_invars: Optional[Sequence[bool]], *compile_args):
self.name = name
self._hlo = hlo
self._is_trivial = is_trivial
self._donated_invars = donated_invars
self._executable = None
self.compile_args = compile_args

Expand Down
8 changes: 8 additions & 0 deletions tests/api_test.py
Expand Up @@ -780,6 +780,14 @@ def f(*args):
f_exe = self.jit(f).lower(1., 1.).compile()
self.assertAllClose(f_exe(1., 1.), 1.)

def test_jit_lower_donate_argnums_available(self):
def f(*args):
x, *_ = args
return x
f_low = self.jit(f, donate_argnums=(0,)).lower(1., 1.)
f_com = f_low.compile()
f_low.donate_argnums == f_com.donate_argnums == (0,)


class PythonJitTest(CPPJitTest):

Expand Down
14 changes: 13 additions & 1 deletion tests/pjit_test.py
Expand Up @@ -385,7 +385,19 @@ def _test_rule(*args, **kwargs):
def testLowerWithDuckTyping(self):
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
# Make sure this doesn't crash
pjit(lambda x: x + 4, in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
pjit(lambda x: x + 4,
in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)

@jtu.with_mesh([('x', 2)])
def testLowerDonateArgnumsAvailable(self):
x = jax.ShapeDtypeStruct((2, 2), jnp.float32)
def f(*args):
x, *_ = args
return x
f_low = pjit(f, donate_argnums=(0,),
in_axis_resources=P('x'), out_axis_resources=P('x')).lower(x)
f_com = f_low.compile()
f_low.donate_argnums == f_com.donate_argnums == (0,)

def testInfeed(self):
devices = np.array(jax.local_devices())
Expand Down

0 comments on commit 6883571

Please sign in to comment.