Skip to content

Commit

Permalink
Add static_argnames to the _cpp_pjit path.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 499311688
  • Loading branch information
pschuh authored and jax authors committed Jan 3, 2023
1 parent 59d3257 commit 9674b06
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 10 deletions.
11 changes: 7 additions & 4 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def __init__(self):

_most_recent_pjit_call_executable = _MostRecentPjitCallExecutable()

def _cpp_pjit(fun: Callable, infer_params, static_argnums):

def _cpp_pjit(fun: Callable, infer_params, static_argnums, static_argnames):

def cache_miss(*args, **kwargs):
global _most_recent_pjit_call_executable
Expand Down Expand Up @@ -163,7 +164,9 @@ def cache_miss(*args, **kwargs):

cpp_pjit_f = xc._xla.pjit( # type: ignore
getattr(fun, "__name__", "<unnamed function>"), # type:ignore
cache_miss, static_argnums)
cache_miss,
static_argnums,
static_argnames)

return wraps(fun)(cpp_pjit_f)

Expand Down Expand Up @@ -476,8 +479,8 @@ def infer_params(*args, _global_avals=False, **kwargs):
return (args_flat, local_in_avals, params, in_tree, out_tree(),
donate_argnums)

if FLAGS.experimental_cpp_pjit and xla_extension_version >= 111:
wrapped = _cpp_pjit(fun, infer_params, static_argnums)
if FLAGS.experimental_cpp_pjit and xla_extension_version >= 115:
wrapped = _cpp_pjit(fun, infer_params, static_argnums, static_argnames)
else:
wrapped = _python_pjit(fun, infer_params)

Expand Down
4 changes: 2 additions & 2 deletions jax/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -3583,7 +3583,7 @@ def create_cpp_call(self, no_kwargs, in_tree, out_tree):
not self.unsafe_call.has_host_callbacks):
return None

if not flags.FLAGS.experimental_cpp_pjit or xla_extension_version < 111:
if not flags.FLAGS.experimental_cpp_pjit or xla_extension_version < 115:
return None

def aot_cache_miss(*args, **kwargs):
Expand All @@ -3603,7 +3603,7 @@ def aot_cache_miss(*args, **kwargs):
fastpath_data = None
return outs, fastpath_data

return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, []) # type: ignore
return xc._xla.pjit(self.unsafe_call.name, aot_cache_miss, [], []) # type: ignore


def _out_shardings_for_trivial(
Expand Down
36 changes: 32 additions & 4 deletions tests/pjit_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2542,8 +2542,8 @@ def test_multi_device_pjit_mul(self):

@jax_array(True)
def test_single_device_pjit_cpp_dispatch(self):
if xla_extension_version < 111:
self.skipTest('Does not work for xla_extension_version < 111')
if xla_extension_version < 115:
self.skipTest('Does not work for xla_extension_version < 115')

shape = (8, 2)
mesh = jtu.create_global_mesh((1,), ('x',))
Expand Down Expand Up @@ -2572,8 +2572,8 @@ def pjit_lower_and_count(*args, **kwargs):

@jax_array(True)
def test_single_device_add_single_compile(self):
if xla_extension_version < 111:
self.skipTest('Does not work for xla_extension_version < 111')
if xla_extension_version < 115:
self.skipTest('Does not work for xla_extension_version < 115')

f1 = pjit(lambda x, y: x + y)
a = jax.device_put(jnp.array([1, 2, 3], dtype=jnp.float32),
Expand Down Expand Up @@ -2739,6 +2739,34 @@ def f(x: str) -> int:
assert f_names('foo') == 1
assert f_names(x='foo') == 1

def test_pjit_with_static_argnames_cpp_dispatch(self):
if xla_extension_version < 115:
self.skipTest('Does not work for xla_extension_version < 115')

original_pjit_lower = pjit_lib._pjit_lower
count = 0

def pjit_lower_and_count(*args, **kwargs):
nonlocal count
count += 1
return original_pjit_lower(*args, **kwargs)

def f(y, **kwargs):
self.assertEqual(kwargs, {'x': 'foo'})
return y * y

try:
pjit_lib._pjit_lower = pjit_lower_and_count
y = jnp.arange(8.)

f_names = pjit(f, static_argnames='x')
f_names(y, x='foo')
f_names(y, x='foo')

self.assertEqual(count, 1)
finally:
pjit_lib._pjit_lower = original_pjit_lower

def test_new_static_argnum_on_keyword_arguments(self):
f = pjit(lambda x: x, static_argnums=0)
y = f(x=4)
Expand Down

0 comments on commit 9674b06

Please sign in to comment.