Skip to content

Commit

Permalink
Replace apply_primitive internals with jax.jit.
Browse files Browse the repository at this point in the history
This allows deletion of a lot of code and leads to ~40% eager performance speedup.

Benchmarks:

```
name                                                      old time/op          new time/op          delta
eager_unary_dispatch                                      31.3µs ± 1%          19.4µs ± 6%  -37.91%    (p=0.016 n=4+5)
eager_unary                                               32.1µs ± 0%          19.8µs ± 4%  -38.26%    (p=0.016 n=4+5)
eager_binary_dispatch                                     35.9µs ± 1%          20.5µs ± 4%  -42.93%    (p=0.016 n=4+5)
eager_binary                                              36.6µs ± 1%          21.1µs ± 4%  -42.29%    (p=0.016 n=4+5)
jit_trivial_dispatch                                      3.87µs ± 2%          4.12µs ±25%     ~       (p=1.000 n=5+5)
jit_trivial                                               4.75µs ± 2%          4.82µs ±11%     ~       (p=0.690 n=5+5)
jit_simple_dispatch                                       2.95µs ± 2%          2.97µs ± 7%     ~       (p=1.000 n=5+5)
jit_simple                                                3.52µs ± 6%          3.51µs ± 5%     ~       (p=0.841 n=5+5)
jit_simple_dispatch_array                                 2.95µs ± 2%          2.96µs ± 6%     ~       (p=1.000 n=5+5)
jit_simple_array                                          3.46µs ± 2%          3.51µs ± 5%     ~       (p=0.690 n=5+5)
jit_small_matmul                                          3.01µs ± 1%          3.00µs ± 4%     ~       (p=0.548 n=5+5)
jit_big_matmul                                            34.0µs ±18%          35.5µs ±17%     ~       (p=0.310 n=5+5)
jit_simple_many_args_dispatch/num_args:10                 6.93µs ± 6%          6.80µs ± 6%     ~     (p=0.481 n=10+10)
jit_simple_many_args_dispatch/num_args:100                47.7µs ± 7%          45.4µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args_dispatch/num_args:1000                545µs ± 8%           516µs ± 2%     ~      (p=0.101 n=10+8)
jit_simple_many_args_dispatch/num_args:2000               1.12ms ± 7%          1.07ms ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:10                          7.42µs ± 5%          7.23µs ± 2%     ~      (p=0.173 n=10+8)
jit_simple_many_args/num_args:100                         48.4µs ± 7%          45.6µs ± 2%     ~      (p=0.237 n=10+8)
jit_simple_many_args/num_args:1000                         542µs ± 6%           524µs ± 8%     ~     (p=0.089 n=10+10)
jit_simple_many_args/num_args:2000                        1.12ms ± 7%          1.08ms ± 1%     ~      (p=0.068 n=10+8)
jit_simple_pruned_args_dispatch_10                        4.79µs ± 8%          4.98µs ±10%     ~       (p=0.421 n=5+5)
jit_simple_pruned_args_10                                 5.32µs ± 6%          5.30µs ± 4%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_100                       24.7µs ± 6%          23.8µs ± 8%     ~       (p=0.548 n=5+5)
jit_simple_pruned_args_100                                25.2µs ± 6%          24.4µs ± 8%     ~       (p=0.690 n=5+5)
jit_simple_pruned_args_dispatch_1000                       238µs ± 7%           232µs ± 8%     ~       (p=0.841 n=5+5)
jit_simple_pruned_args_1000                                240µs ± 7%           234µs ± 8%     ~       (p=1.000 n=5+5)
jit_simple_pruned_args_dispatch_2000                       516µs ± 6%           497µs ± 1%     ~       (p=0.413 n=5+4)
jit_simple_pruned_args_2000                                517µs ± 6%           505µs ± 7%     ~       (p=0.690 n=5+5)
jit_dispatch_without_transfer                              719µs ± 9%           751µs ± 8%     ~       (p=0.222 n=5+5)
jit_dispatch_with_transfer                                 799µs ±14%           793µs ± 9%     ~       (p=1.000 n=5+5)
pmap_trivial_2_devices                                    49.9µs ±40%          48.2µs ±42%     ~       (p=0.841 n=5+5)
pmap_trivial_dispatch_8_devices                           74.5µs ±24%          78.9µs ±29%     ~       (p=0.421 n=5+5)
pmap_trivial_8_devices                                    79.3µs ± 6%          82.7µs ±20%     ~       (p=0.841 n=5+5)
pmap_simple_2_devices                                     47.1µs ±17%          49.1µs ±20%     ~       (p=0.548 n=5+5)
pmap_simple_dispatch_8_devices                            73.4µs ±16%          76.8µs ±21%     ~       (p=0.690 n=5+5)
pmap_simple_8_devices                                     76.0µs ±10%          80.6µs ±29%     ~       (p=1.000 n=5+5)
pmap_simple_dispatch_8_devices_100_args                   1.12ms ±22%          1.08ms ±42%     ~       (p=0.841 n=5+5)
pmap_simple_8_devices_100_args                            12.5ms ± 8%          12.8ms ±10%     ~       (p=1.000 n=5+5)
sda_index_1                                                413µs ± 1%           686µs ± 4%  +66.08%    (p=0.008 n=5+5)
sda_index_2                                                850µs ± 1%          1378µs ± 4%  +62.02%    (p=0.008 n=5+5)
sda_index_8                                               3.60ms ± 1%          5.69ms ± 4%  +58.00%    (p=0.008 n=5+5)
bench_shaped_abstractify                                   300µs ± 1%           305µs ± 3%     ~       (p=0.056 n=5+5)
bench_xla_abstractify_scalar_int                          6.45µs ± 1%          6.50µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_scalar_float                        3.73µs ± 1%          3.73µs ± 3%     ~       (p=0.690 n=5+5)
bench_xla_abstractify_scalar_numpy_int32                  4.97µs ± 1%          4.83µs ± 3%     ~       (p=0.095 n=5+5)
bench_xla_abstractify_scalar_numpy_uint32                 4.91µs ± 1%          4.75µs ± 0%   -3.30%    (p=0.016 n=5+4)
bench_xla_abstractify_numpy_random                        4.34µs ± 2%          4.31µs ± 3%     ~       (p=0.310 n=5+5)
bench_xla_abstractify_numpy_arange_100_float32            3.94µs ± 1%          3.93µs ± 3%     ~       (p=0.548 n=5+5)
bench_xla_abstractify_enum                                6.85µs ± 1%          7.06µs ± 7%   +3.07%    (p=0.032 n=5+5)
bench_are_op_shardings_equal                              26.9µs ± 2%          27.0µs ± 3%     ~       (p=0.841 n=5+5)
bench_pjit_check_aval_sharding                             691µs ± 2%           711µs ±13%     ~       (p=0.841 n=5+5)
bench_addressable_shards_index                             656ns ± 4%           688ns ± 9%     ~       (p=0.095 n=5+5)
bench_remat_eager_retracing_overheads                     12.7ms ± 4%          10.7ms ± 1%  -15.48%    (p=0.016 n=5+4)
bench_remat_eager_retracing_overheads_static_argnums      13.0ms ± 2%          11.3ms ± 6%  -13.71%    (p=0.008 n=5+5)
bench_slicing_compilation                                 12.1ms ± 1%          12.3ms ± 4%     ~       (p=0.690 n=5+5)
bench_slicing_compilation2                                11.3ms ± 0%          11.5ms ± 6%     ~       (p=0.690 n=5+5)
bench_repeated_static_indexing                            62.5ms ± 2%          40.8ms ± 8%  -34.77%    (p=0.008 n=5+5)
bench_repeated_static_slicing                             46.7ms ± 1%          31.4ms ± 2%  -32.76%    (p=0.008 n=5+5)
pjit_simple_1_device/num_args:1                           2.72µs ± 2%          2.68µs ± 5%     ~       (p=0.151 n=5+5)
pjit_simple_1_device/num_args:10                          12.6µs ± 7%          12.3µs ± 3%     ~       (p=0.310 n=5+5)
pjit_simple_1_device/num_args:100                          109µs ± 3%           108µs ± 4%     ~       (p=0.548 n=5+5)
pjit_simple_4_device/num_args:1                           38.0µs ±26%          36.8µs ±19%     ~       (p=0.690 n=5+5)
pjit_simple_4_device/num_args:10                          93.3µs ±19%          96.6µs ±23%     ~       (p=0.841 n=5+5)
pjit_simple_4_device/num_args:100                          730µs ±16%           698µs ±48%     ~       (p=0.841 n=5+5)
pjit_aot_1_device/num_args:1                              3.29µs ± 2%          3.12µs ± 4%   -5.24%    (p=0.016 n=4+5)
pjit_aot_1_device/num_args:10                             13.0µs ± 1%          12.7µs ± 2%     ~       (p=0.063 n=4+5)
pjit_aot_1_device/num_args:100                             111µs ± 5%           110µs ±11%     ~       (p=0.421 n=5+5)
pjit_aot_4_device/num_args:1                              38.4µs ±19%          38.9µs ±24%     ~       (p=1.000 n=5+5)
pjit_aot_4_device/num_args:10                             91.3µs ±15%          96.9µs ±29%     ~       (p=0.548 n=5+5)
pjit_aot_4_device/num_args:100                             676µs ±20%           689µs ±41%     ~       (p=0.841 n=5+5)
host_local_array_to_global_array                           196µs ± 6%           194µs ± 4%     ~       (p=0.548 n=5+5)
device_put                                                50.8µs ± 1%          50.7µs ± 4%     ~       (p=0.413 n=4+5)
device_put_sharded                                         176µs ± 0%           177µs ± 4%     ~       (p=0.190 n=4+5)
device_get_8_devices                                      3.96ms ± 4%          4.03ms ± 7%     ~       (p=0.413 n=4+5)
np_asarray_8_devices                                      3.34ms ±18%          3.30ms ±10%     ~       (p=0.548 n=5+5)
jax_array_arrays_8_devices                                5.01ms ±10%          5.09ms ±21%     ~       (p=0.421 n=5+5)
batch_inplace_while_scatter                                440µs ± 1%           439µs ± 1%     ~       (p=0.421 n=5+5)
batch_inplace_while_dynamic_update_slice                   454µs ± 0%           457µs ± 1%     ~       (p=0.905 n=4+5)
serial_dot_products                                       4.51µs ± 3%          4.41µs ± 2%     ~       (p=0.151 n=5+5)
bench_make_array_from_callback_fully_replicated_sharding  26.6µs ± 1%          27.0µs ± 2%     ~       (p=0.056 n=5+5)
```

