Skip to content

Commit

Permalink
Create a SameDeviceAssignmentTuple type to cache the op shardings a…
Browse files Browse the repository at this point in the history
…nd device assignment. But the device_assignment is only cached once because `pjit` checks if all device assignments are equal or not right at the start.

PiperOrigin-RevId: 467051286
  • Loading branch information
yashk2810 authored and jax authors committed Aug 11, 2022
1 parent af18235 commit 3d2026a
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 40 deletions.
100 changes: 78 additions & 22 deletions jax/experimental/pjit.py
Expand Up @@ -12,17 +12,19 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import dataclasses
from enum import IntEnum
import numpy as np
from collections import OrderedDict, Counter
from typing import Callable, Sequence, Tuple, Union, cast, List
from typing import Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable
import itertools as it
from functools import partial, lru_cache

from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
from jax.experimental.sharding import (
MeshPspecSharding, Sharding, XLACompatibleSharding, OpShardingSharding)
MeshPspecSharding, Sharding, XLACompatibleSharding, OpShardingSharding,
XLADeviceAssignment)
from jax import core
from jax import linear_util as lu
from jax import stages
Expand Down Expand Up @@ -98,6 +100,12 @@ def _is_unspecified_or_from_gda_or_auto(x):
return _is_from_gda(x) or _is_auto(x) or _is_unspecified(x)


PjitSharding = Union[OpShardingSharding, _UnspecifiedValue, _AUTOAxisResource]
PjitShardingMinusUnspecified = Union[OpShardingSharding, _AUTOAxisResource]
MeshSharding = Union[MeshPspecSharding, _UnspecifiedValue, _AUTOAxisResource]
MeshShardingMinusUnspecified = Union[MeshPspecSharding, _AUTOAxisResource]


def _check_all_or_none_unspecified(axis_resources, name):
if not axis_resources:
return False
Expand Down Expand Up @@ -822,15 +830,57 @@ def _pjit_call_impl(*args, jaxpr,
return compiled.unsafe_call(*args)
pjit_p.def_impl(_pjit_call_impl)

@weakref_lru_cache

@dataclasses.dataclass(frozen=True)
class SameDeviceAssignmentTuple:
shardings: Tuple[PjitSharding, ...]
# device_assignment is Optional because shardings can contain `AUTO` and in
# that case `mesh` is compulsory to be used. So in that case
# `_pjit_lower_cached` cache, resource_env will check against the devices.
device_assignment: Optional[XLADeviceAssignment]

def __hash__(self):
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s
for s in self.shardings)
if self.device_assignment is None:
return hash(shardings_hash)
else:
return hash((shardings_hash, *self.device_assignment))

def __eq__(self, other):
if not isinstance(other, SameDeviceAssignmentTuple):
return False
return (all(pxla.are_op_shardings_equal(s._op_sharding, o._op_sharding)
if isinstance(s, OpShardingSharding) and isinstance(o, OpShardingSharding)
else s == o
for s, o in safe_zip(self.shardings, other.shardings)) and
self.device_assignment == other.device_assignment)


def _pjit_lower(
jaxpr: core.ClosedJaxpr,
in_shardings,
out_shardings,
*args, **kwargs):
da = _get_device_assignment(it.chain(in_shardings, out_shardings))
in_shardings = SameDeviceAssignmentTuple(in_shardings, da)
out_shardings = SameDeviceAssignmentTuple(out_shardings, da)
return _pjit_lower_cached(jaxpr, in_shardings, out_shardings, *args, **kwargs)


@weakref_lru_cache
def _pjit_lower_cached(
jaxpr: core.ClosedJaxpr,
sdat_in_shardings: SameDeviceAssignmentTuple,
sdat_out_shardings: SameDeviceAssignmentTuple,
resource_env,
donated_invars,
name: str,
in_is_global: Sequence[bool]):
in_shardings: Tuple[PjitShardingMinusUnspecified, ...] = cast(
Tuple[PjitShardingMinusUnspecified, ...], sdat_in_shardings.shardings)
out_shardings: Tuple[PjitSharding, ...] = sdat_out_shardings.shardings

