Skip to content

Commit

Permalink
[mutable-arrays] support closed-over mutable arrays in jit
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Mar 13, 2024
1 parent e7eb207 commit 649cd50
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 36 deletions.
1 change: 1 addition & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1918,6 +1918,7 @@ def __init__(self, aval, buf):
dtype = property(lambda self: self._aval.dtype)
def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
Expand Down
4 changes: 1 addition & 3 deletions jax/_src/interpreters/partial_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -1005,11 +1005,9 @@ def convert_constvars_jaxpr(jaxpr: Jaxpr) -> Jaxpr:

@weakref_lru_cache
def convert_invars_to_constvars(jaxpr: Jaxpr, n: int) -> Jaxpr:
"""Move n invars to constvars. Like an inverse of convert_constvars_Jaxpr."""
"""Move n invars to constvars. Like an inverse of convert_constvars_jaxpr."""
if n == 0:
return jaxpr.replace() # 'return jaxpr' would create cache reference cycle
if any(isinstance(eff, effects.JaxprInputEffect) for eff in jaxpr.effects):
raise NotImplementedError
config.enable_checks.value and core.check_jaxpr(jaxpr)
constvars, invars = split_list(jaxpr.invars, [n])
dbg = jaxpr.debug_info and jaxpr.debug_info._replace(
Expand Down
81 changes: 52 additions & 29 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,12 @@
from jax._src.lib.mlir.dialects import hlo
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding_impls import (
ArrayMapping, ArrayMappingOrAutoOrUnspecified,
AUTO, UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
ArrayMapping, ArrayMappingOrAutoOrUnspecified, AUTO, UNSPECIFIED,
UnspecifiedValue, get_array_mapping as _get_array_mapping, is_auto,
is_unspecified, is_unspecified_or_auto, array_mapping_to_axis_resources,
SingleDeviceSharding, GSPMDSharding
)
from jax._src.util import (safe_map, safe_zip, partition_list,
wrap_name, tuple_update, tuple_delete,
distributed_debug_log,
SingleDeviceSharding, GSPMDSharding)
from jax._src.util import (safe_map, safe_zip, partition_list, wrap_name,
tuple_update, tuple_delete, distributed_debug_log,
unzip2, HashableFunction, weakref_lru_cache)
from jax._src.state.types import AbstractRef, RefEffect

Expand Down Expand Up @@ -1153,14 +1151,14 @@ class ExecuteReplicated:
__slots__ = ['xla_executable', 'name', 'backend', 'in_handler', 'out_handler',
'has_unordered_effects', 'ordered_effects', 'keepalive',
'has_host_callbacks', '_local_devices', 'kept_var_idx',
'out_mut', '__weakref__']
'mut', '__weakref__']

def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
out_handler: ResultsHandler,
unordered_effects: list[core.Effect],
ordered_effects: list[core.Effect], keepalive: Any,
has_host_callbacks: bool, kept_var_idx: set[int],
out_mut: Sequence[int | None] | None):
mut: MutationData | None):
self.xla_executable = xla_executable
self.name = name
self.backend = backend
Expand All @@ -1172,7 +1170,7 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
self.keepalive = keepalive
self.has_host_callbacks = has_host_callbacks
self.kept_var_idx = kept_var_idx
self.out_mut = out_mut
self.mut = mut

def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
Expand All @@ -1195,6 +1193,8 @@ def _handle_token_bufs(self, token_bufs, sharded_token):
@profiler.annotate_function
def __call__(self, *args):
args = [x for i, x in enumerate(args) if i in self.kept_var_idx]
if self.mut:
args = [*args, *self.mut.in_mut]
input_bufs = self.in_handler(args)
if (self.ordered_effects or self.has_unordered_effects
or self.has_host_callbacks):
Expand All @@ -1215,11 +1215,11 @@ def __call__(self, *args):
out = self.out_handler(out_arrays)
else:
out = results.consume_with_handlers(self.out_handler.handlers)
if self.out_mut is None:
if self.mut is None:
return out
else:
out_ = []
for i, o in zip(self.out_mut, out):
for i, o in zip(self.mut.out_mut, out):
if i is not None:
args[i]._buf = o
else:
Expand Down Expand Up @@ -1781,19 +1781,38 @@ def _dce_jaxpr(closed_jaxpr, global_in_avals, api_name, fun_name,
return (closed_jaxpr, global_in_avals, tuple(global_out_avals), donated_invars,
kept_var_idx, name_stack)

class MutationData(NamedTuple):
in_mut: list[core.MutableArray]
out_mut: list[int | None]

@weakref_lru_cache
def _discharge_refs(
jaxpr: core.ClosedJaxpr
) -> tuple[core.ClosedJaxpr, Sequence[int | None], Sequence[int | None]]:
) -> tuple[core.ClosedJaxpr, Sequence[int | None], MutationData]:
from jax._src.state.discharge import discharge_state
jaxpr, in_mut = _move_mutable_consts(jaxpr)
new_jaxpr = core.ClosedJaxpr(*discharge_state(jaxpr.jaxpr, jaxpr.consts))
count = it.count(len(jaxpr.out_avals)) # new outputs are appended to the end
inout_map = {i: next(count) for i, a in enumerate(jaxpr.in_avals)
if isinstance(a, AbstractRef)}
outin_map = {j: i for i, j in inout_map.items()}
inout_aliases = tuple(map(inout_map.get, range(len(new_jaxpr.in_avals))))
out_mut = tuple(map(outin_map.get, range(len(new_jaxpr.out_avals))))
return new_jaxpr, inout_aliases, out_mut
out_mut = list(map(outin_map.get, range(len(new_jaxpr.out_avals))))
return new_jaxpr, inout_aliases, MutationData(in_mut, out_mut)

