Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Convert everything in pjit to the
Sharding
interface. The following…
… contains the things that have changed in this CL: * All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs. * `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances. * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled. * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used. * Checking of sharding with `aval` has a handler system to deal with sharding instances. * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding. * `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us. * _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL. * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too. * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface. * `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`. * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998 * `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach. * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done. PiperOrigin-RevId: 459548974
- Loading branch information