Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Preserve shardings on the output of pjit that were provided on the ar…
…guments. Following are the changes: * Make _pjit_lower_cached depend on exact sharding equality if `_original_sharding` exists. This top level cache should fill up eventually if users are passing different shardings into the pjit function. * Split lower_sharding_computation into 3 caches: * _trace_to_jaxpr_and_dce cache -- This will return a closed jaxpr which is DCE'd * _cached_lowering_to_hlo cache -- This will cache the generation of MHLO. This cache is dependent on the semantic equality of shardings i.e. if 2 shardings lower to the same OpSharding, then there will be a cache hit * _cached_compilation cache -- This caches the compilation so that we don't recompile if the shardings are semantically equal. The way this works is the out_handlers are created again if we pass in different shardings to pjit (but there is no recompilation). This allows us to maintain the shardings passed by the user. For ops like `jnp.squeeze` where we infer the sharding from the executable, we try to recreate a NamedSharding (right now, more support will be added in following CLs) from the GSPMDSharding since it will be available on the input. PiperOrigin-RevId: 522991145
- Loading branch information