Skip to content

Commit

Permalink
Make core.Token a non-trivial class which wraps a jax.Array. Curr…
Browse files Browse the repository at this point in the history
…ently, we use a singleton and empty `core.token` object everywhere. After the change, tokens could be created and threaded in and out of computations to build up dependency.

Also update ordered side-effects to use the new `core.Token` class (NFC for this part, just to unify token usage).

PiperOrigin-RevId: 626091210
  • Loading branch information
yueshengys authored and jax authors committed Apr 18, 2024
1 parent 9c9e805 commit c2d4373
Show file tree
Hide file tree
Showing 8 changed files with 123 additions and 89 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ Remember to align the itemized text with the first line of an item within a list
to non-parallel computations, as we already do async dispatch for parallel
computations. You can recover the old behavior by setting
`jax.config.update('jax_cpu_enable_async_dispatch', False)`.
* `core.Token` now is a non-trivial class which wraps a `jax.Array`. It could
be created and threaded in and out of computations to build up dependency.
The singleton object `core.token` has been removed, users now should create
and use fresh `core.Token` objects instead.

* Deprecations & Removals
* Pallas now exclusively uses XLA for compiling kernels on GPU. The old
Expand Down
2 changes: 2 additions & 0 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2454,6 +2454,8 @@ def _infer_src_sharding(src, x) -> Sharding | None:
def _check_sharding(x, s):
if isinstance(s, Sharding):
aval = shaped_abstractify(x)
if isinstance(aval, core.AbstractToken):
aval = core.token_shaped_array
if isinstance(s, XLACompatibleSharding) and not isinstance(s, PmapSharding):
pjit.pjit_check_aval_sharding(
(s,), (aval,), None, "device_put args", allow_uneven_sharding=False)
Expand Down
23 changes: 22 additions & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -952,6 +952,13 @@ def _array_shard_arg(x, sharding):
pxla.shard_arg_handlers[ArrayImpl] = _array_shard_arg


def _token_shard_arg(x, sharding):
return _array_shard_arg(x._buf, sharding)


pxla.shard_arg_handlers[core.Token] = _token_shard_arg


def _array_global_result_handler(global_aval, out_sharding, committed):
if global_aval.dtype == dtypes.float0:
return lambda _: np.zeros(global_aval.shape, dtypes.float0) # type: ignore
Expand All @@ -963,7 +970,21 @@ def _array_global_result_handler(global_aval, out_sharding, committed):
)
pxla.global_result_handlers[core.ShapedArray] = _array_global_result_handler
pxla.global_result_handlers[core.ConcreteArray] = _array_global_result_handler
pxla.global_result_handlers[core.AbstractToken] = lambda *_: lambda *_: core.token


def _token_global_result_handler(global_aval, out_sharding, committed):
array_handler = _array_global_result_handler(
core.token_shaped_array, out_sharding, committed
)

def wrapper(*args, **kwargs):
out_buf = array_handler(*args, **kwargs)
return core.Token(out_buf)

return wrapper


pxla.global_result_handlers[core.AbstractToken] = _token_global_result_handler


# Only used for Arrays that come out of pmap.
Expand Down
142 changes: 75 additions & 67 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,70 @@ def shape(self):
"UnshapedArray instances to ever be produced.")
raise TypeError(msg)

def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
or isinstance(dim.dtype, bint))):
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
return dim
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_dim(dim):
return dim
else:
raise type_error

def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)

def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]

def _invalid_shape_error(shape: Shape, context: str=""):
if config.dynamic_shapes.value:
msg = ("Shapes must be 1D sequences of integer scalars, "
f"got {shape}")
else:
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
msg += x._origin_msg()

return TypeError(msg)

class ShapedArray(UnshapedArray):
__slots__ = ['shape', 'named_shape']
Expand Down Expand Up @@ -1960,9 +2024,18 @@ def str_short(self, short_dtypes=False): return 'Tok'
def at_least_vspace(self): return self
abstract_token: AbstractToken = AbstractToken()

# Singleton shaped array used by all abstract tokens when shape/dtype is needed.
token_shaped_array: ShapedArray = ShapedArray((0,), np.dtype(np.bool_))

