Skip to content

Commit

Permalink
Add inspect_array_sharding, enabling looking at shardings in pjit-ted…
Browse files Browse the repository at this point in the history
… functions

PiperOrigin-RevId: 476237731
  • Loading branch information
sharadmv authored and jax authors committed Sep 23, 2022
1 parent 11a6fd9 commit 805073f
Show file tree
Hide file tree
Showing 5 changed files with 274 additions and 6 deletions.
207 changes: 202 additions & 5 deletions jax/_src/debugging.py
Expand Up @@ -16,24 +16,33 @@
import functools
import string
import sys
import weakref

from typing import Any, Dict, Callable, Sequence, Set, Tuple, Union

from jax import core
from jax import tree_util
from jax import lax
from jax._src import ad_checkpoint
from jax._src import custom_derivatives
from jax._src import lib as jaxlib
from jax._src import util
from jax import linear_util as lu
from jax.config import config
from jax.experimental.sharding import Sharding
from jax.experimental.sharding import OpShardingSharding
from jax.experimental import pjit
from jax.interpreters import ad
from jax.interpreters import batching
from jax.interpreters import mlir
from jax.interpreters import partial_eval as pe
from jax.interpreters import pxla
from jax.interpreters import xla
from jax._src import ad_checkpoint
from jax._src import custom_derivatives
from jax._src import lib as jaxlib
from jax._src import source_info_util
from jax._src import util
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
import jax.numpy as jnp

# pytype: disable=import-error
Expand Down Expand Up @@ -246,6 +255,139 @@ def debug_print(fmt: str, *args, ordered: bool = False, **kwargs) -> None:
debug_callback(functools.partial(_format_print_callback, fmt), *args,
**kwargs, ordered=ordered)


# Sharding visualization

inspect_sharding_p = core.Primitive("inspect_sharding")
inspect_sharding_p.multiple_results = True

def _inspect_sharding_impl(value, *, callback):
if not config.jax_array:
raise NotImplementedError("`inspect_sharding` not implemented.")
callback(value.sharding)
return []
inspect_sharding_p.def_impl(_inspect_sharding_impl)

def _inspect_sharding_abstract_eval(aval, **_):
del aval
# Effectful abstract avoids DCE
return [], {DebugEffect.PRINT}
inspect_sharding_p.def_effectful_abstract_eval(_inspect_sharding_abstract_eval)

def _inspect_sharding_batching_rule(args, _, *, callback):
value, = args
inspect_sharding_p.bind(value, callback=callback)
return [], []
batching.primitive_batchers[inspect_sharding_p] = (
_inspect_sharding_batching_rule)

def _inspect_sharding_jvp_rule(primals, _, **params):
return inspect_sharding_p.bind(*primals, **params)
ad.primitive_jvps[inspect_sharding_p] = _inspect_sharding_jvp_rule

sharding_callbacks = weakref.WeakValueDictionary() # type: ignore
_INSPECT_SHARDING_CALL_NAME = "InspectSharding"

class ShardingCallbackInfo:
def __init__(self, callback, module_context):
self.callback = callback
self.module_context = module_context

def _inspect_sharding_lowering_rule(ctx: mlir.LoweringRuleContext, value, *,
callback):

mesh = pxla.thread_resources.env.physical_mesh
axis_context = ctx.module_context.axis_context

if isinstance(axis_context, mlir.ShardingContext):
devices = axis_context.device_assignment
elif isinstance(axis_context, mlir.SPMDAxisContext):
devices = list(axis_context.mesh.devices.flat)
else:
raise NotImplementedError(type(axis_context))

def _op_sharding_callback(op_sharding: xc.OpSharding):
if mesh.empty:
return callback(OpShardingSharding(
devices, op_sharding))
pspec = pjit.parse_flatten_op_sharding(
op_sharding, mesh)[0].get_partition_spec()
return callback(pjit.MeshPspecSharding(mesh, pspec))

if len(devices) == 1:
# If we only have one device in our computation, we can construct a trivial
# OpSharding and call it right now.
trivial_sharding = xc.OpSharding()
trivial_sharding.type = xc.OpSharding.Type.REPLICATED
_op_sharding_callback(trivial_sharding)
return []

# If we have a nontrivial parallel computation, we need to wait until the SPMD
# partitioner calls back with the `HloSharding.
def _hlo_sharding_callback(hlo_sharding):
op_sharding = hlo_sharding.to_proto()
return _op_sharding_callback(op_sharding)