pxla.resource_typecheck(jaxpr, resource_env, {}, lambda: "pjit")
f = core.jaxpr_as_fun(jaxpr)
f.__name__ = name
Expand All @@ -841,18 +891,20 @@ def _pjit_lower(
# MeshPspecSharding is required for host-local inputs too.
if not config.jax_array:
mesh = resource_env.physical_mesh
in_shardings = tuple(
i if _is_auto(i) or isinstance(i, MeshPspecSharding) else
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0])
for i in in_shardings
)
out_shardings = tuple(
o if _is_auto(o) or _is_unspecified(o) or isinstance(o, MeshPspecSharding) else
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0])
for o in out_shardings
)
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
Tuple[MeshShardingMinusUnspecified, ...], tuple(
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0])
if isinstance(i, OpShardingSharding) else i
for i in in_shardings
))
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
Tuple[MeshSharding, ...], tuple(
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0])
if isinstance(o, OpShardingSharding) else o
for o in out_shardings
))

# For `pjit(xmap)` cases, it needs to take the `lower_mesh_computation` path
# because `xmap` only supports SPMDAxisContext right now.
Expand Down Expand Up @@ -1004,13 +1056,7 @@ def _pjit_partial_eval(trace, *in_tracers,
jaxpr, in_shardings, out_shardings,
resource_env, donated_invars, name, in_positional_semantics,
out_positional_semantics):
da = None
for i in it.chain(in_shardings, out_shardings):
if _is_auto(i) or _is_unspecified(i):
continue
da = i._device_assignment
break

da = _get_device_assignment(it.chain(in_shardings, out_shardings))
in_pvals = [t.pval for t in in_tracers]

known_ins = tuple(pv.is_known() for pv in in_pvals)
Expand Down Expand Up @@ -1379,6 +1425,16 @@ def _get_in_positional_semantics(arg) -> maps._PositionalSemantics:
return maps._positional_semantics.val


def _get_device_assignment(shardings: Iterable[PjitSharding]) -> Optional[XLADeviceAssignment]:
da = None
for i in shardings:
if _is_auto(i) or _is_unspecified(i):
continue
da = i._device_assignment # type: ignore
break
return da


def _maybe_replace_from_gda_with_pspec(
in_shardings_flat, args_flat) -> Sequence[XLACompatibleSharding]:

Expand Down
17 changes: 11 additions & 6 deletions jax/experimental/sharding.py
Expand Up @@ -293,20 +293,25 @@ def _hash_op_sharding(op: xc.OpSharding):
class OpShardingSharding(XLACompatibleSharding):

def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = devices
self._devices = tuple(devices)
self._op_sharding = op_sharding

@pxla.maybe_cached_property
def _op_sharding_hash(self):
if xla_extension_version >= 81:
return hash(xc.HloSharding.from_proto(self._op_sharding))
else:
return _hash_op_sharding(self._op_sharding)

def __eq__(self, other):
if not isinstance(other, OpShardingSharding):
return False
return pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding)
return (pxla.are_op_shardings_equal(self._op_sharding, other._op_sharding) and
self._devices == other._devices)

def __hash__(self):
if not hasattr(self, '_hash'):
if xla_extension_version >= 81:
self._hash = hash(xc.HloSharding.from_proto(self._op_sharding))
else:
self._hash = _hash_op_sharding(self._op_sharding)
self._hash = hash((self._devices, self._op_sharding_hash))
return self._hash

def __repr__(self):
Expand Down
4 changes: 3 additions & 1 deletion jax/interpreters/pxla.py
Expand Up @@ -2296,7 +2296,9 @@ class TileManual:
TilingMethod = Union[TileVectorize, TileManual]


def _check_if_any_auto(shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource]]) -> bool:
def _check_if_any_auto(
shardings: Iterable[Union[XLACompatibleSharding, _AUTOAxisResource,
_UnspecifiedValue]]) -> bool:
for s in shardings:
if _is_auto(s):
return True
Expand Down
46 changes: 35 additions & 11 deletions tests/pjit_test.py
Expand Up @@ -1183,25 +1183,25 @@ def cb(index):
def f(x, y):
return x @ y.T

