Skip to content

Commit

Permalink
Implement pjit fast path in cpp for jax.Array inputs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 475988677
  • Loading branch information
cky9301 authored and jax authors committed Sep 22, 2022
1 parent 52476d1 commit 405a231
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 15 deletions.
69 changes: 69 additions & 0 deletions benchmarks/api_benchmark.py
Expand Up @@ -26,6 +26,7 @@
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src.lib import xla_client as xc
from jax.interpreters import pxla
from jax.experimental import array
from jax.experimental import sharding
from jax.experimental import pjit as pjit_lib
import jax.numpy as jnp
Expand Down Expand Up @@ -628,5 +629,73 @@ def bench_slicing_compilation2(state):
jax.jit(lambda x: (x[:1], x[1:2], x[2:3])).lower(x).compile()


def pjit_simple_benchmark(state, num_devices, num_args, cpp_jit):
spec = pjit_lib.PartitionSpec('x')
mesh = jtu.create_global_mesh((num_devices,), ('x',))
s = sharding.MeshPspecSharding(mesh, spec)
inp_data = np.arange(num_devices).astype(np.float32)
x = array.make_array_from_callback(inp_data.shape, s, lambda idx: inp_data[idx])

x = [x for _ in range(num_args)]

prev_state = jax_config.FLAGS.experimental_cpp_pjit
jax_config.FLAGS.experimental_cpp_pjit = cpp_jit

in_axis_resources = sharding.MeshPspecSharding(mesh, spec)
out_axis_resources = sharding.MeshPspecSharding(mesh, spec)

f = pjit_lib.pjit(
lambda x: jax.tree_map(lambda x: x + 1, x),
in_axis_resources=in_axis_resources,
out_axis_resources=out_axis_resources)

x = f(x)

while state:
x = f(x)

jax_config.FLAGS.experimental_cpp_pjit = prev_state


@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_1_device(state):
pjit_simple_benchmark(
state, num_devices=1, num_args=state.range(0), cpp_jit=state.range(1))

@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4_device(state):
pjit_simple_benchmark(
state, num_devices=4, num_args=state.range(0), cpp_jit=state.range(1))

@google_benchmark.register
@google_benchmark.option.arg_names(['num_args', 'cpp_pjit'])
@google_benchmark.option.args([1, False])
@google_benchmark.option.args([1, True])
@google_benchmark.option.args([10, False])
@google_benchmark.option.args([10, True])
@google_benchmark.option.args([100, False])
@google_benchmark.option.args([100, True])
@jax_config.jax_array(True)
def pjit_simple_4000_device(state):
pjit_simple_benchmark(
state, num_devices=4000, num_args=state.range(0), cpp_jit=state.range(1))


if __name__ == "__main__":
google_benchmark.main()
5 changes: 5 additions & 0 deletions jax/_src/api.py
Expand Up @@ -123,6 +123,11 @@
"A flag enabling the C++ jax.pmap fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")
flags.DEFINE_bool(
"experimental_cpp_pjit", bool_env("JAX_CPP_PJIT", False),
"A flag enabling the C++ pjit fast path. Until the default "
"is switched to True, the feature is not supported and possibly broken "
"(e.g. it may use unreleased code from jaxlib.")


def _nan_check_posthook(fun, args, kwargs, output):
Expand Down
93 changes: 81 additions & 12 deletions jax/experimental/pjit.py
Expand Up @@ -16,9 +16,10 @@
from enum import IntEnum
import numpy as np
from collections import OrderedDict, Counter
from typing import Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable
from typing import Any, Callable, Sequence, Tuple, Union, cast, List, Optional, Iterable, NamedTuple
import itertools as it
from functools import partial, lru_cache
import threading

from jax.experimental import maps
from jax.experimental.global_device_array import GlobalDeviceArray as GDA
Expand All @@ -28,7 +29,7 @@
from jax import core
from jax import linear_util as lu
from jax import stages
from jax._src.api import _check_callable, _check_arg, local_devices
from jax._src.api import _check_callable, _check_arg, local_devices, FLAGS
from jax._src.config import config
from jax._src import dispatch
from jax._src import source_info_util
Expand Down Expand Up @@ -122,6 +123,73 @@ def _check_all_or_none_unspecified(axis_resources, name):
'`pjit._UNSPECIFIED`.')
return unspecified

def _python_pjit_helper(infer_params, *args, **kwargs):
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
for arg in args_flat:
_check_arg(arg)
out_flat = pjit_p.bind(*args_flat, **params)
outs = tree_unflatten(out_tree, out_flat)
return outs, out_flat, out_tree