# Here we store information in a container that we store globally so the
# custom partitioning code can access it.
sharding_callback_info = ShardingCallbackInfo(_hlo_sharding_callback,
ctx.module_context)
key = str(id(sharding_callback_info))
sharding_callbacks[key] = sharding_callback_info
# We need to make sure `sharding_callback_info` is still alive when the SPMD
# partitioner runs so we keep it alive by attaching it to the executable.
ctx.module_context.add_keepalive(sharding_callback_info)

mhlo.CustomCallOp([ir.TupleType.get_tuple([])], [value],
call_target_name=ir.StringAttr.get(
_INSPECT_SHARDING_CALL_NAME),
has_side_effect=ir.BoolAttr.get(True),
api_version=mlir.i32_attr(1),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(key),
operand_layouts=None,
result_layouts=None)
return []
mlir.register_lowering(inspect_sharding_p, _inspect_sharding_lowering_rule)

def inspect_sharding_prop_user_sharding(sharding, backend_string):
del sharding, backend_string
return []

def inspect_sharding_partition(shapes, arg_shardings, result_shape,
result_sharding, backend_string):
del result_shape, result_sharding
sharding_callback_info = sharding_callbacks[backend_string]
sharding_callback = sharding_callback_info.callback
module_context = sharding_callback_info.module_context

# Execute callback
hlo_sharding, = arg_shardings
sharding_callback(hlo_sharding)

tiled_args = [p.tile(s) for s, p in zip(shapes, arg_shardings)]
in_avals = [core.ShapedArray(arg.dimensions(), arg.numpy_dtype())
for arg in tiled_args]
fun = lu.wrap_init(lambda *args: [])
jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
trivial_comp = mlir.build_xla_computation_helper(closed_jaxpr,
name="tmp_xla_computation", platform=module_context.platform,
backend_or_name=module_context.backend_or_name,
axis_context=module_context.axis_context)
return trivial_comp, arg_shardings, arg_shardings[0]

def inspect_sharding_infer_sharding_from_operands(arg_shapes, arg_shardings,
shape, backend_string):
del arg_shapes, shape, backend_string
return arg_shardings[0]

if jaxlib.xla_extension_version >= 94:
xc.register_custom_call_partitioner( # pytype: disable=module-attr
_INSPECT_SHARDING_CALL_NAME, inspect_sharding_prop_user_sharding,
inspect_sharding_partition, inspect_sharding_infer_sharding_from_operands,
True)

