Skip to content

Commit

Permalink
Preserve shardings on the output of pjit that were provided on the ar…
Browse files Browse the repository at this point in the history
…guments.

Following are the changes:

* Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function.
* Split lower_sharding_computation into 3 caches:
  * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd
  * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit
  * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal.

The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user.

For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input.

PiperOrigin-RevId: 522991145
  • Loading branch information
yashk2810 authored and chrisflesher committed Jun 3, 2023
1 parent 003d39e commit a9192d8
Show file tree
Hide file tree
Showing 4 changed files with 669 additions and 265 deletions.
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Expand Up @@ -218,7 +218,7 @@ def sharded_lowering(fun, name, donated_invars, keep_unused,
# apply it to all out_avals.
return pxla.lower_sharding_computation(
fun, 'jit', name, in_shardings, pxla._UNSPECIFIED, donated_invars,
in_avals, keep_unused=keep_unused, always_lower=False,
tuple(in_avals), keep_unused=keep_unused, always_lower=False,
devices_from_context=None, lowering_platform=lowering_platform)


Expand Down

0 comments on commit a9192d8

Please sign in to comment.