def _python_pjit(fun: Callable, infer_params):

@wraps(fun)
def wrapped(*args, **kwargs):
return _python_pjit_helper(infer_params, *args, **kwargs)[0]

return wrapped

class _PjitFastpathData(NamedTuple):
xla_executable: xla.XlaExecutable
out_pytree_def: Any
in_shardings: Sequence[Any]
out_shardings: Sequence[Any]
out_avals: Sequence[Any]
out_committed: Sequence[bool]

class _MostRecentPjitCallExecutable(threading.local):
def __init__(self):
self.value = None

_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()

def _cpp_pjit(fun: Callable, infer_params, static_argnums):

def cache_miss(*args, **kwargs):
global _most_recent_pjit_call_executable

outs, out_flat, out_tree = _python_pjit_helper(infer_params, *args, **kwargs)

executable = _most_recent_pjit_call_executable.value
_most_recent_pjit_call_executable.value = None

use_fastpath = (
executable is not None and
isinstance(executable, pxla.MeshExecutable) and
isinstance(executable.unsafe_call, pxla.ExecuteReplicated) and
not executable.unsafe_call.has_unordered_effects and
not executable.unsafe_call.has_host_callbacks and
all(isinstance(x, xc.Array) for x in out_flat)
)

if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
fastpath_data = _PjitFastpathData(executable.xla_executable,
out_tree,
executable._in_shardings,
executable._out_shardings, out_avals,
out_committed)
else:
fastpath_data = None


return outs, fastpath_data

cpp_pjit_f = xc._xla.pjit(fun, cache_miss, static_argnums)

return wraps(fun)(cpp_pjit_f)


# TODO(yashkatariya): Add pjit microbenchmarks.
# in_axis_resources and out_axis_resources can't be None as the default value
Expand Down Expand Up @@ -359,13 +427,10 @@ def infer_params(*args, _global_avals=False, **kwargs):
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)

@wraps(fun)
def wrapped(*args, **kwargs):
args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
for arg in args_flat:
_check_arg(arg)
out = pjit_p.bind(*args_flat, **params)
return tree_unflatten(out_tree, out)
if FLAGS.experimental_cpp_pjit and xc._version >= 95:
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
else:
wrapped = _python_pjit(fun, infer_params)

def lower(*args, _global_avals=False, **kwargs):
(_, flat_local_in_avals, params, in_tree, out_tree,
Expand Down Expand Up @@ -838,6 +903,9 @@ def _pjit_call_impl(*args, jaxpr,
in_shardings, out_shardings, resource_env,
donated_invars, name,
in_positional_semantics, out_positional_semantics):

global _most_recent_pjit_call_executable

if config.jax_array:
in_shardings = _resolve_in_shardings(args, in_shardings, out_shardings,
resource_env.physical_mesh)
Expand All @@ -851,6 +919,7 @@ def _pjit_call_impl(*args, jaxpr,
jaxpr, in_shardings, out_shardings, resource_env,
donated_invars, name, in_is_global).compile(
_allow_propagation_to_outputs=_allow_propagation_to_outputs)
_most_recent_pjit_call_executable.value = compiled
# This check is expensive so only do it if enable_checks is on.
if compiled._auto_spmd_lowering and config.jax_enable_checks:
pxla._check_gda_or_array_xla_sharding_match(args, compiled._in_shardings)
Expand Down Expand Up @@ -880,7 +949,7 @@ class SameDeviceAssignmentTuple:
device_assignment: Optional[XLADeviceAssignment]

def __hash__(self):
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s
shardings_hash = tuple(s._op_sharding_hash if isinstance(s, OpShardingSharding) else s # type: ignore
for s in self.shardings)
if self.device_assignment is None:
return hash(shardings_hash)
Expand Down Expand Up @@ -935,14 +1004,14 @@ def _pjit_lower_cached(
in_shardings: Tuple[MeshShardingMinusUnspecified, ...] = cast( # type:ignore[no-redef]
Tuple[MeshShardingMinusUnspecified, ...], tuple(
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0])
mesh, parse_flatten_op_sharding(i._op_sharding, mesh)[0]) # type: ignore
if isinstance(i, OpShardingSharding) else i
for i in in_shardings
))
out_shardings: Tuple[MeshSharding, ...] = cast( # type: ignore[no-redef]
Tuple[MeshSharding, ...], tuple(
MeshPspecSharding._from_parsed_pspec(
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0])
mesh, parse_flatten_op_sharding(o._op_sharding, mesh)[0]) # type: ignore
if isinstance(o, OpShardingSharding) else o
for o in out_shardings
))
Expand Down
13 changes: 10 additions & 3 deletions jax/experimental/sharding.py
Expand Up @@ -157,27 +157,30 @@ def device_replica_id_map(sharding, global_shape: Shape) -> Mapping[Device, int]
return out