def _slice_to_chunk_idx(size: int, slc: slice) -> int:
if slc.stop == slc.start == None:
return 0
Expand Down Expand Up @@ -339,8 +481,63 @@ def visualize_sharding(shape: Sequence[int], sharding: Sharding, *,
table.add_row(*col)
console.print(table, end='\n\n')

def inspect_array_sharding(value, *, callback: Callable[[Sharding], None]):
"""Enables inspecting array sharding inside JIT-ted functions.
This function, when provided with a Pytree of arrays, calls back with each of
their shardings and works in ``pjit``-ted computations, enabling inspecting
the chosen intermediate shardings.
The policy for when ``callback`` is called is *as early as possible* when the
sharding information is available. This means if ``inspect_array_callback`` is
called without any transformations, The callback will happen immediately
since we have the array and its sharding readily available. Inside of a
``jax.jit``, the callback will happen at lowering time, meaning you can
trigger the callback using the AOT API( ``jit(f).lower(...)``). When inside of
a ``pjit``, the callback happens **at compile time** since the sharding is
determined by XLA. You can trigger the callback by using JAX's AOT API
(``pjit(f).lower(...).compile()``). In all cases, the callback will be
triggered by running the function, since running a function entails lowering
and compiling it first. However, once the function is compiled and cached,
the callback will no longer occur.
This function is experimental and its behavior may change in the future.
Args:
value: A Pytree of JAX arrays.
callback: A callable that takes in a `Sharding` and doesn't return a value.
In the following example, we print out the sharding of an intermediate value
in a ``pjit``-ted computation:
>>> import jax
>>> import jax.numpy as jnp
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental.pjit import PartitionSpec, pjit
>>>
>>> x = jnp.arange(8, dtype=jnp.float32)
>>> def f_(x):
... x = jnp.sin(x)
... jax.debug.inspect_array_sharding(x, callback=print)
... return jnp.square(x)
>>> f = pjit(f_, in_axis_resources=PartitionSpec('dev'),
... out_axis_resources=PartitionSpec('dev'))
>>> with Mesh(jax.devices(), ('dev',)):
... f.lower(x).compile() # doctest: +SKIP
...
MeshPspecSharding(mesh={'dev': 8}, partition_spec=PartitionSpec(('dev',),))
"""
if jaxlib.xla_extension_version < 94:
raise NotImplementedError("`inspect_array_sharding` not implemented. "
"Please upgrade `jaxlib` to the latest version.")
def _inspect(val):
inspect_sharding_p.bind(val, callback=callback)
tree_util.tree_map(_inspect, value)

def visualize_array_sharding(arr, **kwargs):
"""Visualizes an array's sharding."""
if not config.jax_array:
raise NotImplementedError("`visualize_array_sharding` not implemented.")
return visualize_sharding(arr.shape, arr.sharding, **kwargs)
def _visualize(sharding):
return visualize_sharding(arr.shape, sharding, **kwargs)
inspect_array_sharding(arr, callback=_visualize)
1 change: 1 addition & 0 deletions jax/debug.py
Expand Up @@ -15,5 +15,6 @@
from jax._src.debugging import debug_print as print
from jax._src.debugging import DebugEffect
from jax._src.debugging import visualize_array_sharding
from jax._src.debugging import inspect_array_sharding
from jax._src.debugging import visualize_sharding
from jax._src.debugger import breakpoint
1 change: 1 addition & 0 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -1208,6 +1208,7 @@ def _unexpected_primitive(p: core.Primitive, *args, **kwargs):
"full_to_shard",
"shard_to_full",
"pure_callback",
"inspect_sharding",

# Not high priority?
"after_all",
Expand Down
15 changes: 15 additions & 0 deletions jax/interpreters/mlir.py
Expand Up @@ -1641,6 +1641,21 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function
token, *results = results
return results, token, keepalive

def build_xla_computation_helper(
closed_jaxpr: core.ClosedJaxpr, *, name: str, platform: str,
backend_or_name: str, axis_context: AxisContext) -> xc.XlaComputation:
"""Helper to generate pmap-style XLA computations for custom partitioners."""
if closed_jaxpr.effects:
raise NotImplementedError
lowering_result = lower_jaxpr_to_module(name, closed_jaxpr,
backend_or_name=backend_or_name, unordered_effects=[], ordered_effects=[],
name_stack=source_info_util.NameStack(),
donated_args=[False] * len(closed_jaxpr.jaxpr.invars),
axis_context=axis_context, platform=platform)
return xc._xla.mlir.mlir_module_to_xla_computation(
module_to_string(lowering_result.module), use_tuple_args=False,
return_tuple=False)

# Lax ops missing MLIR lowerings.
# # TODO(b/203775215): these are missing from the cHLO dialect. Either add
# # them or port them to Python.
Expand Down
56 changes: 55 additions & 1 deletion tests/debugging_primitives_test.py
Expand Up @@ -694,7 +694,7 @@ def f(x):
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=spec)
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
with mesh:
with jtu.capture_stdout() as output:
f(np.arange(8, dtype=jnp.int32))
Expand Down Expand Up @@ -1049,6 +1049,60 @@ def test_visualize_pmap_sharding(self):
""")
self.assertEqual(output(), expected)

class InspectShardingTest(jtu.JaxTestCase):

def test_inspect_sharding_is_called_in_pjit(self):

if jaxlib.xla_extension_version < 94:
raise unittest.SkipTest("Inspect sharding not supported.")

is_called = False
def _cb(sd):
nonlocal is_called
is_called = True
self.assertIsInstance(sd, sharding.Sharding)
self.assertLen(sd.device_set, len(jax.devices()))

def f(x):
debugging.inspect_array_sharding(x, callback=_cb)
return jnp.square(x)

mesh = maps.Mesh(np.array(jax.devices()), ['dev'])
if config.jax_array:
spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec('dev'))
out_spec = sharding.MeshPspecSharding(mesh, pjit.PartitionSpec())
else:
spec = pjit.PartitionSpec('dev')
out_spec = pjit.PartitionSpec()
f = pjit.pjit(f, in_axis_resources=spec, out_axis_resources=out_spec)
with mesh:
f(np.arange(8, dtype=jnp.int32))
self.assertTrue(is_called)

def test_inspect_sharding_is_called_in_jit(self):

if jaxlib.xla_extension_version < 94:
raise unittest.SkipTest("Inspect sharding not supported.")

if not config.jax_array:
raise unittest.SkipTest("jax_array to work inside of `jit`.")

is_called = False
def _cb(sd):
nonlocal is_called
is_called = True
self.assertIsInstance(sd, sharding.Sharding)
self.assertLen(sd.device_set, 1)

def f(x):
debugging.inspect_array_sharding(x, callback=_cb)
return jnp.square(x)

f = jax.jit(f)
f(np.arange(8, dtype=jnp.int32))
self.assertTrue(is_called)


if not rich:
del VisualizeShardingTest

Expand Down

0 comments on commit 805073f

Please sign in to comment.