Skip to content

Commit

Permalink
[c++ jit] only set use_fastpath in cache_miss if all args are DeviceA…
Browse files Browse the repository at this point in the history
…rrays

fixes #12542

Co-authored-by: Peter Hawkins <phawkins@google.com>
Co-authored-by: Kuangyuan Chen <chky@google.com>
  • Loading branch information
3 people committed Sep 28, 2022
1 parent 933b6a2 commit b175e11
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/api.py
Expand Up @@ -537,6 +537,8 @@ def _device_array_use_fast_path(execute, out_pytree_def, args_flat, out_flat):
not execute.args[5] and not execute.args[6] and
# Has no host callbacks
not execute.args[8] and
# impl rule must have been called, i.e. top trace is an EvalTrace
isinstance(core.find_top_trace(args_flat), core.EvalTrace) and
# Not supported: ShardedDeviceArray
all(device_array.type_is_device_array(x) for x in out_flat) and
# Not supported: dynamic shapes
Expand Down
47 changes: 47 additions & 0 deletions tests/api_test.py
Expand Up @@ -3872,6 +3872,53 @@ def test_jit_negative_static_argnums(self):
g = jax.jit(lambda x, y: x * y, static_argnums=-1)
g(1, 2) # doesn't crash

def test_fastpath_cache_confusion(self):
# https://github.com/google/jax/issues/12542
@jax.jit
def a(x):
return ()

@jax.jit
def b(x):
return a(x)


@jax.jit
def g(x):
return x, x

@jax.jit
def h(x):
return g(x)

jaxpr = jax.make_jaxpr(h)(7)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)

b(8) # don't crash

def test_fastpath_cache_confusion2(self):
@jax.jit
def a(): # note nullary function, still staged out though
return ()

@jax.jit
def b(x):
return a()


@jax.jit
def g(x):
return x, x

@jax.jit
def h(x):
return g(x)

jaxpr = jax.make_jaxpr(h)(7)
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, 7)

b(8) # don't crash


@jtu.with_config(jax_experimental_subjaxpr_lowering_cache=True)
class SubcallTraceCacheTest(jtu.JaxTestCase):
Expand Down

0 comments on commit b175e11

Please sign in to comment.