before_lower_cache = pjit_lib._pjit_lower.cache_info()
before_lower_cache = pjit_lib._pjit_lower_cached.cache_info()

f(gda_obj, gda_obj)
after_lower_cache1 = pjit_lib._pjit_lower.cache_info()
after_lower_cache1 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(before_lower_cache.hits, after_lower_cache1.hits)
self.assertEqual(before_lower_cache.misses + 1, after_lower_cache1.misses)

f(gda_obj, gda_obj)
after_lower_cache2 = pjit_lib._pjit_lower.cache_info()
after_lower_cache2 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(after_lower_cache1.hits + 1, after_lower_cache2.hits)
self.assertEqual(after_lower_cache1.misses, after_lower_cache2.misses)

f(input_data, input_data)
after_lower_cache3 = pjit_lib._pjit_lower.cache_info()
after_lower_cache3 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(after_lower_cache2.hits, after_lower_cache3.hits)
self.assertEqual(after_lower_cache2.misses + 1, after_lower_cache3.misses)

f(gda_obj, input_data)
after_lower_cache4 = pjit_lib._pjit_lower.cache_info()
after_lower_cache4 = pjit_lib._pjit_lower_cached.cache_info()
self.assertEqual(after_lower_cache3.hits, after_lower_cache4.hits)
self.assertEqual(after_lower_cache3.misses + 1, after_lower_cache4.misses)

Expand Down Expand Up @@ -1260,9 +1260,9 @@ def f(x):
out_gda = f(input_gda)
self.assertEqual(out_gda.mesh_axes, ())

before_cache = pjit_lib._pjit_lower.cache_info()
before_cache = pjit_lib._pjit_lower_cached.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower.cache_info()
after_cache = pjit_lib._pjit_lower_cached.cache_info()

self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)
Expand All @@ -1282,9 +1282,9 @@ def test_no_recompilation_due_to_fully_replicated_and_gda_inputs(self):
out_gda = f(global_data)
self.assertEqual(out_gda.mesh_axes, ())

before_cache = pjit_lib._pjit_lower.cache_info()
before_cache = pjit_lib._pjit_lower_cached.cache_info()
f(out_gda)
after_cache = pjit_lib._pjit_lower.cache_info()
after_cache = pjit_lib._pjit_lower_cached.cache_info()

self.assertEqual(before_cache.hits + 1, after_cache.hits)
self.assertEqual(before_cache.misses, after_cache.misses)
Expand Down Expand Up @@ -1818,14 +1818,38 @@ def test_pjit_single_device_sharding_cache(self):
f = pjit(lambda x: x)

out = f(a)
cache_info1 = pjit_lib._pjit_lower.cache_info()
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()

_ = f(out)
cache_info2 = pjit_lib._pjit_lower.cache_info()
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()

self.assertEqual(cache_info2.hits, cache_info1.hits + 1)
self.assertEqual(cache_info2.misses, cache_info1.misses)

@jax._src.config.jax_array(True)
def test_pjit_different_device_recompilation(self):
if jax.device_count() < 2:
raise unittest.SkipTest('Requires 2 or more devices.')

val1 = jnp.array([1, 2, 3], dtype=jnp.float32)
a = jax.device_put(val1, jax.devices()[0])

val2 = jnp.array([4, 5, 6], dtype=jnp.float32)
b = jax.device_put(val2, jax.devices()[1])

f = pjit(lambda x: x)

out1 = f(a)
cache_info1 = pjit_lib._pjit_lower_cached.cache_info()

out2 = f(b)
cache_info2 = pjit_lib._pjit_lower_cached.cache_info()

self.assertEqual(cache_info2.hits, cache_info1.hits)
self.assertEqual(cache_info2.misses, cache_info1.misses + 1)
self.assertArraysEqual(out1, val1)
self.assertArraysEqual(out2, val2)


def spec_regex(s):
return str(s).replace(r"(", r"\(").replace(r")", r"\)")
Expand Down

0 comments on commit 3d2026a

Please sign in to comment.