Skip to content

Commit

Permalink
Convert shardings in jit path to OpShardingSharding to avoid recomp…
Browse files Browse the repository at this point in the history
…ilation when semantically similar shardings are used in `jit`.

PiperOrigin-RevId: 477626548
  • Loading branch information
yashk2810 authored and jax authors committed Sep 29, 2022
1 parent 500f8b7 commit 163b7e2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion jax/_src/api.py
Expand Up @@ -670,13 +670,15 @@ def _jit_lower(fun, static_argnums, static_argnames, device, backend,

def arg_spec(x):
from jax._src.sharding import PmapSharding
from jax.experimental import pjit
# like xla.arg_spec but duck-types on x.shape and x.dtype
aval = None if jax.config.jax_dynamic_shapes else shaped_abstractify(x)
if jax.config.jax_array:
if hasattr(x, 'sharding'):
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (x.sharding if x._committed else None)
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim)
if x._committed else None)
else:
return aval, None
else:
Expand Down
4 changes: 3 additions & 1 deletion jax/_src/dispatch.py
Expand Up @@ -91,13 +91,15 @@

def arg_spec(x: Any) -> ArgSpec:
from jax._src.sharding import PmapSharding
from jax.experimental import pjit

aval = xla.abstractify(x)
try:
if config.jax_array:
if isinstance(x.sharding, PmapSharding):
return aval, None
return aval, (x.sharding if x._committed else None)
return aval, (pjit.to_op_sharding_sharding(x.sharding, x.ndim) # type: ignore
if x._committed else None)
else:
return aval, x._device
except:
Expand Down

0 comments on commit 163b7e2

Please sign in to comment.