# Concrete token object
class Token: pass
token: Token = Token()
class Token:
# The underlying data wrapped by the token, could be used to threaded in and
# out of computations to build up data dependency.
_buf: Array
def __init__(self, buf):
self._buf = buf
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token


Expand Down Expand Up @@ -2121,71 +2194,6 @@ def dimension_as_value(d: DimSize):
if hasattr(d, "dimension_as_value"): return d.dimension_as_value()
return operator.index(d)

def _canonicalize_dimension(dim: DimSize) -> DimSize:
# Dimensions are most commonly integral (by far), so we check that first.
try:
return operator.index(dim)
except TypeError as e:
type_error = e
if isinstance(dim, Tracer) and config.dynamic_shapes.value:
if not (dim.ndim == 0 and (dtypes.issubdtype(dim.dtype, np.integer)
or isinstance(dim.dtype, bint))):
raise TypeError(f"Dimensions must be integer scalars; got {dim.ndim=} {dim.dtype=}")
return dim
elif (config.dynamic_shapes.value and isinstance(dim, DArray) and
type(dim._aval.dtype) is bint and not dim._aval.shape):
return dim
elif is_dim(dim):
return dim
else:
raise type_error

def canonicalize_shape(shape: Shape, context: str="") -> tuple[Any, ...]:
"""Canonicalizes and checks for errors in a user-provided shape value.
Args:
shape: a Python value that represents a shape.
Returns:
A tuple of canonical dimension values.
"""
try:
return tuple(unsafe_map(_canonicalize_dimension, shape))
except TypeError:
pass
raise _invalid_shape_error(shape, context)

def canonicalize_dim(d: DimSize, context: str="") -> DimSize:
"""Canonicalizes and checks for errors in a user-provided shape dimension value.
Args:
f: a Python value that represents a dimension.
Returns:
A canonical dimension value.
"""
return canonicalize_shape((d,), context)[0]

def _invalid_shape_error(shape: Shape, context: str=""):
if config.dynamic_shapes.value:
msg = ("Shapes must be 1D sequences of integer scalars, "
f"got {shape}")
else:
msg = ("Shapes must be 1D sequences of concrete values of integer type, "
f"got {shape}.")
if context:
msg += f" {context}."
if not config.dynamic_shapes.value and any(
isinstance(x, Tracer) and isinstance(get_aval(x), ShapedArray)
and not isinstance(get_aval(x), ConcreteArray) for x in shape):
msg += ("\nIf using `jit`, try using `static_argnums` or applying `jit` to "
"smaller subfunctions.")
for x in shape:
if isinstance(x, Tracer) and hasattr(x, "_origin_msg"):
msg += x._origin_msg()

return TypeError(msg)

class SomeTracer:
__slots__ = ()
def __repr__(self): return "[dynamic]"
Expand Down
13 changes: 7 additions & 6 deletions jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ class RuntimeTokenSet(threading.local):

# For each ordered effect, the token returned by the last dispatched
# computation, sharded over the devices in that computation.
current_tokens: dict[core.Effect, jax.Array]
current_tokens: dict[core.Effect, core.Token]

# For each device, the runtime token returned by the last dispatched
# computation on that device.
Expand All @@ -117,11 +117,12 @@ def __init__(self):
self.current_tokens = {}
self.output_runtime_tokens = {}

def get_token_input(self, eff: core.Effect,
devices: list[Device]) -> jax.Array:
def get_token_input(
self, eff: core.Effect, devices: list[Device]
) -> core.Token:
tok = self.current_tokens.get(eff, np.zeros(0, np.bool_))