PiperOrigin-RevId: 586505950
  • Loading branch information
yashk2810 authored and jax authors committed Nov 30, 2023
1 parent b6c73f8 commit e624610
Show file tree
Hide file tree
Showing 8 changed files with 54 additions and 162 deletions.
1 change: 1 addition & 0 deletions jax/_src/core.py
Expand Up @@ -1845,6 +1845,7 @@ def __init__(self, aval, data):
self._data = data
shape = property(lambda self: self._aval.shape)
dtype = property(lambda self: self._aval.dtype)
aval = property(lambda self: self._aval)
def __repr__(self) -> str:
if not self.shape and type(self.dtype) is bint:
# special-case scalar bints
Expand Down
136 changes: 22 additions & 114 deletions jax/_src/dispatch.py
Expand Up @@ -18,7 +18,6 @@
import atexit
from collections.abc import Iterator, Sequence
import contextlib
import dataclasses
from functools import partial
import itertools
import time
Expand All @@ -32,10 +31,8 @@
from jax._src import basearray
from jax._src import config
from jax._src import core
from jax._src import api
from jax._src import dtypes
from jax._src import linear_util as lu
from jax._src import api_util
from jax._src import tree_util
from jax._src import source_info_util
from jax._src import traceback_util
from jax._src import util
Expand All @@ -45,14 +42,15 @@
from jax._src.interpreters import mlir
from jax._src.interpreters import xla
from jax._src.interpreters import pxla
from jax._src.interpreters import partial_eval as pe
from jax._src import lib
from jax._src.lib import xla_client as xc
from jax._src.lib import xla_extension_version
from jax._src.monitoring import record_event_duration_secs
from jax._src.partition_spec import PartitionSpec
from jax._src.sharding import Sharding
from jax._src.sharding_impls import (
PmapSharding, SingleDeviceSharding, NamedSharding, XLACompatibleSharding,
UNSPECIFIED, GSPMDSharding, TransferToMemoryKind)
GSPMDSharding, TransferToMemoryKind)


