diff --git a/jax/experimental/maps.py b/jax/experimental/maps.py index 84d80f61ac99..914d6166b966 100644 --- a/jax/experimental/maps.py +++ b/jax/experimental/maps.py @@ -735,7 +735,8 @@ def make_xmap_callable(fun: lu.WrappedFun, av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av) for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics) ] - in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics] + in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia + for ips, ia in safe_zip(in_positional_semantics, mesh_in_axes)] tiling_method: pxla.TilingMethod if config.experimental_xmap_spmd_lowering_manual: manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values())) @@ -746,7 +747,7 @@ def make_xmap_callable(fun: lu.WrappedFun, f, 'xmap', name, mesh, mesh_in_axes, mesh_out_axes, donated_invars, use_spmd_lowering, global_in_avals, - tiling_method=tiling_method, in_is_gda=in_is_gda) + tiling_method=tiling_method, in_is_global=in_is_global) else: return dispatch.lower_xla_callable( f, None, backend, name, donated_invars, *((a, None) for a in in_avals)) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index a9c132a8dcfe..c61e76a78b79 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -238,19 +238,21 @@ def infer_params(*args, **kwargs): maps._PositionalSemantics.GLOBAL if isinstance(a, GDA) else maps._positional_semantics.val for a in args_flat) out_positional_semantics = maps._positional_semantics.val - jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr( - flat_fun, mesh, local_in_avals, in_tree, - hashable_pytree(in_axis_resources), - HashableFunction(out_tree, closure=()), - hashable_pytree(out_axis_resources), - in_positional_semantics, out_positional_semantics, - tuple(isinstance(a, GDA) for a in args_flat)) - in_axis_resources_flat = tree_map(_maybe_replace_from_gda_with_pspec, - in_axis_resources_flat, tuple(args_flat)) + + global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources( + mesh, local_in_avals, hashable_pytree(in_axis_resources), in_tree, + in_positional_semantics, tuple(isinstance(a, GDA) for a in args_flat)) + jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr( + flat_fun, mesh, global_in_avals, HashableFunction(out_tree, closure=()), + hashable_pytree(out_axis_resources)) + canonicalized_in_axis_resources_flat = tree_map( + _maybe_replace_from_gda_with_pspec, + canonicalized_in_axis_resources_flat, tuple(args_flat)) + params = dict( jaxpr=jaxpr, - in_axis_resources=in_axis_resources_flat, - out_axis_resources=out_axis_resources_flat, + in_axis_resources=canonicalized_in_axis_resources_flat, + out_axis_resources=canonicalized_out_axis_resources_flat, resource_env=resource_env, donated_invars=donated_invars, name=getattr(flat_fun, '__name__', ''), @@ -270,11 +272,13 @@ def wrapped(*args, **kwargs): def lower(*args, **kwargs): (args_flat, flat_local_in_avals, params, in_tree, out_tree, donate_argnums) = infer_params(*args, **kwargs) + in_is_global = _calc_is_global_sequence( + params['in_positional_semantics'], params['in_axis_resources']) lowering = _pjit_lower( params['jaxpr'], params['in_axis_resources'], params['out_axis_resources'], params['resource_env'], params['donated_invars'], params['name'], - params['in_positional_semantics'], params['out_positional_semantics']) + in_is_global) args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]]) local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals) @@ -352,12 +356,9 @@ class PytreeLeaf: def __repr__(self): return "pytree leaf" -@lu.cache -def _pjit_jaxpr(fun, mesh, local_in_avals, - in_tree, in_axis_resources_thunk, - out_tree, out_axis_resources_thunk, - in_positional_semantics, out_positional_semantics, is_gda): - # TODO(yashkatariya): Make this work with FROM_GDA special value. +@cache() +def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk, + in_tree, in_positional_semantics, is_gda): in_axis_resources_flat = flatten_axis_resources( "pjit in_axis_resources", in_tree, in_axis_resources_thunk(), tupled_args=True) @@ -388,7 +389,11 @@ def _pjit_jaxpr(fun, mesh, local_in_avals, global_in_avals = local_to_global(in_positional_semantics, mesh, local_in_avals, canonicalized_in_axis_resources_flat) + return tuple(global_in_avals), canonicalized_in_axis_resources_flat + +@lu.cache +def _pjit_jaxpr(fun, mesh, global_in_avals, out_tree, out_axis_resources_thunk): prev_positional_val = maps._positional_semantics.val try: maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL @@ -407,8 +412,7 @@ def _pjit_jaxpr(fun, mesh, local_in_avals, allow_uneven_sharding=False) canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat) # lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple - return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat, - canonicalized_out_axis_resources_flat]) + return _ListWithW([jaxpr, canonicalized_out_axis_resources_flat]) class SpecSync(IntEnum): @@ -492,9 +496,6 @@ def __repr__(self): f"unsafe_user_spec={self.unsafe_user_spec}, " f"sync={self.sync})") -REPLICATED = ParsedPartitionSpec(None, ()) - - class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec): """ParsedPartitionSpecs that are canonicalized. @@ -524,6 +525,9 @@ def __repr__(self): f"sync={self.sync})") +REPLICATED = CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ())) + + def _prepare_axis_resources(axis_resources, arg_name, allow_unconstrained_dims=False): @@ -595,10 +599,10 @@ def _pjit_call_impl(*args, jaxpr, in_axis_resources, out_axis_resources, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): + in_is_global = _calc_is_global_sequence(in_positional_semantics, in_axis_resources) compiled = _pjit_lower( jaxpr, in_axis_resources, out_axis_resources, - resource_env, donated_invars, name, in_positional_semantics, - out_positional_semantics).compile() + resource_env, donated_invars, name, in_is_global).compile() distributed_debug_log(("Running pjit'd function", name), ("mesh", resource_env.physical_mesh)) return compiled.unsafe_call(*args) @@ -612,7 +616,7 @@ def _pjit_lower( resource_env, donated_invars, name: str, - in_positional_semantics, out_positional_semantics): + in_is_global: Sequence[bool]): # in_axis_resources and out_axis_resources are canonicalized to avoid # recompilation (since pjit_lower is cached) if its compiled with `None` but # in the next call `P(None)` is passed. Those are the same thing so should be @@ -623,12 +627,10 @@ def _pjit_lower( f = core.jaxpr_as_fun(jaxpr) f.__name__ = name fun = lu.wrap_init(f) - in_is_gda = [ips == maps._PositionalSemantics.GLOBAL - for ips in in_positional_semantics] return pxla.lower_mesh_computation( fun, 'pjit', name, resource_env.physical_mesh, in_axes, out_axes, donated_invars, - True, jaxpr.in_avals, tiling_method=None, in_is_gda=in_is_gda) + True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global) def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env, @@ -782,8 +784,14 @@ def keep_where(l, should_keep): out_positional_semantics=out_positional_semantics) if num_residuals: - executable = _pjit_lower(**known_params).compile( - _allow_propagation_to_outputs=True, _allow_compile_replicated=False) + in_is_global = _calc_is_global_sequence( + known_params['in_positional_semantics'], known_params['in_axis_resources']) + executable = _pjit_lower( + known_params["jaxpr"], known_params["in_axis_resources"], + known_params["out_axis_resources"], known_params["resource_env"], + known_params["donated_invars"], known_params["name"], + in_is_global).compile(_allow_propagation_to_outputs=True, + _allow_compile_replicated=False) output_op_sharding = \ executable.xla_executable.hlo_modules()[0].spmd_output_sharding output_sharding_specs = parse_op_sharding(output_op_sharding, mesh) @@ -1053,6 +1061,13 @@ def local_to_global(positional_semantics, mesh, avals, axes): for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics) ] + +def _calc_is_global_sequence(in_positional_semantics, in_axis_resources): + return tuple( + ips == maps._PositionalSemantics.GLOBAL or p.partitions == () + for ips, p in safe_zip(in_positional_semantics, in_axis_resources)) + + def _create_cpspec(x): return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x) diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index fd9a2544063a..83cf78c1f148 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2140,7 +2140,7 @@ def lower_mesh_computation( spmd_lowering: bool, global_in_avals: Sequence[core.ShapedArray], tiling_method: Optional[TilingMethod], - in_is_gda: Sequence[bool]): + in_is_global: Sequence[bool]): assert not mesh.empty backend = xb.get_device_backend(mesh.devices.flat[0]) name_stack = new_name_stack(wrap_name(fun_name, api_name)) @@ -2236,7 +2236,7 @@ def lower_mesh_computation( return MeshComputation( str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals, global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes, - spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda) + spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_global=in_is_global) class MeshComputation: @@ -2277,13 +2277,13 @@ def compile(self, return self._executable -def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda): +def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_global): input_specs, input_indices, input_avals = [], [], [] num_local_devices = len(global_mesh.local_devices) - for gaval, axis, is_gda in safe_zip(global_in_avals, in_axes, in_is_gda): + for gaval, axis, is_global in safe_zip(global_in_avals, in_axes, in_is_global): # TODO(yashkatariya): Don't calculate input_indices and input_specs for GDA # as GDA doesn't need it. - if is_gda or not axis: + if is_global: aval = gaval mesh = global_mesh else: @@ -2292,9 +2292,11 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda): spec = (mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, axis) if aval is not core.abstract_unit else None) - # We special case this logic to support fully replicated non-GDA values - # with non-contiguous submeshes - if not axis and not is_gda: + # We special case this logic to support fully replicated values because + # the mesh is global mesh and the indices returned by `spec_to_indices` will + # represent index for each device in the global mesh. But here we want + # indices for the local devices of the global mesh. + if not axis: index = tuple((slice(None),) * aval.ndim for _ in range(num_local_devices)) else: index = spec_to_indices(aval.shape, spec) if spec is not None else None @@ -2323,7 +2325,7 @@ def from_hlo(name: str, in_axes: Sequence[ArrayMapping], out_axes: Sequence[ArrayMapping], spmd_lowering: bool, tuple_args: bool, - in_is_gda: Sequence[bool], + in_is_global: Sequence[bool], _allow_propagation_to_outputs: bool, _allow_compile_replicated: bool) -> 'MeshExecutable': assert not mesh.empty @@ -2345,7 +2347,7 @@ def from_hlo(name: str, _allow_propagation_to_outputs input_specs, input_indices, input_avals = _get_input_metadata( - global_in_avals, mesh, in_axes, in_is_gda) + global_in_avals, mesh, in_axes, in_is_global) # Calculate local information here instead of calculating it in # `avals_to_results_handler` because pmap also uses this function. handle_outs = global_avals_to_results_handler(global_out_avals, out_axes, mesh) diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 5ecbdd49798b..63961b7fbaf1 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1078,20 +1078,32 @@ def cb(index): gda_obj = global_device_array.GlobalDeviceArray.from_callback( input_shape, global_mesh, mesh_axes, cb) - trace_counter = [0] @partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y')) def f(x, y): - trace_counter[0] += 1 return x @ y.T + before_lower_cache = pjit_lib._pjit_lower.cache_info() + f(gda_obj, gda_obj) - self.assertListEqual(trace_counter, [1]) + after_lower_cache1 = pjit_lib._pjit_lower.cache_info() + self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits) + self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses) + f(gda_obj, gda_obj) - self.assertListEqual(trace_counter, [1]) + after_lower_cache2 = pjit_lib._pjit_lower.cache_info() + self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits) + self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses) + f(input_data, input_data) - self.assertListEqual(trace_counter, [2]) + after_lower_cache3 = pjit_lib._pjit_lower.cache_info() + self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits) + self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses) + f(gda_obj, input_data) - self.assertListEqual(trace_counter, [3]) + after_lower_cache4 = pjit_lib._pjit_lower.cache_info() + self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits) + self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses) + @jtu.with_mesh([('x', 4), ('y', 2)]) def test_partition_spec_mismatch_semantically_equivalent(self): @@ -1143,7 +1155,7 @@ def test_no_recompilation_due_to_in_axis_resources(self): def f(x): return x - with maps.Mesh(global_mesh.devices, global_mesh.axis_names): + with global_mesh: out_gda = f(input_gda) self.assertEqual(out_gda.mesh_axes, ()) @@ -1151,7 +1163,28 @@ def f(x): f(out_gda) after_cache = pjit_lib._pjit_lower.cache_info() - self.assertNotEqual(id(before_cache), id(after_cache)) + self.assertEqual(before_cache.hits + 1, after_cache.hits) + self.assertEqual(before_cache.misses, after_cache.misses) + + def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self): + global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y')) + global_input_shape = (8, 2) + mesh_axes = P(None) + global_data = np.arange( + prod(global_input_shape)).reshape(global_input_shape) + + with jax._src.config.parallel_functions_output_gda(True): + f = pjit(lambda x: x, in_axis_resources=mesh_axes, + out_axis_resources=mesh_axes) + + with global_mesh: + out_gda = f(global_data) + self.assertEqual(out_gda.mesh_axes, ()) + + before_cache = pjit_lib._pjit_lower.cache_info() + f(out_gda) + after_cache = pjit_lib._pjit_lower.cache_info() + self.assertEqual(before_cache.hits + 1, after_cache.hits) self.assertEqual(before_cache.misses, after_cache.misses)