Skip to content

Commit

Permalink
Convert everything in pjit to the Sharding interface. The following…
Browse files Browse the repository at this point in the history
… contains the things that have changed in this CL:

* All in_axis_resources and out_axis_resources are instances of `Sharding`. When `config.jax_array` is enabled, `in_shardings` is inferred from the inputs.

* `out_shardings` are still instances of `MeshPspecSharding` even if `Array` are used. In a follow up CL, I will change out_axis_resources to accept `Sharding` instances.
  * This is also a reason why you still need a mesh context manager when `config.jax_array` is enabled.
  * cl/458267790 is WIP for this. It adds a couple of checks in MeshPspecSharding too when `AUTO` is used.

* Checking of sharding with `aval` has a handler system to deal with sharding instances.
  * The reason for creating a `pjit` specific system rather than putting this check on the sharding instances is because each transformation has a different way of checking the sharding. The best example for this is `pjit` and `xmap`. They both have different way to check if an aval is sharded properly with respect to the given sharding because `pjit` and `xmap` has different ways to express sharding.

* `MeshPspecSharding` and `SingleDeviceSharding` have `__hash__` and `__eq__`. So now we don't have to pass around canonicalized pspecs in the new path to get cache hits. The `Sharding` instances should handle that for us.

* _pjit_lower still depends on mesh which is the major reason why I haven't removed `resource_env` from `params`. But in the interest of keep this CL small (LOL), I'll make those changes in a follow up CL.
  * Also the private functions in pxla.py are used by pathways and automap so I'll have to modify those too.
  * Also it has `pxla.resource_typecheck` which I haven't figured out how to move it to sharding interface.

* `_to_xla_op_sharding` takes in `axis_ctx` as an extra **optional** parameter. This is required for `with_sharding_constraint`.
  * `with_sharding_constraint` uses the MLIR `ctx` here: cl/458042998

* `pjit`'s batching handlers add an extra dimension to the axis_resources. Since this is dependent on how each transformation adds the extra dimension and it also differs on how each sharding instance will handle it, I added a handler system for this too. Again `xmap` and `pjit` differ a lot here. This is why I went with the handler approach.
  * MeshPspecSharding handles this `insert_axis_partitions` on the parsed partition spec. I have added more detailed comments in the place where this is done.

PiperOrigin-RevId: 459548974
  • Loading branch information
yashk2810 authored and jax authors committed Jul 7, 2022
1 parent 88c1e7d commit 2314951
Show file tree
Hide file tree
Showing 4 changed files with 418 additions and 243 deletions.
27 changes: 14 additions & 13 deletions jax/experimental/jax2tf/jax2tf.py
Expand Up @@ -17,7 +17,7 @@
import os
import re
import threading
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union, cast

import jax
from jax import lax
Expand All @@ -28,6 +28,7 @@
from jax import numpy as jnp
from jax.experimental import maps
from jax.experimental import pjit
from jax.experimental import sharding
from jax.interpreters import ad
from jax.interpreters import partial_eval
from jax.interpreters import pxla
Expand Down Expand Up @@ -2670,13 +2671,13 @@ def split_to_logical_devices(tensor: TfVal,
return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)


def _shard_value(mesh: maps.Mesh,
val: TfVal,
def _shard_value(val: TfVal,
aval: core.ShapedArray,
axis_resources: pjit.ParsedPartitionSpec) -> TfVal:
sd: sharding.XLACompatibleSharding) -> TfVal:
"""Apply sharding to a TfVal."""
sharding_proto: xla_client.OpSharding = pjit.get_aval_sharding_proto(
aval, axis_resources, mesh)
sharding_proto: xla_client.OpSharding = cast(
xla_client.OpSharding, sd._to_xla_op_sharding(aval.ndim))

# To use xla_sharding.py, we must have a xla_data_pb2.OpSharding.
xla_sharding_proto: xla_data_pb2.OpSharding = (
xla_data_pb2.OpSharding(
Expand All @@ -2691,8 +2692,8 @@ def _shard_value(mesh: maps.Mesh,

def _pjit(*args: TfVal,
jaxpr: core.ClosedJaxpr,
in_axis_resources: Sequence[pjit.ParsedPartitionSpec],
out_axis_resources: Sequence[pjit.ParsedPartitionSpec],
in_shardings: Sequence[sharding.XLACompatibleSharding],
out_shardings: Sequence[sharding.XLACompatibleSharding],
resource_env: maps.ResourceEnv,
donated_invars,
name: str,
Expand All @@ -2704,15 +2705,13 @@ def _pjit(*args: TfVal,
if resource_env.physical_mesh.is_multi_process:
raise NotImplementedError("jax2tf translation for pjit over multi-process "
"meshes is not supported yet")
# TODO: add `name` to the name stack
shard_value_for_mesh = partial(_shard_value, resource_env.physical_mesh)
# Apply sharding annotation to the arguments
sharded_args: Sequence[TfVal] = tuple(
map(shard_value_for_mesh, args, _in_avals, in_axis_resources))
map(_shard_value, args, _in_avals, in_shardings))
results = _interpret_jaxpr(jaxpr, *sharded_args,
extra_name_stack=util.wrap_name(name, "pjit"))
sharded_results: Sequence[TfVal] = tuple(
map(shard_value_for_mesh, results, _out_aval, out_axis_resources))
map(_shard_value, results, _out_aval, out_shardings))
return tuple(sharded_results)


Expand All @@ -2725,7 +2724,9 @@ def _pjit_sharding_constraint(arg: TfVal, *,
_in_avals: Sequence[core.ShapedArray],
_out_aval: core.ShapedArray,
**kwargs) -> TfVal:
return _shard_value(resource_env.physical_mesh, arg, _in_avals[0], axis_resources)
ms = sharding.MeshPspecSharding._from_parsed_pspec(
resource_env.physical_mesh, axis_resources)
return _shard_value(arg, _in_avals[0], ms)


tf_impl_with_avals[pjit.sharding_constraint_p] = _pjit_sharding_constraint
Expand Down

0 comments on commit 2314951

Please sign in to comment.