JAXPR_TRACE_EVENT = "/jax/core/compile/jaxpr_trace_duration"
Expand All @@ -78,119 +76,29 @@

### op-by-op execution

class _ArgSpec(NamedTuple):
aval: core.AbstractValue
sharding: XLACompatibleSharding | None


def _arg_spec(x: Any) -> _ArgSpec:
from jax._src import pjit

aval = xla.abstractify(x)
try:
if isinstance(x.sharding, PmapSharding):
return _ArgSpec(aval, None)
return _ArgSpec(aval, (pjit.to_gspmd_sharding(x.sharding, x.ndim) # type: ignore
if x._committed else None))
except:
return _ArgSpec(aval, None)


@dataclasses.dataclass(frozen=True)
class OrigShardings:
shardings: Sequence[GSPMDSharding | None]

def __hash__(self):
return hash(tuple(s for s in self.shardings))

def __eq__(self, other):
if not isinstance(other, OrigShardings):
return False
return all(getattr(s, "_original_sharding", s) == getattr(o, "_original_sharding", o)
for s, o in zip(self.shardings, other.shardings))


def apply_primitive(prim, *args, **params):
"""Impl rule that compiles and runs a single primitive 'prim' using XLA."""
from jax._src import pjit

try:
in_avals, in_shardings = util.unzip2([_arg_spec(a) for a in args])
in_tree = tree_util.tree_structure(args)
compiled_fun = xla_primitive_callable(
prim, in_avals, in_tree, OrigShardings(in_shardings), **params)
except pxla.DeviceAssignmentMismatchError as e:
fails, = e.args
# TODO(yashkatariya): Thread through a signature_fun via every primitive
# using apply_primitive so that the error message has the right argument
# name instead of `args[0]`, etc.
arg_names = api_util._arg_names(prim.impl, args, {}, (), ())
msg = pjit._device_assignment_mismatch_error(
prim.name, fails, args, 'jit', arg_names)
raise ValueError(msg) from None

