Skip to content

Commit

Permalink
Convert in_shardings to physical shardings in cpp dispatch path becau…
Browse files Browse the repository at this point in the history
…se the same happens with prng arrays.

Also comment out key reuse check in cpp dispatch since it's True for jax tests which prevent prng keys from taking Cpp dispatch.

PiperOrigin-RevId: 613289252
  • Loading branch information
yashk2810 authored and jax authors committed Mar 6, 2024
1 parent fc8dc83 commit 1cb8d31
Show file tree
Hide file tree
Showing 9 changed files with 109 additions and 47 deletions.
21 changes: 18 additions & 3 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -2471,6 +2471,15 @@ def _register_out_sharding_handler(
_orig_out_sharding_handlers[sharding_cls] = handler


def _gspmd_to_named_sharding_via_mesh(
out_s: sharding_impls.GSPMDSharding,
mesh: Mesh) -> sharding_impls.NamedSharding:
parsed_pspec = sharding_impls.parse_flatten_op_sharding(
out_s._hlo_sharding, mesh)[0]
return create_mesh_pspec_sharding(
mesh, parsed_pspec.get_partition_spec(), parsed_pspec,
out_s.memory_kind)

def _gspmd_to_named_sharding(
out_s: sharding_impls.GSPMDSharding,
orig_in_s: sharding_impls.NamedSharding) -> sharding_impls.NamedSharding:
Expand Down Expand Up @@ -2688,7 +2697,7 @@ def _maybe_get_and_check_in_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
new_in_shardings.append(xla_s)
else:
# TODO(yashkatariya): Remove the if branch for abstract_token once
Expand Down Expand Up @@ -2726,7 +2735,7 @@ def _maybe_get_and_check_out_shardings(
if is_unspecified(orig):
if (aval is not core.abstract_token and
dtypes.issubdtype(aval.dtype, dtypes.extended)):
xla_s = aval.dtype._rules.logical_op_sharding(aval, xla_s)
xla_s = aval.dtype._rules.logical_sharding(aval, xla_s)
new_out_shardings.append(xla_s)
else:
xla_hlo_s = xla_s._to_xla_hlo_sharding(aval.ndim) # type: ignore
Expand Down Expand Up @@ -3031,8 +3040,14 @@ def aot_cache_miss(*args, **kwargs):
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
a.dtype._rules.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(self._in_shardings, self.in_avals)
]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, self._in_shardings,
self.xla_executable, out_tree_dispatch, in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec,
self.unsafe_call.in_handler.local_devices,
self.unsafe_call.in_handler.input_indices)
Expand Down
6 changes: 5 additions & 1 deletion jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5116,9 +5116,13 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
return hlo_sharding

@staticmethod
def logical_op_sharding(aval, phys_sharding):
def logical_sharding(aval, phys_sharding):
return phys_sharding

@staticmethod
def physical_sharding(aval, sharding):
return sharding

