Skip to content

Commit

Permalink
Optimize host_local_array_to_global_array by caching the local to g…
Browse files Browse the repository at this point in the history
…lobal conversion and flattening of axis resources. Also take a fast path for device_put which does not do `abstractify` and only canonicalize_dtype on the entire array once (instead of doing it for every shard).

This results in a 5x speedup!

Before:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array       3.03 ms         3.02 ms          220
```

After:

```
---------------------------------------------------------------------------
Benchmark                                 Time             CPU   Iterations
---------------------------------------------------------------------------
host_local_array_to_global_array      0.673 ms        0.671 ms          985
```

PiperOrigin-RevId: 489880547
  • Loading branch information
yashk2810 authored and jax authors committed Nov 21, 2022
1 parent 6d11567 commit 928dee4
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 14 deletions.
14 changes: 14 additions & 0 deletions benchmarks/api_benchmark.py
Expand Up @@ -31,6 +31,7 @@
from jax._src import sharding
from jax.experimental import pjit as pjit_lib
from jax.experimental import maps
from jax.experimental import multihost_utils
import jax.numpy as jnp
import numpy as np

Expand Down Expand Up @@ -817,5 +818,18 @@ def pjit_aot_4000_device(state):
use_aot=True)


@google_benchmark.register
@google_benchmark.option.unit(google_benchmark.kMillisecond)
def host_local_array_to_global_array(state):
global_mesh = create_mesh((4, 2), ('x', 'y'), state)
input_shape = (8, 2)
input_data = np.arange(np.prod(input_shape)).reshape(input_shape)
in_pspec = pxla.PartitionSpec('x', 'y')

while state:
multihost_utils.host_local_array_to_global_array(
(input_data, input_data), global_mesh, (in_pspec, in_pspec))


if __name__ == "__main__":
google_benchmark.main()
51 changes: 37 additions & 14 deletions jax/experimental/multihost_utils.py
Expand Up @@ -14,15 +14,17 @@
"""Utilities for synchronizing and communication across multiple hosts."""

import functools
import itertools as it
from typing import Optional
import zlib

import jax
from jax.tree_util import tree_flatten, tree_map, tree_unflatten
from jax._src import dispatch
from jax._src import array
from jax._src import sharding
from jax.tree_util import PyTreeDef
from jax.interpreters import pxla
from jax.interpreters import pxla, xla
from jax.experimental import maps
from jax.experimental import pjit as pjit_lib
from jax.experimental.pjit import pjit, FROM_GDA
Expand Down Expand Up @@ -228,6 +230,27 @@ def should_save(step_id: int) -> bool:
return sync_manager.reached_sync_point(step_id)


@functools.lru_cache()
def _flatten_pspecs(name, in_tree, pspecs_thunk):
return pjit_lib.flatten_axis_resources(
name, in_tree, pspecs_thunk(), tupled_args=True)

@functools.lru_cache()
def _local_to_global_aval(local_aval, mesh, pspec):
return mesh._local_to_global(pxla._get_array_mapping(pspec), local_aval)

@functools.lru_cache()
def _global_to_local_aval(global_aval, mesh, pspec):
return mesh._global_to_local(
pxla._get_array_mapping(pspec), global_aval)

def _device_put(x, device):
try:
return dispatch.device_put_handlers[type(x)](x, device)
except KeyError as err:
raise TypeError(f"No device_put handler for type: {type(x)}") from err


def host_local_array_to_global_array(local_inputs, global_mesh, pspecs):
"""Converts a host local value to a globally sharded `jax.Array`.
Expand Down Expand Up @@ -289,21 +312,22 @@ def _convert(arr, pspec):
local_sharding._to_xla_op_sharding(arr.ndim))):
arrays = arr._arrays
else:
arrays = [
jax.device_put(arr[index], d)
arr = xla.canonicalize_dtype(arr)
arrays = list(it.chain.from_iterable(
_device_put(arr[index], d)
for d, index in local_sharding.devices_indices_map(arr.shape).items()
]
))

global_aval = _local_to_global_aval(
jax.ShapedArray(arr.shape, arrays[0].dtype), global_mesh, pspec)

global_aval = global_mesh._local_to_global(
pxla._get_array_mapping(pspec),
jax.ShapedArray(arr.shape, arrays[0].dtype))
return array.ArrayImpl(
global_aval, jax.sharding.NamedSharding(global_mesh, pspec),
arrays, committed=True)
arrays, committed=True, _skip_checks=True)

flattened_inps, in_tree = tree_flatten(local_inputs)
in_pspecs = pjit_lib.flatten_axis_resources(
'input pspecs', in_tree, pspecs, tupled_args=True)
in_pspecs = _flatten_pspecs('input pspecs', in_tree,
pjit_lib.hashable_pytree(pspecs))
out = tree_map(_convert, tuple(flattened_inps), in_pspecs)
return tree_unflatten(in_tree, out)

Expand Down Expand Up @@ -354,14 +378,13 @@ def _convert(arr, pspec):
# If the Array is already fully addressable i.e. host local, return it.
if isinstance(arr, array.ArrayImpl) and arr.is_fully_addressable:
return arr
local_aval = global_mesh._global_to_local(
pxla._get_array_mapping(pspec), arr.aval)
local_aval = _global_to_local_aval(arr.aval, global_mesh, pspec)
return array.ArrayImpl(
local_aval, jax.sharding.NamedSharding(global_mesh.local_mesh, pspec),
arr._arrays, committed=True)

flattened_inps, out_tree = tree_flatten(global_inputs)
out_pspecs = pjit_lib.flatten_axis_resources(
'output pspecs', out_tree, pspecs, tupled_args=True)
out_pspecs = _flatten_pspecs('output pspecs', out_tree,
pjit_lib.hashable_pytree(pspecs))
out = tree_map(_convert, tuple(flattened_inps), out_pspecs)
return tree_unflatten(out_tree, out)

0 comments on commit 928dee4

Please sign in to comment.