return compiled_fun(*args)


# No need to cache here because there is a cache on xla_primitive_callable.
# If that cache is broken, a new function will be created which will always
# break the cache on this function.
def _trace_to_jaxpr(fun, in_avals, api_name, fun_name):
with log_elapsed_time(
"Finished tracing + transforming {fun_name} in {elapsed_time} sec",
fun_name=util.wrap_name(fun_name, api_name), event=JAXPR_TRACE_EVENT):
jaxpr, out_avals, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
return core.ClosedJaxpr(jaxpr, consts), tuple(out_avals)

fun = xla_primitive_callable(prim, **params)
# TODO(yashkatariya): Investigate adding is_primitive to jit and never
# triggering the disable jit path instead of messing around with it here.
if xla_extension_version >= 218:
prev = lib.jax_jit.swap_thread_local_state_disable_jit(False)
try:
outs = fun(*args)
finally:
lib.jax_jit.swap_thread_local_state_disable_jit(prev)
else:
with config.disable_jit(False):
outs = fun(*args)
return outs

@util.cache()
def xla_primitive_callable(
prim: core.Primitive, in_avals: tuple[core.AbstractValue, ...], in_tree,
orig_in_shardings: OrigShardings, **params,
) -> Callable:
def xla_primitive_callable(prim: core.Primitive, **params):
def prim_fun(*args):
out = prim.bind(*args, **params)
if prim.multiple_results:
return out
else:
return out,

donated_invars = (False,) * len(in_avals)
wrapped_fun = lu.wrap_init(prim_fun)
flat_fun, out_tree = api_util.flatten_fun_nokwargs(wrapped_fun, in_tree)
closed_jaxpr, out_avals = _trace_to_jaxpr(flat_fun, in_avals, 'jit', prim.name)
computation = sharded_lowering(
closed_jaxpr, prim.name, donated_invars, keep_unused=False,
inline=True, in_avals=in_avals, out_avals=out_avals,
in_shardings=orig_in_shardings.shardings,
lowering_parameters=mlir.LoweringParameters())
compiled = computation.compile()
if config.disable_jit.value:
call = compiled.unsafe_call
else:
call = compiled.create_cpp_call_for_apply_primitive(out_tree())
if call is None:
call = compiled.unsafe_call
if not prim.multiple_results:
return lambda *args, **kw: call(*args, **kw)[0]
else:
return call


def sharded_lowering(
closed_jaxpr: core.ClosedJaxpr, name: str, donated_invars: Sequence[bool],
keep_unused: bool, inline: bool, in_avals: tuple[core.AbstractValue, ...],
out_avals: tuple[core.AbstractValue, ...],
in_shardings: Sequence[Sharding | None],
lowering_parameters: mlir.LoweringParameters
) -> pxla.MeshComputation:
in_shardings_unspec = [UNSPECIFIED if i is None else i for i in in_shardings]
return pxla.lower_sharding_computation(
closed_jaxpr, 'jit', name, in_shardings_unspec,
(UNSPECIFIED,) * len(out_avals), donated_invars,
in_avals, keep_unused=keep_unused, inline=inline,
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))
return prim.bind(*args, **params)
prim_fun.__name__ = prim.name
prim_fun.__qualname__ = prim.name
return api.jit(prim_fun)


