Skip to content

Commit

Permalink
Set in_positional_semantics should be GLOBAL for fully replicated val…
Browse files Browse the repository at this point in the history
…ues to avoid recompilation.

Split _pjit_jaxpr into 2 functions so that passing in `is_gda` as an argument to _pjit_jaxpr can be avoided which was leading to the cache invalidation.

PiperOrigin-RevId: 434825926
  • Loading branch information
yashk2810 authored and jax authors committed Mar 15, 2022
1 parent 4848c75 commit 846c480
Show file tree
Hide file tree
Showing 4 changed files with 102 additions and 51 deletions.
5 changes: 3 additions & 2 deletions jax/experimental/maps.py
Expand Up @@ -735,7 +735,8 @@ def make_xmap_callable(fun: lu.WrappedFun,
av if ips == _PositionalSemantics.GLOBAL else mesh._local_to_global(ax, av)
for ax, av, ips in safe_zip(mesh_in_axes, in_avals, in_positional_semantics)
]
in_is_gda = [ips == _PositionalSemantics.GLOBAL for ips in in_positional_semantics]
in_is_global = [ips == _PositionalSemantics.GLOBAL or not ia
for ips, ia in safe_zip(in_positional_semantics, mesh_in_axes)]
tiling_method: pxla.TilingMethod
if config.experimental_xmap_spmd_lowering_manual:
manual_mesh_axes = frozenset(it.chain.from_iterable(plan.physical_axis_resources.values()))
Expand All @@ -746,7 +747,7 @@ def make_xmap_callable(fun: lu.WrappedFun,
f, 'xmap', name, mesh,
mesh_in_axes, mesh_out_axes, donated_invars,
use_spmd_lowering, global_in_avals,
tiling_method=tiling_method, in_is_gda=in_is_gda)
tiling_method=tiling_method, in_is_global=in_is_global)
else:
return dispatch.lower_xla_callable(
f, None, backend, name, donated_invars, *((a, None) for a in in_avals))
Expand Down
77 changes: 46 additions & 31 deletions jax/experimental/pjit.py
Expand Up @@ -238,19 +238,21 @@ def infer_params(*args, **kwargs):
maps._PositionalSemantics.GLOBAL if isinstance(a, GDA) else maps._positional_semantics.val
for a in args_flat)
out_positional_semantics = maps._positional_semantics.val
jaxpr, in_axis_resources_flat, out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, local_in_avals, in_tree,
hashable_pytree(in_axis_resources),
HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources),
in_positional_semantics, out_positional_semantics,
tuple(isinstance(a, GDA) for a in args_flat))
in_axis_resources_flat = tree_map(_maybe_replace_from_gda_with_pspec,
in_axis_resources_flat, tuple(args_flat))

global_in_avals, canonicalized_in_axis_resources_flat = _process_in_axis_resources(
mesh, local_in_avals, hashable_pytree(in_axis_resources), in_tree,
in_positional_semantics, tuple(isinstance(a, GDA) for a in args_flat))
jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
flat_fun, mesh, global_in_avals, HashableFunction(out_tree, closure=()),
hashable_pytree(out_axis_resources))
canonicalized_in_axis_resources_flat = tree_map(
_maybe_replace_from_gda_with_pspec,
canonicalized_in_axis_resources_flat, tuple(args_flat))

params = dict(
jaxpr=jaxpr,
in_axis_resources=in_axis_resources_flat,
out_axis_resources=out_axis_resources_flat,
in_axis_resources=canonicalized_in_axis_resources_flat,
out_axis_resources=canonicalized_out_axis_resources_flat,
resource_env=resource_env,
donated_invars=donated_invars,
name=getattr(flat_fun, '__name__', '<unnamed function>'),
Expand All @@ -270,11 +272,13 @@ def wrapped(*args, **kwargs):
def lower(*args, **kwargs):
(args_flat, flat_local_in_avals, params, in_tree, out_tree,
donate_argnums) = infer_params(*args, **kwargs)
in_is_global = _calc_is_global_sequence(
params['in_positional_semantics'], params['in_axis_resources'])
lowering = _pjit_lower(
params['jaxpr'], params['in_axis_resources'],
params['out_axis_resources'], params['resource_env'],
params['donated_invars'], params['name'],
params['in_positional_semantics'], params['out_positional_semantics'])
in_is_global)

args_kwargs_in_tree = treedef_tuple([in_tree, tree_flatten({})[1]])
local_in_avals = args_kwargs_in_tree.unflatten(flat_local_in_avals)
Expand Down Expand Up @@ -352,12 +356,9 @@ class PytreeLeaf:
def __repr__(self): return "pytree leaf"


@lu.cache
def _pjit_jaxpr(fun, mesh, local_in_avals,
in_tree, in_axis_resources_thunk,
out_tree, out_axis_resources_thunk,
in_positional_semantics, out_positional_semantics, is_gda):
# TODO(yashkatariya): Make this work with FROM_GDA special value.
@cache()
def _process_in_axis_resources(mesh, local_in_avals, in_axis_resources_thunk,
in_tree, in_positional_semantics, is_gda):
in_axis_resources_flat = flatten_axis_resources(
"pjit in_axis_resources", in_tree,
in_axis_resources_thunk(), tupled_args=True)
Expand Down Expand Up @@ -388,7 +389,11 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,

global_in_avals = local_to_global(in_positional_semantics, mesh,
local_in_avals, canonicalized_in_axis_resources_flat)
return tuple(global_in_avals), canonicalized_in_axis_resources_flat


@lu.cache
def _pjit_jaxpr(fun, mesh, global_in_avals, out_tree, out_axis_resources_thunk):
prev_positional_val = maps._positional_semantics.val
try:
maps._positional_semantics.val = maps._PositionalSemantics.GLOBAL
Expand All @@ -407,8 +412,7 @@ def _pjit_jaxpr(fun, mesh, local_in_avals,
allow_uneven_sharding=False)
canonicalized_out_axis_resources_flat = tree_map(_create_cpspec, out_axis_resources_flat)
# lu.cache needs to be able to create weakrefs to outputs, so we can't return a plain tuple
return _ListWithW([jaxpr, canonicalized_in_axis_resources_flat,
canonicalized_out_axis_resources_flat])
return _ListWithW([jaxpr, canonicalized_out_axis_resources_flat])


class SpecSync(IntEnum):
Expand Down Expand Up @@ -492,9 +496,6 @@ def __repr__(self):
f"unsafe_user_spec={self.unsafe_user_spec}, "
f"sync={self.sync})")

REPLICATED = ParsedPartitionSpec(None, ())


class CanonicalizedParsedPartitionSpec(ParsedPartitionSpec):
"""ParsedPartitionSpecs that are canonicalized.
Expand Down Expand Up @@ -524,6 +525,9 @@ def __repr__(self):
f"sync={self.sync})")


REPLICATED = CanonicalizedParsedPartitionSpec(ParsedPartitionSpec(None, ()))


def _prepare_axis_resources(axis_resources,
arg_name,
allow_unconstrained_dims=False):
Expand Down Expand Up @@ -595,10 +599,10 @@ def _pjit_call_impl(*args, jaxpr,
in_axis_resources, out_axis_resources,
resource_env, donated_invars, name,
in_positional_semantics, out_positional_semantics):
in_is_global = _calc_is_global_sequence(in_positional_semantics, in_axis_resources)
compiled = _pjit_lower(
jaxpr, in_axis_resources, out_axis_resources,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics).compile()
resource_env, donated_invars, name, in_is_global).compile()
distributed_debug_log(("Running pjit'd function", name),
("mesh", resource_env.physical_mesh))
return compiled.unsafe_call(*args)
Expand All @@ -612,7 +616,7 @@ def _pjit_lower(
resource_env,
donated_invars,
name: str,
in_positional_semantics, out_positional_semantics):
in_is_global: Sequence[bool]):
# in_axis_resources and out_axis_resources are canonicalized to avoid
# recompilation (since pjit_lower is cached) if its compiled with `None` but
# in the next call `P(None)` is passed. Those are the same thing so should be
Expand All @@ -623,12 +627,10 @@ def _pjit_lower(
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
fun = lu.wrap_init(f)
in_is_gda = [ips == maps._PositionalSemantics.GLOBAL
for ips in in_positional_semantics]
return pxla.lower_mesh_computation(
fun, 'pjit', name, resource_env.physical_mesh,
in_axes, out_axes, donated_invars,
True, jaxpr.in_avals, tiling_method=None, in_is_gda=in_is_gda)
True, jaxpr.in_avals, tiling_method=None, in_is_global=in_is_global)


def _pjit_abstract_eval(*args, jaxpr, out_axis_resources, resource_env,
Expand Down Expand Up @@ -782,8 +784,14 @@ def keep_where(l, should_keep):
out_positional_semantics=out_positional_semantics)

if num_residuals:
executable = _pjit_lower(**known_params).compile(
_allow_propagation_to_outputs=True, _allow_compile_replicated=False)
in_is_global = _calc_is_global_sequence(
known_params['in_positional_semantics'], known_params['in_axis_resources'])
executable = _pjit_lower(
known_params["jaxpr"], known_params["in_axis_resources"],
known_params["out_axis_resources"], known_params["resource_env"],
known_params["donated_invars"], known_params["name"],
in_is_global).compile(_allow_propagation_to_outputs=True,
_allow_compile_replicated=False)
output_op_sharding = \
executable.xla_executable.hlo_modules()[0].spmd_output_sharding
output_sharding_specs = parse_op_sharding(output_op_sharding, mesh)
Expand Down Expand Up @@ -1053,6 +1061,13 @@ def local_to_global(positional_semantics, mesh, avals, axes):
for aval, aval_axes, ps in safe_zip(avals, axes, positional_semantics)
]


def _calc_is_global_sequence(in_positional_semantics, in_axis_resources):
return tuple(
ips == maps._PositionalSemantics.GLOBAL or p.partitions == ()
for ips, p in safe_zip(in_positional_semantics, in_axis_resources))


def _create_cpspec(x):
return x if _is_from_gda(x) else CanonicalizedParsedPartitionSpec(x)

Expand Down
22 changes: 12 additions & 10 deletions jax/interpreters/pxla.py
Expand Up @@ -2140,7 +2140,7 @@ def lower_mesh_computation(
spmd_lowering: bool,
global_in_avals: Sequence[core.ShapedArray],
tiling_method: Optional[TilingMethod],
in_is_gda: Sequence[bool]):
in_is_global: Sequence[bool]):
assert not mesh.empty
backend = xb.get_device_backend(mesh.devices.flat[0])
name_stack = new_name_stack(wrap_name(fun_name, api_name))
Expand Down Expand Up @@ -2236,7 +2236,7 @@ def lower_mesh_computation(
return MeshComputation(
str(name_stack), module, donated_invars, mesh=mesh, global_in_avals=global_in_avals,
global_out_avals=global_out_avals, in_axes=in_axes, out_axes=out_axes,
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_gda=in_is_gda)
spmd_lowering=spmd_lowering, tuple_args=tuple_args, in_is_global=in_is_global)


class MeshComputation:
Expand Down Expand Up @@ -2277,13 +2277,13 @@ def compile(self,
return self._executable


def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):
def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_global):
input_specs, input_indices, input_avals = [], [], []
num_local_devices = len(global_mesh.local_devices)
for gaval, axis, is_gda in safe_zip(global_in_avals, in_axes, in_is_gda):
for gaval, axis, is_global in safe_zip(global_in_avals, in_axes, in_is_global):
# TODO(yashkatariya): Don't calculate input_indices and input_specs for GDA
# as GDA doesn't need it.
if is_gda or not axis:
if is_global:
aval = gaval
mesh = global_mesh
else:
Expand All @@ -2292,9 +2292,11 @@ def _get_input_metadata(global_in_avals, global_mesh, in_axes, in_is_gda):

spec = (mesh_sharding_specs(mesh.shape, mesh.axis_names)(aval, axis)
if aval is not core.abstract_unit else None)
# We special case this logic to support fully replicated non-GDA values
# with non-contiguous submeshes
if not axis and not is_gda:
# We special case this logic to support fully replicated values because
# the mesh is global mesh and the indices returned by `spec_to_indices` will
# represent index for each device in the global mesh. But here we want
# indices for the local devices of the global mesh.
if not axis:
index = tuple((slice(None),) * aval.ndim for _ in range(num_local_devices))
else:
index = spec_to_indices(aval.shape, spec) if spec is not None else None
Expand Down Expand Up @@ -2323,7 +2325,7 @@ def from_hlo(name: str,
in_axes: Sequence[ArrayMapping],
out_axes: Sequence[ArrayMapping],
spmd_lowering: bool, tuple_args: bool,
in_is_gda: Sequence[bool],
in_is_global: Sequence[bool],
_allow_propagation_to_outputs: bool,
_allow_compile_replicated: bool) -> 'MeshExecutable':
assert not mesh.empty
Expand All @@ -2345,7 +2347,7 @@ def from_hlo(name: str,
_allow_propagation_to_outputs

input_specs, input_indices, input_avals = _get_input_metadata(
global_in_avals, mesh, in_axes, in_is_gda)
global_in_avals, mesh, in_axes, in_is_global)
# Calculate local information here instead of calculating it in
# `avals_to_results_handler` because pmap also uses this function.
handle_outs = global_avals_to_results_handler(global_out_avals, out_axes, mesh)
Expand Down
49 changes: 41 additions & 8 deletions tests/pjit_test.py
Expand Up @@ -1078,20 +1078,32 @@ def cb(index):
gda_obj = global_device_array.GlobalDeviceArray.from_callback(
input_shape, global_mesh, mesh_axes, cb)

trace_counter = [0]
@partial(pjit, in_axis_resources=mesh_axes, out_axis_resources=P('x', 'y'))
def f(x, y):
trace_counter[0] += 1
return x @ y.T

before_lower_cache = pjit_lib._pjit_lower.cache_info()

f(gda_obj, gda_obj)
self.assertListEqual(trace_counter, [1])
after_lower_cache1 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits)
self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses)

f(gda_obj, gda_obj)
self.assertListEqual(trace_counter, [1])
after_lower_cache2 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits)
self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses)

f(input_data, input_data)
self.assertListEqual(trace_counter, [2])
after_lower_cache3 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits)
self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses)

f(gda_obj, input_data)
self.assertListEqual(trace_counter, [3])
after_lower_cache4 = pjit_lib._pjit_lower.cache_info()
self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits)
self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses)


@jtu.with_mesh([('x', 4), ('y', 2)])
def test_partition_spec_mismatch_semantically_equivalent(self):
Expand Down Expand Up @@ -1143,15 +1155,36 @@ def test_no_recompilation_due_to_in_axis_resources(self):
def f(x):
return x

with maps.Mesh(global_mesh.devices, global_mesh.axis_names):
with global_mesh:
out_gda = f(input_gda)
self.assertEqual(out_gda.mesh_axes, ())

before_cache = pjit_lib._pjit_lower.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower.cache_info()

self.assertNotEqual(id(before_cache), id(after_cache))
self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)

def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
global_mesh = jtu.create_global_mesh((1, 2), ('x', 'y'))
global_input_shape = (8, 2)
mesh_axes = P(None)
global_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)

with jax._src.config.parallel_functions_output_gda(True):
f = pjit(lambda x: x, in_axis_resources=mesh_axes,
out_axis_resources=mesh_axes)

with global_mesh:
out_gda = f(global_data)
self.assertEqual(out_gda.mesh_axes, ())

before_cache = pjit_lib._pjit_lower.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower.cache_info()

self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)

Expand Down

0 comments on commit 846c480

Please sign in to comment.