Skip to content

Commit

Permalink
Simplifies full to not instantiate intermediate array with default sh…
Browse files Browse the repository at this point in the history
…arding, this significantly reduces overhead when creating sharded arrays in eager mode when using jnp.zeros_like(...)

PiperOrigin-RevId: 606765964
  • Loading branch information
marksandler2 authored and jax authors committed Feb 13, 2024
1 parent 2adefe9 commit 2717dae
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 19 deletions.
51 changes: 32 additions & 19 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@
from jax._src import dispatch
from jax._src import dtypes
from jax._src import effects
from jax._src import shard_alike
from jax._src import linear_util as lu
from jax._src import pretty_printer as pp
from jax._src import source_info_util
Expand All @@ -63,7 +62,6 @@
standard_primitive)
from jax._src import xla_bridge
from jax._src.lib import xla_client
from jax._src.lib import xla_extension_version
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import chlo
from jax._src.lib.mlir.dialects import hlo
Expand Down Expand Up @@ -1213,7 +1211,9 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
fill_value: the value to fill the new array with.
dtype: the type of the output array, or `None`. If not `None`, `fill_value`
will be cast to `dtype`.
sharding: an optional sharding specification for the resulting array.
sharding: an optional sharding specification for the resulting array,
note, sharding will currently be ignored in jitted mode, this might change
in the future.
"""
shape = canonicalize_shape(shape)
if np.shape(fill_value):
Expand All @@ -1224,10 +1224,19 @@ def full(shape: Shape, fill_value: ArrayLike, dtype: DTypeLike | None = None, *,
weak_type = dtype is None and dtypes.is_weakly_typed(fill_value)
dtype = dtypes.canonicalize_dtype(dtype or _dtype(fill_value))
fill_value = _convert_element_type(fill_value, dtype, weak_type)
out = broadcast(fill_value, shape)
if sharding is not None:
return array.make_array_from_callback(shape, sharding, lambda idx: out[idx])
return out
# In tracing mode we can't set sharing explictly and PmapShardng is not
# supported.
# NB: Consider using with_sharding_constraint in jitted computation
# if needed?
if (sharding is not None and not isinstance(sharding, PmapSharding) and
isinstance(fill_value, array.ArrayImpl)):

broadcast_shape = sharding.shard_shape(shape)
shard = broadcast(fill_value, broadcast_shape)
return array.make_array_from_callback(shape, sharding, lambda _: shard)

return broadcast(fill_value, shape)



def zeros_like_shaped_array(aval: ShapedArray) -> Array:
Expand Down Expand Up @@ -1370,7 +1379,10 @@ def full_like(x: ArrayLike | DuckTypedArray,
shape: optional, a shape parameter for the output ndarray.
sharding: an optional sharding specification for the resulting array.
If not specified, the output will have the same sharding as the input,
so long as ``shape`` is also not specified.
with a few exceptions/limitations in particular:
1. Sharding is not available during tracing, thus this will rely on jit.
2. If x is weakly typed or uncomitted, will use default sharding.
3. Shape is not None and is different from x.shape, default will be used.
Returns:
An ndarray with the same shape as `x` with its entries set equal to
Expand All @@ -1381,19 +1393,20 @@ def full_like(x: ArrayLike | DuckTypedArray,
dtype = dtype or _dtype(x)
if dtypes.issubdtype(dtype, dtypes.extended):
return dtype._rules.full(fill_shape, fill_value, dtype) # type: ignore[union-attr]

use_x_sharding = (
sharding is None and
isinstance(x, array.ArrayImpl) and
not weak_type and x._committed and
# NB: consider reusng x.sharding for mismatched shapes
# if x is replicated or single device.
fill_shape == x.shape)
if use_x_sharding:
assert isinstance(x, array.ArrayImpl) # makes pytype happy.
# TODO(yashkatariya): Use shard_alike in tracing_mode once it is supported.
sharding = x.sharding
val = full(fill_shape, _convert_element_type(fill_value, dtype, weak_type),
sharding=sharding)
# TODO(yashkatariya): Use shard_like in tracing mode too i.e. remove the
# ArrayImpl check.
if shape is None and sharding is None and isinstance(x, array.ArrayImpl):
if xla_extension_version < 227:
sharding = x.sharding # type: ignore[union-attr]
if (not dispatch.is_single_device_sharding(sharding) and
not isinstance(sharding, PmapSharding)):
return array.make_array_from_callback(
type_cast(array.Shape, fill_shape), sharding, lambda idx: val[idx])
else:
return shard_alike.shard_alike(x, val)[1]
return val


Expand Down
18 changes: 18 additions & 0 deletions tests/lax_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2795,6 +2795,24 @@ def testOnesWithInvalidShape(self):
with self.assertRaises(TypeError):
jnp.ones((-1, 1))

def test_full_like_commited(self):
x = jnp.array((1, 2, 3), dtype=np.int32)
self.assertFalse(x._committed)
self.assertFalse(lax.full_like(x, 1.1)._committed)
x = jax.device_put(x, jax.devices()[-1])
self.assertTrue(x._committed)
y = lax.full_like(x, 1.1)
self.assertTrue(y._committed)
self.assertEqual(x.sharding, y.sharding)

def test_zeros_like_with_explicit_device_and_jitted(self):
x = jnp.array((1, 2, 3), dtype=np.int32)
x = jax.device_put(x, jax.devices()[0])
zeros_like_with_device = partial(jnp.zeros_like, device=jax.devices()[0])
y = jax.jit(zeros_like_with_device)(x)
self.assertEqual(x.shape, y.shape)
self.assertEqual(y.sharding, SingleDeviceSharding(jax.devices()[0]))

@jtu.sample_product(
[dict(shape=shape, out_shape=out_shape, fill_value_shape=fill_value_shape)
for shape in array_shapes
Expand Down
60 changes: 60 additions & 0 deletions tests/multi_device_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os
from unittest import SkipTest
import tracemalloc as tm

from absl.testing import absltest

Expand Down Expand Up @@ -285,6 +286,65 @@ def test_lax_full_like_sharding(self):
y = lax.full_like(x, 1, sharding=sharding)
self.assertEqual(y.sharding, sharding)

def test_lax_full_like_same_device(self):
devices = self.get_devices()
x = jax.device_put(jnp.ones((100, 100)), devices[1])
y = lax.full_like(x, 1)
self.assertEqual(y.sharding, x.sharding)
self.assertEqual(y.sharding.device_set, {jax.devices()[1]})


def test_lax_full_like_custom_shape_sharded(self):
devices = [self.get_devices()]
mesh = Mesh(devices, axis_names=('i', 'j'))
sharding = NamedSharding(mesh, P('i', 'j'))
x = jnp.array(jnp.arange(8).reshape((1, 8)), dtype=jnp.int32)
x = jax.device_put(x, sharding)
y = lax.full_like(x, fill_value=1.0, shape=())
self.assertEqual(y.shape, ())

def test_lax_full_like_single_device(self):
devices = self.get_devices()
x = jax.device_put(jnp.ones((100, 100)), devices[1])
y = lax.full_like(x, fill_value=1.0, shape=())
self.assertEqual(y.shape, ())
# Currently if shape is provided the sharding will revert
# to default. This might change in the future and this test might
# need to be updated.
self.assertEqual(
y.sharding,
jax.sharding.SingleDeviceSharding(jax.devices()[0]))


def test_lax_full_like_efficient(self):
devices = self.get_devices()
if len(devices) < 4:
self.skipTest("test requires 4 devices")
mem_stats = devices[0].memory_stats()
if mem_stats is None:
self.skipTest('Only can run test on device with mem_stats')
mesh = Mesh(devices, axis_names=("i"))
sharding = NamedSharding(mesh, P('i'))
available_memory = mem_stats['bytes_reservable_limit']
array_size = available_memory // (6 * len(devices)) * len(devices)
# Set up tracemalloc to track memory usage.
tm.start()
x = lax.full([array_size], sharding=sharding, fill_value=1.0,
dtype=jnp.float32)
y = lax.full_like(x, fill_value=1.0, dtype=jnp.float32)

# Wait until computation finished to ensure we are measuring the correct
# thing.
y.block_until_ready()
unused_current, peak = tm.get_traced_memory()
# Verify that we don't create large CPU arrays.
self.assertLess(peak, array_size // len(devices))

# Important: make sure that all jax computation in this part has finished
# before we can stop trace_malloc.
jax.effects_barrier()
tm.stop()
self.assertEqual(y.sharding, x.sharding)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit 2717dae

Please sign in to comment.