def simple_impl(prim):
Expand Down
29 changes: 0 additions & 29 deletions jax/_src/interpreters/pxla.py
Expand Up @@ -2834,35 +2834,6 @@ def aot_cache_miss(*args, **kwargs):
return xc._xla.pjit(self.unsafe_call.name, None, aot_cache_miss, [], [], [],
tree_util.dispatch_registry)

def create_cpp_call_for_apply_primitive(self, out_tree):
# unsafe_call can be different than ExecuteReplicated for pathways.
if not (isinstance(self.unsafe_call, ExecuteReplicated) and
not self.unsafe_call.has_unordered_effects and
not self.unsafe_call.has_host_callbacks):
return None

def apply_primitive_cache_miss(*args):
out_flat = self.unsafe_call(*args)
outs = tree_util.tree_unflatten(out_tree, out_flat)
out_flat, out_tree_dispatch = reflatten_outputs_for_dispatch(
out_tree, out_flat)
use_fastpath = (all(isinstance(x, xc.ArrayImpl) for x in out_flat))

if use_fastpath:
out_avals = [o.aval for o in out_flat]
out_committed = [o._committed for o in out_flat]
kept_var_bitvec = [i in self._kept_var_idx
for i in range(len(args))]
fastpath_data = MeshExecutableFastpathData(
self.xla_executable, out_tree_dispatch, self._in_shardings,
self._out_shardings, out_avals, out_committed, kept_var_bitvec)
else:
fastpath_data = None
return outs, fastpath_data

return xc._xla.pjit(self.unsafe_call.name, None, apply_primitive_cache_miss,
[], [], [], tree_util.dispatch_registry)