@staticmethod
def convert_from(bint_dtype, other_dtype) -> bool:
return other_dtype in (np.dtype('int32'), np.dtype('int64'))
Expand Down
29 changes: 17 additions & 12 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,20 +186,20 @@ def _get_fastpath_data(
out_reflattened, out_tree = pxla.reflatten_outputs_for_dispatch(out_tree, out_flat)

use_fastpath = (
executable is not None and
isinstance(executable, pxla.MeshExecutable) and
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
executable is not None
and isinstance(executable, pxla.MeshExecutable)
and isinstance(executable.unsafe_call, pxla.ExecuteReplicated)
# No effects in computation
not executable.unsafe_call.ordered_effects and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.ArrayImpl) for x in out_reflattened) and
and not executable.unsafe_call.ordered_effects
and not executable.unsafe_call.has_unordered_effects
and not executable.unsafe_call.has_host_callbacks
and all(isinstance(x, xc.ArrayImpl) for x in out_reflattened)
# no attr state effects
not attrs_tracked and
and not attrs_tracked
# no ref state effects
not any(isinstance(e, RefEffect) for e in effects) and
and not any(isinstance(e, RefEffect) for e in effects)
# no prng reuse checking
not (config.enable_key_reuse_checks.value and any(
and not (config.enable_key_reuse_checks.value and any(
hasattr(arg, 'dtype') and dtypes.issubdtype(arg.dtype, dtypes.prng_key)
for arg in (*args_flat, *out_flat)))
)
Expand All @@ -209,8 +209,14 @@ def _get_fastpath_data(
out_committed = [o._committed for o in out_reflattened]
kept_var_bitvec = [i in executable._kept_var_idx
for i in range(len(args_flat))]
in_shardings = [
a.dtype._rules.physical_sharding(a, s)
if a is not core.abstract_token and dtypes.issubdtype(a.dtype, dtypes.extended)
else s
for s, a in zip(executable._in_shardings, executable.in_avals)
]
fastpath_data = pxla.MeshExecutableFastpathData(
executable.xla_executable, out_tree, executable._in_shardings,
executable.xla_executable, out_tree, in_shardings,
executable._out_shardings, out_avals, out_committed, kept_var_bitvec,
executable.unsafe_call.in_handler.local_devices,
executable.unsafe_call.in_handler.input_indices)
Expand Down Expand Up @@ -2084,7 +2090,6 @@ def _pjit_pp_rule(eqn, context, settings):
core.pp_eqn_rules[pjit_p] = _pjit_pp_rule



def _pjit_state_discharge_rule(
in_avals, out_avals, *args, jaxpr, in_shardings, out_shardings, **params):
if not (all(map(is_unspecified, in_shardings)) and
Expand Down
45 changes: 27 additions & 18 deletions jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,7 @@ def global_shards(self) -> list[Shard]:
@property
def sharding(self):
phys_sharding = self._base_array.sharding
return KeyTyRules.logical_op_sharding(self.aval, phys_sharding)
return KeyTyRules.logical_sharding(self.aval, phys_sharding)

def _is_scalar(self):
base_ndim = len(self._impl.key_shape)
Expand Down Expand Up @@ -345,6 +345,22 @@ def make_key_array_phys_sharding(aval, sharding):
sharding._device_assignment,
KeyTyRules.physical_hlo_sharding(aval, hlos))


def get_logical_gspmd_sharding(aval, phys_sharding):
key_shape = aval.dtype._impl.key_shape
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape))
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-len(key_shape)] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))


class KeyTyRules:

@staticmethod
Expand Down Expand Up @@ -378,7 +394,12 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding) -> xc.HloSharding:
return xc.HloSharding.from_proto(new_op_sharding)

@staticmethod
def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
def physical_sharding(
aval, sharding: XLACompatibleSharding) -> XLACompatibleSharding:
return make_key_array_phys_sharding(aval, sharding)

@staticmethod
def logical_sharding(aval, phys_sharding) -> XLACompatibleSharding:
# The trailing dims should always be replicated.
aval.dtype._rules.check_replicated_trailing_dims(phys_sharding, aval)

Expand All @@ -392,23 +413,11 @@ def logical_op_sharding(aval, phys_sharding) -> XLACompatibleSharding:
return PmapSharding(devices=phys_sharding.devices,
sharding_spec=logical_sharding_spec)
elif isinstance(phys_sharding, NamedSharding):
key_shape = aval.dtype._impl.key_shape
return pxla.create_mesh_pspec_sharding(
phys_sharding.mesh,
PartitionSpec(*phys_sharding.spec[:-len(key_shape)]))
logical_gs = get_logical_gspmd_sharding(aval, phys_sharding)
return pxla._gspmd_to_named_sharding_via_mesh(
logical_gs, phys_sharding.mesh)
else:
key_shape = aval.dtype._impl.key_shape
phys_hlo_sharding = phys_sharding._to_xla_hlo_sharding(
aval.ndim + len(key_shape))
partitions, num_replicas = op_shardings.get_num_ways_dim_sharded(
phys_hlo_sharding)
suffix = [] if num_replicas == 1 else [num_replicas]
# Create logical sharding by cutting off the replicated trailing dims.
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
tad = partitions[:-len(key_shape)] + suffix
logical_op_sharding.tile_assignment_dimensions = tad
return GSPMDSharding(phys_sharding._device_assignment,
xc.HloSharding.from_proto(logical_op_sharding))
return get_logical_gspmd_sharding(aval, phys_sharding)