@weakref_lru_cache
def _move_mutable_consts(
closed_jaxpr: core.ClosedJaxpr,
) -> tuple[core.ClosedJaxpr, list[core.MutableArray]]:
jaxpr = closed_jaxpr.jaxpr
hoist = [isinstance(c, core.MutableArray) for c in closed_jaxpr.consts]
consts, in_mut = partition_list(hoist, closed_jaxpr.consts)
constvars, mutvars = partition_list(hoist, jaxpr.constvars)
invars = (*jaxpr.invars, *mutvars)
effects = pe.make_jaxpr_effects(constvars, invars, jaxpr.outvars, jaxpr.eqns)
jaxpr = core.Jaxpr(constvars, invars, jaxpr.outvars, jaxpr.eqns,
effects, None)
return core.ClosedJaxpr(jaxpr, consts), in_mut


@dataclasses.dataclass(frozen=True)
Expand Down Expand Up @@ -2012,16 +2031,20 @@ def lower_sharding_computation(
in_layouts = tuple(l for i, l in enumerate(in_layouts) if i in kept_var_idx)

if any(isinstance(e, RefEffect) for e in closed_jaxpr.effects):
closed_jaxpr, inout_aliases, out_mut = _discharge_refs(closed_jaxpr)
if out_mut:
out_layouts_ = iter(zip(out_shardings, out_layouts))
out_shardings, out_layouts = unzip2(
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
for i in out_mut)
assert next(out_layouts_, None) is None
global_out_avals = closed_jaxpr.out_avals
closed_jaxpr, inout_aliases, mut = _discharge_refs(closed_jaxpr)
in_shardings = (*in_shardings,) + (UNSPECIFIED,) * len(mut.in_mut)
in_layouts = (*in_layouts,) + (None,) * len(mut.in_mut)
donated_invars = (*donated_invars,) + (False,) * len(mut.in_mut)
out_layouts_ = iter(zip(out_shardings, out_layouts))
out_shardings, out_layouts = unzip2(
next(out_layouts_) if i is None else (in_shardings[i], in_layouts[i])
for i in mut.out_mut)
assert next(out_layouts_, None) is None
# TODO(yashkatariya): remove global_in_avals / global_out_avals
global_in_avals = closed_jaxpr.in_avals
global_out_avals = closed_jaxpr.out_avals
else:
inout_aliases = out_mut = None
inout_aliases = mut = None

jaxpr = closed_jaxpr.jaxpr
assert len(out_shardings) == len(out_layouts) == len(global_out_avals), (
Expand Down Expand Up @@ -2106,7 +2129,7 @@ def lower_sharding_computation(
host_callbacks=host_callbacks,
keepalive=keepalive,
kept_var_idx=kept_var_idx,
out_mut=out_mut,
mut=mut,
backend=backend,
device_assignment=da_object,
committed=committed,
Expand Down Expand Up @@ -2775,7 +2798,7 @@ class UnloadedMeshExecutable:
keepalive: Sequence[Any]
host_callbacks: Sequence[Any]
kept_var_idx: set[int]
out_mut: Sequence[None | int] | None
mut: MutationData | None
auto_spmd_lowering: bool
in_layouts: Sequence[SpecifiedLayout | None]
out_layouts: Sequence[SpecifiedLayout | None]
Expand All @@ -2795,7 +2818,7 @@ def build_unsafe_call(self):
unsafe_call = ExecuteReplicated( # type: ignore # assignment
self.xla_executable, self.name, self.backend, handle_args,
handle_outs, self.unordered_effects, self.ordered_effects, self.keepalive,
bool(self.host_callbacks), self.kept_var_idx, self.out_mut)
bool(self.host_callbacks), self.kept_var_idx, self.mut)
return unsafe_call

def load(self) -> MeshExecutable:
Expand Down Expand Up @@ -2829,7 +2852,7 @@ def from_hlo(name: str,
in_layouts: MaybeLayout,
out_layouts: MaybeLayout,
pmap_nreps: int = 1,
out_mut: Sequence[None | int] | None = None,
mut: MutationData | None = None,
shape_poly_state: mlir.ShapePolyLoweringState | None = None,
all_default_mem_kind: bool = True,
all_args_info: AllArgsInfo | None = None,
Expand Down Expand Up @@ -2922,7 +2945,7 @@ def from_hlo(name: str,
keepalive=keepalive,
host_callbacks=host_callbacks,
kept_var_idx=kept_var_idx,
out_mut=out_mut,
mut=mut,
auto_spmd_lowering=auto_spmd_lowering,
in_layouts=in_layouts, # type: ignore
out_layouts=out_layouts, # type: ignore
Expand Down
10 changes: 9 additions & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,7 +1409,7 @@ def call_impl_cache_miss(*args_, **kwargs_):
donated_invars=donated_invars, name=name, keep_unused=keep_unused,
inline=inline)
fastpath_data = _get_fastpath_data(
compiled, tree_structure(out_flat), args, out_flat, [], set())
compiled, tree_structure(out_flat), args, out_flat, [], jaxpr.effects)
return out_flat, fastpath_data

f = _get_jaxpr_as_fun(
Expand Down Expand Up @@ -1561,6 +1561,14 @@ def pjit_staging_rule(trace, *args, **params):
params['jaxpr'].effects, source_info)
trace.frame.add_eqn(eqn)
return out_tracers
elif any(isinstance(c, core.MutableArray) for c in params['jaxpr'].consts):
jaxpr, consts = pxla._move_mutable_consts(params['jaxpr'])
consts = map(trace.instantiate_const, consts)
in_shardings = (*params['in_shardings'],) + (UNSPECIFIED,) * len(consts)
donated_invars = (*params['donated_invars'],) + (False,) * len(consts)
new_params = dict(params, jaxpr=jaxpr, in_shardings=in_shardings,
donated_invars=donated_invars)
return trace.default_process_primitive(pjit_p, (*args, *consts), new_params)
else:
return trace.default_process_primitive(pjit_p, args, params)
pe.custom_staging_rules[pjit_p] = pjit_staging_rule
Expand Down
6 changes: 3 additions & 3 deletions jax/experimental/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ def do_export(*args_specs, **kwargs_specs) -> Exported:
mlir_module = lowering.stablehlo()

args_avals_flat, _ = tree_util.tree_flatten(lowered.in_avals)
if "out_mut" in lowering.compile_args:
if lowering.compile_args["out_mut"]: raise NotImplementedError
if "mut" in lowering.compile_args:
if lowering.compile_args["mut"]: raise NotImplementedError
if "kept_var_idx" in lowering.compile_args:
module_kept_var_idx = tuple(sorted(lowering.compile_args["kept_var_idx"]))
else:
Expand Down Expand Up @@ -747,7 +747,7 @@ def _check_lowering(lowering) -> None:
allowed_compile_args = [
"backend", "mesh", "global_in_avals",
"global_out_avals", "in_shardings", "out_shardings", "kept_var_idx",
"out_mut", "spmd_lowering", "auto_spmd_lowering",
"mut", "spmd_lowering", "auto_spmd_lowering",
"tuple_args", "ordered_effects", "unordered_effects",
"keepalive", "host_callbacks", "pmap_nreps", "committed",
"device_assignment", "jaxpr_debug_info", "shape_poly_state",
Expand Down
45 changes: 45 additions & 0 deletions tests/state_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,6 +1544,51 @@ def f(x_mut, y, z_mut, w):
self.assertAllClose(out1, 4 * jnp.ones((2, 3)), check_dtypes=False)
self.assertAllClose(out2, y + w, check_dtypes=False)

@parameterized.parameters([True, False])
def test_closed_over_basic(self, jit):
x_mut = core.mutable_array(jnp.zeros(3))
def f():
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5

if jit:
f = jax.jit(f)

f()

self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)

jaxpr = jax.make_jaxpr(f)()
self.assertTrue(any(isinstance(e, RefEffect) for e in jaxpr.effects))

@parameterized.parameters([True, False])
def test_closed_over_nested(self, jit):
x_mut = core.mutable_array(jnp.zeros(3))

@jax.jit
def f(y_mut, z):
x_mut[...] += 1.
x_mut[0] += 1
x_mut[1] += 5

y_mut[2] += 7
return z + 9

if jit:
f = jax.jit(f)

y_mut = core.mutable_array(np.zeros(3))

w = f(y_mut, 1)

self.assertAllClose(x_mut[...], jnp.array([2., 6., 1.]),
check_dtypes=False)
self.assertAllClose(y_mut[...], jnp.array([0., 0., 7.]),
check_dtypes=False)
self.assertAllClose(w, 10, check_dtypes=False)


if CAN_USE_HYPOTHESIS:

Expand Down

0 comments on commit 649cd50

Please sign in to comment.