diff --git a/jax/_src/core.py b/jax/_src/core.py index c351e6980bf7..8dfd4f04ec14 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1918,6 +1918,7 @@ def __init__(self, aval, buf): dtype = property(lambda self: self._aval.dtype) def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) + def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): diff --git a/jax/_src/interpreters/partial_eval.py b/jax/_src/interpreters/partial_eval.py index 44738b3df16f..2b25082349c9 100644 --- a/jax/_src/interpreters/partial_eval.py +++ b/jax/_src/interpreters/partial_eval.py @@ -1005,11 +1005,9 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr: @weakref_lru_cache def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr: - """Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr.""" + """Move n invars to constvars. Like an inverse of convert_constvars_jaxpr.""" if n == 0: return jaxpr.replace() # 'return jaxpr' would create cache reference cycle - if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects): - raise NotImplementedError config.enable_checks.value and core.check_jaxpr(jaxpr) constvars, invars = split_list(jaxpr.invars, [n]) dbg = jaxpr.debug_info and jaxpr.debug_info._replace( diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 9fc30fb6530c..ae3347bdaf12 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -67,14 +67,12 @@ from jax._src.lib.mlir.dialects import hlo from jax._src.partition_spec import PartitionSpec from jax._src.sharding_impls import ( - ArrayMapping, ArrayMappingOrAutoOrUnspecified, - AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, + ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED, + UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto, is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources, - SingleDeviceSharding, GSPMDSharding -) -from jax._src.util import (safe_map, safe_zip, partition_list, - wrap_name, tuple_update, tuple_delete, - distributed_debug_log, + SingleDeviceSharding, GSPMDSharding) +from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name, + tuple_update, tuple_delete, distributed_debug_log, unzip2, HashableFunction, weakref_lru_cache) from jax._src.state.types import AbstractRef, RefEffect @@ -1153,14 +1151,14 @@ class ExecuteReplicated: __slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler', 'has_unordered_effects', 'ordered_effects', 'keepalive', 'has_host_callbacks', '_local_devices', 'kept_var_idx', - 'out_mut', '__weakref__'] + 'mut', '__weakref__'] def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, out_handler: ResultsHandler, unordered_effects: list[core.Effect], ordered_effects: list[core.Effect], keepalive: Any, has_host_callbacks: bool, kept_var_idx: set[int], - out_mut: Sequence[int | None] | None): + mut: MutationData | None): self.xla_executable = xla_executable self.name = name self.backend = backend @@ -1172,7 +1170,7 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler, self.keepalive = keepalive self.has_host_callbacks = has_host_callbacks self.kept_var_idx = kept_var_idx - self.out_mut = out_mut + self.mut = mut def _add_tokens_to_inputs(self, input_bufs): if self.ordered_effects: @@ -1195,6 +1193,8 @@ def _handle_token_bufs(self, token_bufs, sharded_token): @profiler.annotate_function def __call__(self, *args): args = [x for i, x in enumerate(args) if i in self.kept_var_idx] + if self.mut: + args = [*args, *self.mut.in_mut] input_bufs = self.in_handler(args) if (self.ordered_effects or self.has_unordered_effects or self.has_host_callbacks): @@ -1215,11 +1215,11 @@ def __call__(self, *args): out = self.out_handler(out_arrays) else: out = results.consume_with_handlers(self.out_handler.handlers) - if self.out_mut is None: + if self.mut is None: return out else: out_ = [] - for i, o in zip(self.out_mut, out): + for i, o in zip(self.mut.out_mut, out): if i is not None: args[i]._buf = o else: @@ -1781,19 +1781,38 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name, return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars, kept_var_idx, name_stack) +class MutationData(NamedTuple): + in_mut: list[core.MutableArray] + out_mut: list[int | None] + @weakref_lru_cache def _discharge_refs( jaxpr: core.ClosedJaxpr -) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]: +) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]: from jax._src.state.discharge import discharge_state + jaxpr, in_mut = _move_mutable_consts(jaxpr) new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts)) count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals) if isinstance(a, AbstractRef)} outin_map = {j: i for i, j in inout_map.items()} inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals)))) - out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals)))) - return new_jaxpr, inout_aliases, out_mut + out_mut = list(map(outin_map.get, range(len(new_jaxpr.out_avals)))) + return new_jaxpr, inout_aliases, MutationData(in_mut, out_mut) + +@weakref_lru_cache +def _move_mutable_consts( + closed_jaxpr: core.ClosedJaxpr, +) -> tuple[core.ClosedJaxpr, list[core.MutableArray]]: + jaxpr = closed_jaxpr.jaxpr + hoist = [isinstance(c, core.MutableArray) for c in closed_jaxpr.consts] + consts, in_mut = partition_list(hoist, closed_jaxpr.consts) + constvars, mutvars = partition_list(hoist, jaxpr.constvars) + invars = (*jaxpr.invars, *mutvars) + effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns) + jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns, + effects, None) + return core.ClosedJaxpr(jaxpr, consts), in_mut @dataclasses.dataclass(frozen=True) @@ -2012,16 +2031,20 @@ def lower_sharding_computation( in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx) if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects): - closed_jaxpr, inout_aliases, out_mut = _discharge_refs(closed_jaxpr) - if out_mut: - out_layouts_ = iter(zip(out_shardings, out_layouts)) - out_shardings, out_layouts = unzip2( - next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i]) - for i in out_mut) - assert next(out_layouts_, None) is None - global_out_avals = closed_jaxpr.out_avals + closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr) + in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut) + in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut) + donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut) + out_layouts_ = iter(zip(out_shardings, out_layouts)) + out_shardings, out_layouts = unzip2( + next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i]) + for i in mut.out_mut) + assert next(out_layouts_, None) is None + # TODO(yashkatariya): remove global_in_avals / global_out_avals + global_in_avals = closed_jaxpr.in_avals + global_out_avals = closed_jaxpr.out_avals else: - inout_aliases = out_mut = None + inout_aliases = mut = None jaxpr = closed_jaxpr.jaxpr assert len(out_shardings) == len(out_layouts) == len(global_out_avals), ( @@ -2106,7 +2129,7 @@ def lower_sharding_computation( host_callbacks=host_callbacks, keepalive=keepalive, kept_var_idx=kept_var_idx, - out_mut=out_mut, + mut=mut, backend=backend, device_assignment=da_object, committed=committed, @@ -2775,7 +2798,7 @@ class UnloadedMeshExecutable: keepalive: Sequence[Any] host_callbacks: Sequence[Any] kept_var_idx: set[int] - out_mut: Sequence[None | int] | None + mut: MutationData | None auto_spmd_lowering: bool in_layouts: Sequence[SpecifiedLayout | None] out_layouts: Sequence[SpecifiedLayout | None] @@ -2795,7 +2818,7 @@ def build_unsafe_call(self): unsafe_call = ExecuteReplicated( # type: ignore # assignment self.xla_executable, self.name, self.backend, handle_args, handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive, - bool(self.host_callbacks), self.kept_var_idx, self.out_mut) + bool(self.host_callbacks), self.kept_var_idx, self.mut) return unsafe_call def load(self) -> MeshExecutable: @@ -2829,7 +2852,7 @@ def from_hlo(name: str, in_layouts: MaybeLayout, out_layouts: MaybeLayout, pmap_nreps: int = 1, - out_mut: Sequence[None | int] | None = None, + mut: MutationData | None = None, shape_poly_state: mlir.ShapePolyLoweringState | None = None, all_default_mem_kind: bool = True, all_args_info: AllArgsInfo | None = None, @@ -2922,7 +2945,7 @@ def from_hlo(name: str, keepalive=keepalive, host_callbacks=host_callbacks, kept_var_idx=kept_var_idx, - out_mut=out_mut, + mut=mut, auto_spmd_lowering=auto_spmd_lowering, in_layouts=in_layouts, # type: ignore out_layouts=out_layouts, # type: ignore diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index d166894e8613..a5631573707b 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1409,7 +1409,7 @@ def call_impl_cache_miss(*args_, **kwargs_): donated_invars=donated_invars, name=name, keep_unused=keep_unused, inline=inline) fastpath_data = _get_fastpath_data( - compiled, tree_structure(out_flat), args, out_flat, [], set()) + compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects) return out_flat, fastpath_data f = _get_jaxpr_as_fun( @@ -1561,6 +1561,14 @@ def pjit_staging_rule(trace, *args, **params): params['jaxpr'].effects, source_info) trace.frame.add_eqn(eqn) return out_tracers + elif any(isinstance(c, core.MutableArray) for c in params['jaxpr'].consts): + jaxpr, consts = pxla._move_mutable_consts(params['jaxpr']) + consts = map(trace.instantiate_const, consts) + in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts) + donated_invars = (*params['donated_invars'],) + (False,) * len(consts) + new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings, + donated_invars=donated_invars) + return trace.default_process_primitive(pjit_p, (*args, *consts), new_params) else: return trace.default_process_primitive(pjit_p, args, params) pe.custom_staging_rules[pjit_p] = pjit_staging_rule diff --git a/jax/experimental/export/_export.py b/jax/experimental/export/_export.py index 3c48768c0db4..358c8873ca3d 100644 --- a/jax/experimental/export/_export.py +++ b/jax/experimental/export/_export.py @@ -436,8 +436,8 @@ def do_export(*args_specs, **kwargs_specs) -> Exported: mlir_module = lowering.stablehlo() args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals) - if "out_mut" in lowering.compile_args: - if lowering.compile_args["out_mut"]: raise NotImplementedError + if "mut" in lowering.compile_args: + if lowering.compile_args["mut"]: raise NotImplementedError if "kept_var_idx" in lowering.compile_args: module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"])) else: @@ -747,7 +747,7 @@ def _check_lowering(lowering) -> None: allowed_compile_args = [ "backend", "mesh", "global_in_avals", "global_out_avals", "in_shardings", "out_shardings", "kept_var_idx", - "out_mut", "spmd_lowering", "auto_spmd_lowering", + "mut", "spmd_lowering", "auto_spmd_lowering", "tuple_args", "ordered_effects", "unordered_effects", "keepalive", "host_callbacks", "pmap_nreps", "committed", "device_assignment", "jaxpr_debug_info", "shape_poly_state", diff --git a/tests/state_test.py b/tests/state_test.py index 1f109536fc16..85c7b77d25e3 100644 --- a/tests/state_test.py +++ b/tests/state_test.py @@ -1544,6 +1544,51 @@ def f(x_mut, y, z_mut, w): self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False) self.assertAllClose(out2, y + w, check_dtypes=False) + @parameterized.parameters([True, False]) + def test_closed_over_basic(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + def f(): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + if jit: + f = jax.jit(f) + + f() + + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + + jaxpr = jax.make_jaxpr(f)() + self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects)) + + @parameterized.parameters([True, False]) + def test_closed_over_nested(self, jit): + x_mut = core.mutable_array(jnp.zeros(3)) + + @jax.jit + def f(y_mut, z): + x_mut[...] += 1. + x_mut[0] += 1 + x_mut[1] += 5 + + y_mut[2] += 7 + return z + 9 + + if jit: + f = jax.jit(f) + + y_mut = core.mutable_array(np.zeros(3)) + + w = f(y_mut, 1) + + self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]), + check_dtypes=False) + self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]), + check_dtypes=False) + self.assertAllClose(w, 10, check_dtypes=False) + if CAN_USE_HYPOTHESIS: