From 3d2026a1d07aef5255f6539613bf0145df67bdcf Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Thu, 11 Aug 2022 14:35:28 -0700 Subject: [PATCH] Create a `SameDeviceAssignmentTuple` type to cache the op shardings and device assignment. But the device_assignment is only cached once because `pjit` checks if all device assignments are equal or not right at the start. PiperOrigin-RevId: 467051286 --- jax/experimental/pjit.py | 100 +++++++++++++++++++++++++++-------- jax/experimental/sharding.py | 17 +++--- jax/interpreters/pxla.py | 4 +- tests/pjit_test.py | 46 ++++++++++++---- 4 files changed, 127 insertions(+), 40 deletions(-) diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 566891ad3534..9c483eaafc56 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -12,17 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import dataclasses from enum import IntEnum import numpy as np from collections import OrderedDict, Counter -from typing import Callable, Sequence, Tuple, Union, cast, List +from typing import Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable import itertools as it from functools import partial, lru_cache from jax.experimental import maps from jax.experimental.global_device_array import GlobalDeviceArray as GDA from jax.experimental.sharding import ( - MeshPspecSharding, Sharding, XLACompatibleSharding, OpShardingSharding) + MeshPspecSharding, Sharding, XLACompatibleSharding, OpShardingSharding, + XLADeviceAssignment) from jax import core from jax import linear_util as lu from jax import stages @@ -98,6 +100,12 @@ def _is_unspecified_or_from_gda_or_auto(x): return _is_from_gda(x) or _is_auto(x) or _is_unspecified(x) +PjitSharding = Union[OpShardingSharding, _UnspecifiedValue, _AUTOAxisResource] +PjitShardingMinusUnspecified = Union[OpShardingSharding, _AUTOAxisResource] +MeshSharding = Union[MeshPspecSharding, _UnspecifiedValue, _AUTOAxisResource] +MeshShardingMinusUnspecified = Union[MeshPspecSharding, _AUTOAxisResource] + + def _check_all_or_none_unspecified(axis_resources, name): if not axis_resources: return False @@ -822,15 +830,57 @@ def _pjit_call_impl(*args, jaxpr, return compiled.unsafe_call(*args) pjit_p.def_impl(_pjit_call_impl) -@weakref_lru_cache + +@dataclasses.dataclass(frozen=True) +class SameDeviceAssignmentTuple: + shardings: Tuple[PjitSharding, ...] + # device_assignment is Optional because shardings can contain `AUTO` and in + # that case `mesh` is compulsory to be used. So in that case + # `_pjit_lower_cached` cache, resource_env will check against the devices. + device_assignment: Optional[XLADeviceAssignment] + + def __hash__(self): + shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s + for s in self.shardings) + if self.device_assignment is None: + return hash(shardings_hash) + else: + return hash((shardings_hash, *self.device_assignment)) + + def __eq__(self, other): + if not isinstance(other, SameDeviceAssignmentTuple): + return False + return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding) + if isinstance(s, OpShardingSharding) and isinstance(o, OpShardingSharding) + else s == o + for s, o in safe_zip(self.shardings, other.shardings)) and + self.device_assignment == other.device_assignment) + + def _pjit_lower( jaxpr: core.ClosedJaxpr, in_shardings, out_shardings, + *args, **kwargs): + da = _get_device_assignment(it.chain(in_shardings, out_shardings)) + in_shardings = SameDeviceAssignmentTuple(in_shardings, da) + out_shardings = SameDeviceAssignmentTuple(out_shardings, da) + return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs) + + +@weakref_lru_cache +def _pjit_lower_cached( + jaxpr: core.ClosedJaxpr, + sdat_in_shardings: SameDeviceAssignmentTuple, + sdat_out_shardings: SameDeviceAssignmentTuple, resource_env, donated_invars, name: str, in_is_global: Sequence[bool]): + in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast( + Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings) + out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings + pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit") f = core.jaxpr_as_fun(jaxpr) f.__name__ = name @@ -841,18 +891,20 @@ def _pjit_lower( # MeshPspecSharding is required for host-local inputs too. if not config.jax_array: mesh = resource_env.physical_mesh - in_shardings = tuple( - i if _is_auto(i) or isinstance(i, MeshPspecSharding) else - MeshPspecSharding._from_parsed_pspec( - mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) - for i in in_shardings - ) - out_shardings = tuple( - o if _is_auto(o) or _is_unspecified(o) or isinstance(o, MeshPspecSharding) else - MeshPspecSharding._from_parsed_pspec( - mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) - for o in out_shardings - ) + in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef] + Tuple[MeshShardingMinusUnspecified, ...], tuple( + MeshPspecSharding._from_parsed_pspec( + mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) + if isinstance(i, OpShardingSharding) else i + for i in in_shardings + )) + out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef] + Tuple[MeshSharding, ...], tuple( + MeshPspecSharding._from_parsed_pspec( + mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) + if isinstance(o, OpShardingSharding) else o + for o in out_shardings + )) # For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path # because `xmap` only supports SPMDAxisContext right now. @@ -1004,13 +1056,7 @@ def _pjit_partial_eval(trace, *in_tracers, jaxpr, in_shardings, out_shardings, resource_env, donated_invars, name, in_positional_semantics, out_positional_semantics): - da = None - for i in it.chain(in_shardings, out_shardings): - if _is_auto(i) or _is_unspecified(i): - continue - da = i._device_assignment - break - + da = _get_device_assignment(it.chain(in_shardings, out_shardings)) in_pvals = [t.pval for t in in_tracers] known_ins = tuple(pv.is_known() for pv in in_pvals) @@ -1379,6 +1425,16 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics: return maps._positional_semantics.val +def _get_device_assignment(shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]: + da = None + for i in shardings: + if _is_auto(i) or _is_unspecified(i): + continue + da = i._device_assignment # type: ignore + break + return da + + def _maybe_replace_from_gda_with_pspec( in_shardings_flat, args_flat) -> Sequence[XLACompatibleSharding]: diff --git a/jax/experimental/sharding.py b/jax/experimental/sharding.py index cff367b7700e..66c7aed8c58d 100644 --- a/jax/experimental/sharding.py +++ b/jax/experimental/sharding.py @@ -293,20 +293,25 @@ def _hash_op_sharding(op: xc.OpSharding): class OpShardingSharding(XLACompatibleSharding): def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding): - self._devices = devices + self._devices = tuple(devices) self._op_sharding = op_sharding + @pxla.maybe_cached_property + def _op_sharding_hash(self): + if xla_extension_version >= 81: + return hash(xc.HloSharding.from_proto(self._op_sharding)) + else: + return _hash_op_sharding(self._op_sharding) + def __eq__(self, other): if not isinstance(other, OpShardingSharding): return False - return pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) + return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and + self._devices == other._devices) def __hash__(self): if not hasattr(self, '_hash'): - if xla_extension_version >= 81: - self._hash = hash(xc.HloSharding.from_proto(self._op_sharding)) - else: - self._hash = _hash_op_sharding(self._op_sharding) + self._hash = hash((self._devices, self._op_sharding_hash)) return self._hash def __repr__(self): diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 536079149811..2118ff46f362 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -2296,7 +2296,9 @@ class TileManual: TilingMethod = Union[TileVectorize, TileManual] -def _check_if_any_auto(shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool: +def _check_if_any_auto( + shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource, + _UnspecifiedValue]]) -> bool: for s in shardings: if _is_auto(s): return True diff --git a/tests/pjit_test.py b/tests/pjit_test.py index 2c7f0520f56f..1b66b9fca794 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -1183,25 +1183,25 @@ def cb(index): def f(x, y): return x @ y.T - before_lower_cache = pjit_lib._pjit_lower.cache_info() + before_lower_cache = pjit_lib._pjit_lower_cached.cache_info() f(gda_obj, gda_obj) - after_lower_cache1 = pjit_lib._pjit_lower.cache_info() + after_lower_cache1 = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits) self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses) f(gda_obj, gda_obj) - after_lower_cache2 = pjit_lib._pjit_lower.cache_info() + after_lower_cache2 = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits) self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses) f(input_data, input_data) - after_lower_cache3 = pjit_lib._pjit_lower.cache_info() + after_lower_cache3 = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits) self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses) f(gda_obj, input_data) - after_lower_cache4 = pjit_lib._pjit_lower.cache_info() + after_lower_cache4 = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits) self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses) @@ -1260,9 +1260,9 @@ def f(x): out_gda = f(input_gda) self.assertEqual(out_gda.mesh_axes, ()) - before_cache = pjit_lib._pjit_lower.cache_info() + before_cache = pjit_lib._pjit_lower_cached.cache_info() f(out_gda) - after_cache = pjit_lib._pjit_lower.cache_info() + after_cache = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(before_cache.hits + 1, after_cache.hits) self.assertEqual(before_cache.misses, after_cache.misses) @@ -1282,9 +1282,9 @@ def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self): out_gda = f(global_data) self.assertEqual(out_gda.mesh_axes, ()) - before_cache = pjit_lib._pjit_lower.cache_info() + before_cache = pjit_lib._pjit_lower_cached.cache_info() f(out_gda) - after_cache = pjit_lib._pjit_lower.cache_info() + after_cache = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(before_cache.hits + 1, after_cache.hits) self.assertEqual(before_cache.misses, after_cache.misses) @@ -1818,14 +1818,38 @@ def test_pjit_single_device_sharding_cache(self): f = pjit(lambda x: x) out = f(a) - cache_info1 = pjit_lib._pjit_lower.cache_info() + cache_info1 = pjit_lib._pjit_lower_cached.cache_info() _ = f(out) - cache_info2 = pjit_lib._pjit_lower.cache_info() + cache_info2 = pjit_lib._pjit_lower_cached.cache_info() self.assertEqual(cache_info2.hits, cache_info1.hits + 1) self.assertEqual(cache_info2.misses, cache_info1.misses) + @jax._src.config.jax_array(True) + def test_pjit_different_device_recompilation(self): + if jax.device_count() < 2: + raise unittest.SkipTest('Requires 2 or more devices.') + + val1 = jnp.array([1, 2, 3], dtype=jnp.float32) + a = jax.device_put(val1, jax.devices()[0]) + + val2 = jnp.array([4, 5, 6], dtype=jnp.float32) + b = jax.device_put(val2, jax.devices()[1]) + + f = pjit(lambda x: x) + + out1 = f(a) + cache_info1 = pjit_lib._pjit_lower_cached.cache_info() + + out2 = f(b) + cache_info2 = pjit_lib._pjit_lower_cached.cache_info() + + self.assertEqual(cache_info2.hits, cache_info1.hits) + self.assertEqual(cache_info2.misses, cache_info1.misses + 1) + self.assertArraysEqual(out1, val1) + self.assertArraysEqual(out2, val2) + def spec_regex(s): return str(s).replace(r"(", r"\(").replace(r")", r"\)")