def check_arg_avals_for_call(ref_avals, arg_avals,
jaxpr_debug_info: core.JaxprDebugInfo | None = None):
Expand Down
13 changes: 8 additions & 5 deletions jax/_src/maps.py
Expand Up @@ -55,7 +55,7 @@
GSPMDSharding)
from jax._src.sharding_impls import (
ArrayMapping, NamedSharding, ParsedPartitionSpec,
array_mapping_to_axis_resources)
array_mapping_to_axis_resources, UNSPECIFIED)
from jax._src.tree_util import (tree_flatten, tree_unflatten, all_leaves,
tree_map, treedef_tuple)
from jax._src.util import (safe_map, safe_zip, HashableFunction, unzip2, unzip3,
Expand Down Expand Up @@ -702,10 +702,13 @@ def make_xmap_callable(fun: lu.WrappedFun,
tiling_method=tiling_method,
lowering_parameters=lowering_parameters)
else:
closed_jaxpr, out_avals = dispatch._trace_to_jaxpr(f, in_avals, 'jit', name)
return dispatch.sharded_lowering(
closed_jaxpr, name, donated_invars, True, False, in_avals, out_avals,
(None,) * len(in_avals), lowering_parameters=lowering_parameters)
jaxpr, out_avals, consts = pe.trace_to_jaxpr_final(f, in_avals)
return pxla.lower_sharding_computation(
core.ClosedJaxpr(jaxpr, consts), 'jit', name,
(UNSPECIFIED,) * len(in_avals), (UNSPECIFIED,) * len(out_avals),
donated_invars, in_avals, keep_unused=True, inline=False,
devices_from_context=None, lowering_parameters=lowering_parameters,
in_layouts=(None,) * len(in_avals), out_layouts=(None,) * len(out_avals))


class EvaluationPlan(NamedTuple):
Expand Down
8 changes: 4 additions & 4 deletions jax/_src/pjit.py
Expand Up @@ -1176,15 +1176,15 @@ def _pjit_call_impl_python(
("fingerprint", fingerprint))
try:
return compiled.unsafe_call(*args), compiled
except FloatingPointError:
except FloatingPointError as e:
assert config.debug_nans.value or config.debug_infs.value # compiled_fun can only raise in this case

_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return
if len(jaxpr.eqns) > 1:
_ = core.jaxpr_as_fun(jaxpr)(*args) # may raise, not return

# If control reaches this line, we got a NaN on the output of `compiled`
# but not `fun.call_wrapped` on the same arguments. Let's tell the user.
msg = ("An invalid value was encountered in the output of the "
f"`jit`-decorated function {name}. Because "
msg = (f"{str(e)}. Because "
"jax_config.debug_nans.value and/or config.jax_debug_infs is set, the "
"de-optimized function (i.e., the function as if the `jit` "
"decorator were removed) was called in an attempt to get a more "
Expand Down
2 changes: 1 addition & 1 deletion tests/api_test.py
Expand Up @@ -5541,7 +5541,7 @@ def test_vjp_caching_static_argnums(self):
identity = jax.remat(lambda x, y: jax.jit(lambda x: 2 * x if y else x)(x),
static_argnums=(1,))
_, f_vjp = jax.vjp(identity, 1., True)
with jtu.count_pjit_cpp_cache_miss() as count: # noqa: F841
with jtu.count_jit_and_pmap_compiles() as count: # noqa: F841
for _ in range(20):
f_vjp(1.)[0].block_until_ready()
self.assertEqual(count[0], 2) # fwd execute_trivial, backward_pass on bwd
Expand Down
13 changes: 11 additions & 2 deletions tests/debug_nans_test.py
Expand Up @@ -32,11 +32,13 @@
class DebugNaNsTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self.cfg = config._read("jax_debug_nans")
config.update("jax_debug_nans", True)

def tearDown(self):
config.update("jax_debug_nans", self.cfg)
super().tearDown()

def testSinc(self):
# Regression test for #6936
Expand Down Expand Up @@ -190,22 +192,29 @@ def f(x, y):
return x / y

with self.assertRaisesRegex(
FloatingPointError, r"invalid value \(nan\) encountered in jit\(div\)"):
FloatingPointError,
r"invalid value \(nan\) encountered in jit\(true_divide\)"):
f(inp, inp)

# TODO(yashkatariya): Fix this and make true_divide appear in the name again.
# Instead of `f` showing up in the error, the name should be of the
# primitive (true_divide) in this case.
with self.assertRaisesRegex(
FloatingPointError, r"invalid value \(nan\) encountered in jit\(div\)"):
FloatingPointError,
r"invalid value \(nan\) encountered in jit\(f\)"):
jax.jit(f)(inp, inp)


class DebugInfsTest(jtu.JaxTestCase):

def setUp(self):
super().setUp()
self.cfg = config._read("jax_debug_infs")
config.update("jax_debug_infs", True)

def tearDown(self):
config.update("jax_debug_infs", self.cfg)
super().tearDown()

def testSingleResultPrimitiveNoInf(self):
A = jnp.array([[1., 2.], [2., 3.]])
Expand Down
14 changes: 7 additions & 7 deletions tests/pjit_test.py
Expand Up @@ -3039,27 +3039,28 @@ def f(x):
def test_pjit_no_global_cache_hit_axis_resources(self):
mesh = jtu.create_global_mesh((1,), ('x',))
s = NamedSharding(mesh, P('x'))
inp = jnp.arange(8.0)

with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(jnp.arange(8.0))
pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)(inp)
self.assertEqual(count[0], 10)

with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pjit(lambda x: x * 2, device=jax.devices()[0])(jnp.arange(8.))
pjit(lambda x: x * 2, device=jax.devices()[0])(inp)
self.assertEqual(count[0], 10)

pf = pjit(lambda x: x * 2, in_shardings=s, out_shardings=s)
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf(jnp.arange(8.))
pf(inp)
self.assertEqual(count[0], 1)

pf1 = pjit(lambda x: x * 2, device=jax.devices()[0])
with jtu.count_pjit_cpp_cache_miss() as count:
for _ in range(10):
pf1(jnp.arange(8.))
pf1(inp)
self.assertEqual(count[0], 1)

def test_with_sharding_constraint_spmd_axis_name(self):
Expand Down Expand Up @@ -3176,9 +3177,8 @@ def test_device_assignment_mismatch_apply_primitive(self):
arr2 = jax.device_put(np.arange(8), jax.devices()[1])
with self.assertRaisesRegex(
ValueError,
"Received incompatible devices for jitted computation. Got argument "
r"args\[0\] of concatenate with shape int.*\[8\].*and argument "
r"args\[1\].*"):
"Received incompatible devices for jitted computation. Got argument.*"
r"of concatenate with shape int.*\[8\].*and argument.*"):
jnp.concatenate([arr, arr2])

def test_device_put_grad(self):
Expand Down

0 comments on commit e624610

Please sign in to comment.