@staticmethod
def result_handler(sticky_device, aval):
Expand Down
1 change: 0 additions & 1 deletion jax/_src/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -981,7 +981,6 @@ class JaxTestCase(parameterized.TestCase):
"""Base class for JAX tests including numerical checks and boilerplate."""
_default_config = {
'jax_enable_checks': True,
'jax_enable_key_reuse_checks': True,
'jax_numpy_dtype_promotion': 'strict',
'jax_numpy_rank_promotion': 'raise',
'jax_traceback_filtering': 'off',
Expand Down
23 changes: 12 additions & 11 deletions tests/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from jax import numpy as jnp
from jax import jvp, linearize, vjp, jit, make_jaxpr
from jax.api_util import flatten_fun_nokwargs
from jax import config
from jax._src import config

from jax._src import core
from jax._src import linear_util as lu
Expand Down Expand Up @@ -750,16 +750,17 @@ def g(x): return x
core.check_jaxpr(jaxpr)

def test_check_jaxpr_key_reuse(self):
try:
from jax.experimental.key_reuse import KeyReuseError
except ImportError:
self.skipTest("Test requires jax.experimental.key_reuse")
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(KeyReuseError):
jax.jit(f)(0)
with config.enable_key_reuse_checks(True):
try:
from jax.experimental.key_reuse import KeyReuseError
except ImportError:
self.skipTest("Test requires jax.experimental.key_reuse")
def f(seed):
key = jax.random.key(seed)
return jax.random.uniform(key) + jax.random.normal(key)
with jax.enable_checks(True):
with self.assertRaises(KeyReuseError):
jax.jit(f)(0)


if __name__ == '__main__':
Expand Down
1 change: 1 addition & 0 deletions tests/key_reuse_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -586,6 +586,7 @@ def f_good(x, key):
self.check_key_reuse(jax.grad(f_good), x, key)


@jtu.with_config(jax_enable_key_reuse_checks=True)
class KeyReuseEager(jtu.JaxTestCase):
jit_msg = "Previously-consumed key passed to jit-compiled function at index 0"
eager_bits_msg = "Previously-consumed key passed to random_bits at index 0"
Expand Down
6 changes: 5 additions & 1 deletion tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2979,9 +2979,13 @@ def physical_hlo_sharding(aval, hlo_sharding: xc.HloSharding):
return xc.HloSharding.from_proto(new_op_sharding)

@staticmethod
def logical_op_sharding(aval, phys_sharding):
def logical_sharding(aval, phys_sharding):
return phys_sharding

@staticmethod
def physical_sharding(aval, sharding):
return sharding

@staticmethod
def result_handler(sticky_device, aval):
def handler(_, buf):
Expand Down
24 changes: 24 additions & 0 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3897,6 +3897,30 @@ def f():
lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1]', lowered_text)

def test_partial_sharded_prng_key_inp(self):
input_shape = (8, 2, 2)
mesh = jtu.create_global_mesh((2, 2, 2), ('x', 'y', 'z'))
spec = P('x', 'y', None)

seeds, _ = create_array(input_shape, mesh, spec, dtype=np.uint32)

@jax.jit
def make_keys(seeds):
make_key = partial(prng.random_seed, impl=prng.threefry_prng_impl)
key = make_key(seeds)
return key.T

make_keys(seeds)
out = make_keys(seeds) # cpp dispatch
self.assertEqual(out.sharding, NamedSharding(mesh, P(None, 'y', 'x')))

base_array = jax.random.key_data(out)
self.assertEqual(base_array.shape, (2, 2, 8, 2))
self.assertEqual(base_array.sharding, NamedSharding(mesh, P(None, 'y', 'x')))

lowered_text = make_keys.lower(seeds).as_text()
self.assertIn('unspecified_dims=[0,1,2]', lowered_text)

def test_jit_partially_specified_shardings(self):

mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
Expand Down

0 comments on commit 1cb8d31

Please sign in to comment.