if isinstance(tok, jax.Array):
if isinstance(tok, core.Token):
# The order of devices may change, so we need to reshard if necessary.
# TODO(yueshengys): This might still be buggy in a multi-process SPMD
# scenario. Revise the logic later. A distributed shutdown barrier inside
Expand All @@ -131,11 +132,11 @@ def get_token_input(self, eff: core.Effect,
# We only use replicated sharding for the first time when the token for the
# order effect hasn't been created.
s = jax.sharding.GSPMDSharding.get_replicated(devices)
sharded_tok = pxla.shard_args([s], [tok])[0]
sharded_tok = core.Token(pxla.shard_args([s], [tok])[0])
self.current_tokens[eff] = sharded_tok
return sharded_tok

def set_token_result(self, eff: core.Effect, token: jax.Array):
def set_token_result(self, eff: core.Effect, token: core.Token):
self.current_tokens[eff] = token

def set_output_runtime_token(self, device: Device, token: RuntimeToken):
Expand Down
18 changes: 7 additions & 11 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,6 @@ def get_addressable_devices_for_shard_arg(
def _get_replicated_slices(num_addressable_devices: int):
return ((slice(None),),) * num_addressable_devices

def _shard_token(x, sharding):
devices = get_addressable_devices_for_shard_arg(sharding)
indices = _get_replicated_slices(len(devices))
zeros = np.zeros((), dtype=np.dtype(np.bool_))
aval = api_util.shaped_abstractify(zeros)
return batched_device_put(aval, sharding, [zeros for _ in indices], devices)
shard_arg_handlers[core.Token] = _shard_token

def _masked_array_error(x, sharding):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
Expand Down Expand Up @@ -1148,8 +1141,9 @@ def __init__(self, xla_executable, name, backend, in_handler: InputsHandler,
def _add_tokens_to_inputs(self, input_bufs):
if self.ordered_effects:
tokens = [
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)
for eff in self.ordered_effects]
dispatch.runtime_tokens.get_token_input(eff, self._local_devices)._buf
for eff in self.ordered_effects
]
input_bufs = [*tokens, *input_bufs]
return input_bufs

Expand All @@ -1163,7 +1157,7 @@ def _handle_token_bufs(self, token_bufs, sharded_token):
for eff, token_buf in zip(self.ordered_effects, token_bufs):
assert len(token_buf) > 0
if len(token_buf) == 1:
dispatch.runtime_tokens.set_token_result(eff, token_buf[0])
dispatch.runtime_tokens.set_token_result(eff, core.Token(token_buf[0]))
else:
token_devices = []
for token in token_buf:
Expand All @@ -1173,7 +1167,9 @@ def _handle_token_bufs(self, token_bufs, sharded_token):
global_token_array = jax.make_array_from_single_device_arrays(
(0,), s, token_buf
)
dispatch.runtime_tokens.set_token_result(eff, global_token_array)
dispatch.runtime_tokens.set_token_result(
eff, core.Token(global_token_array)
)

@profiler.annotate_function
def __call__(self, *args):
Expand Down
1 change: 0 additions & 1 deletion jax/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@
subst_axis_names_var as subst_axis_names_var,
substitute_vars_in_output_ty as substitute_vars_in_output_ty,
thread_local_state as thread_local_state,
token as token,
trace_state_clean as trace_state_clean,
traverse_jaxpr_params as traverse_jaxpr_params,
typecheck as typecheck,
Expand Down
9 changes: 6 additions & 3 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,8 +673,12 @@ def noop(arr, token):

arr = jnp.ones(10)
token = jax.lax.create_token()
_, out_token = noop(arr, token)

self.assertEqual(token, noop(arr, token)[1])
self.assertIsInstance(token, core.Token)
self.assertIsInstance(out_token, core.Token)
# Different token objects.
self.assertIsNot(token, out_token)

def test_jit_bad_input(self):
def f(x):
Expand Down Expand Up @@ -1226,7 +1230,6 @@ def f(x, y, *args, **kwargs):
for s in ("\"x\"", "y['hi']", "args[0]", "args[1]", "kwargs['z']", "kwargs['w']"):
self.assertNotIn(s, hlo_str)


@parameterized.parameters([0, 2, [(0, 2)]])
def test_jit_lower_arg_info_static_argnums(self, static_argnums):
def f(x, y, *args, **kwargs):
Expand Down Expand Up @@ -3732,7 +3735,7 @@ def test_jit_returning_token(self):
self.assertIsInstance(x, core.Token)

def test_jit_capturing_token(self):
tok = core.token
tok = jax.lax.create_token()
_, y = jax.jit(lambda x: (x + 2, tok))(7)
self.assertIsInstance(y, core.Token)

Expand Down

0 comments on commit c2d4373

Please sign in to comment.