@pxla.use_cpp_class(xc.MeshPspecSharding if xc._version >= 95 else None)
class MeshPspecSharding(XLACompatibleSharding):

@pxla.use_cpp_method
def __init__(
self, mesh: pxla.Mesh, spec: pxla.PartitionSpec, _parsed_pspec = None):

self.mesh = mesh
self.spec = spec
self._parsed_pspec = _parsed_pspec
self._preprocess()

def _preprocess(self):
# This split exists because you can pass `_parsed_pspec` that has been
# modified from the original. For example: Adding extra dimension to
# axis_resources for vmap handlers. In such cases you need to preserve the
# `sync` attribute of parsed pspecs.
# PartitionSpec is inferred from the parsed pspec in this case.
# TODO(yaskatariya): Remove this and replace this with a normalized
# representation of Parsed Pspec
if _parsed_pspec is None:
if self._parsed_pspec is None:
from jax.experimental import pjit
self._parsed_pspec, _, _, _ = pjit._prepare_axis_resources(
self.spec, "MeshPspecSharding spec")
else:
self._parsed_pspec = _parsed_pspec

_check_mesh_resource_axis(self.mesh, self._parsed_pspec)

Expand Down Expand Up @@ -256,8 +259,10 @@ def _get_replicated_op_sharding():
return proto


@pxla.use_cpp_class(xc.SingleDeviceSharding if xc._version >= 95 else None)
class SingleDeviceSharding(XLACompatibleSharding):

@pxla.use_cpp_method
def __init__(self, device: Device):
self._device = device

Expand Down Expand Up @@ -349,8 +354,10 @@ def _hash_op_sharding(op: xc.OpSharding):
op.type, op.replicate_on_last_tile_dim, tuple(op.last_tile_dims)))


@pxla.use_cpp_class(xc.OpShardingSharding if xc._version >= 95 else None)
class OpShardingSharding(XLACompatibleSharding):

@pxla.use_cpp_method
def __init__(self, devices: Sequence[Device], op_sharding: xc.OpSharding):
self._devices = tuple(devices)
self._op_sharding = op_sharding
Expand Down
41 changes: 41 additions & 0 deletions tests/pjit_test.py
Expand Up @@ -24,6 +24,8 @@
from absl.testing import parameterized
import numpy as np

import concurrent.futures

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
Expand Down Expand Up @@ -2290,6 +2292,45 @@ def test_out_sharding_indices_id_cache_hit(self):
self.assertEqual(cache_info3.hits, cache_info2.hits + 1)


class ArrayCppPjitTest(ArrayPjitTest):

def setUp(self):
super().setUp()
self.jax_array = config.jax_array
self.cpp_pjit = config.FLAGS.experimental_cpp_pjit
config.update('experimental_cpp_pjit', True)
config.update('jax_array', True)

def tearDown(self):
config.update('experimental_cpp_pjit', self.cpp_pjit)
config.update('jax_array', self.jax_array)
super().tearDown()

def test_concurrent_cpp_pjit(self):
global_mesh = jtu.create_global_mesh((1,), ('x',))
sharding = MeshPspecSharding(global_mesh, P('x',))
n = 10
with global_mesh:
fs = [pjit(lambda x, i: x + i, static_argnums=1) for _ in range(n)]

def _invoke_with_mesh_twice(arg_tuple):
f, x, i = arg_tuple
with global_mesh:
f(x, i)
return f(x, i)

xs = [
array.make_array_from_callback(
(i,), sharding, lambda idx: np.arange(i, dtype=np.float32))
for i in range(n)
]
with concurrent.futures.ThreadPoolExecutor() as executor:
ys = executor.map(_invoke_with_mesh_twice,
[(fs[i], x, i) for i, x in enumerate(xs)])
for i, x, y in zip(range(n), xs, ys):
self.assertAllClose(x + i, y)


class TempSharding(Sharding):

def __init__(self, devices):
Expand Down

0 comments on commit 405a231

Please sign in to comment.