Skip to content

Commit

Permalink
convert to doctest
Browse files Browse the repository at this point in the history
  • Loading branch information
yashk2810 committed Mar 3, 2022
1 parent f0df7b7 commit 852e39e
Showing 1 changed file with 82 additions and 54 deletions.
136 changes: 82 additions & 54 deletions jax/experimental/global_device_array.py
Expand Up @@ -219,40 +219,51 @@ class GlobalDeviceArray:
is_fully_replicated : True if the full array value is present on all devices
of the global mesh.
Example::
# Logical mesh is (hosts, devices)
assert global_mesh.shape == {'x': 4, 'y': 8}
global_input_shape = (64, 32)
mesh_axes = P('x', 'y')
# Dummy example data; in practice we wouldn't necessarily materialize global data
# in a single process.
global_input_data = np.arange(
np.prod(global_input_shape)).reshape(global_input_shape)
def get_local_data_slice(index):
# index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4))
# This method will be called per-local device from the GDA constructor.
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(
global_input_shape, global_mesh, mesh_axes, get_local_data_slice)
f = pjit(lambda x: x @ x.T, out_axis_resources = P('y', 'x'))
Example:
>>> from jax.experimental.maps import Mesh
>>> from jax.experimental import PartitionSpec as P
>>> import numpy as np
>>> assert jax.device_count() == 8
>>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> # Logical mesh is (hosts, devices)
>>> assert global_mesh.shape == {'x': 4, 'y': 2}
>>> global_input_shape = (8, 2)
>>> mesh_axes = P('x', 'y')
>>> # Dummy example data; in practice we wouldn't necessarily materialize global data
>>> # in a single process.
>>> global_input_data = np.arange(
... np.prod(global_input_shape)).reshape(global_input_shape)
>>> def get_local_data_slice(index):
... # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4))
... # This method will be called per-local device from the GDA constructor.
... return global_input_data[index]
>>> gda = GlobalDeviceArray.from_callback(
... global_input_shape, global_mesh, mesh_axes, get_local_data_slice)
>>> print(gda.shape)
(8, 2)
>>> print(gda.local_shards[0].data) # Access the data on a single local device
[[0]
[2]]
>>> print(gda.local_shards[0].data.shape)
(2, 1)
>>> # Numpy-style index into the global array that this data shard corresponds to
>>> print(gda.local_shards[0].index)
(slice(0, 2, None), slice(0, 1, None))
# Allow pjit to output GDAs
jax.config.update('jax_parallel_functions_output_gda', True)
f = pjit(lambda x: x @ x.T, in_axis_resources=P('x', 'y'), out_axis_resources = P('x', 'y'))
with global_mesh:
out = f(gda)
print(type(out)) # GlobalDeviceArray
print(out.shape) # global shape == (64, 64)
print(out.local_shards[0].data) # Access the data on a single local device,
# e.g. for checkpointing
print(out.local_shards[0].data.shape) # per-device shape == (8, 16)
print(out.local_shards[0].index) # Numpy-style index into the global array that
# this data shard corresponds to
# `out` can be passed to another pjit call, out.local_shards can be used to
# export the data to non-jax systems (e.g. for checkpointing or logging), etc.
"""
Expand Down Expand Up @@ -394,11 +405,17 @@ def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh,
Example::
global_input_shape = (8, 2)
global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
def cb(index):
return global_input_data[index]
gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)
>>> from jax.experimental.maps import Mesh
>>> import numpy as np
>>> global_input_shape = (8, 8)
>>> mesh_axes = ['x', 'y']
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
>>> def cb(index):
... return global_input_data[index]
>>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb)
>>> gda.local_data(0).shape
(4, 2)
Args:
global_shape : The global shape of the array
Expand Down Expand Up @@ -431,13 +448,18 @@ def from_batched_callback(cls, global_shape: Shape,
Example::
global_input_shape = (8, 2)
global_input_data = np.arange(
prod(global_input_shape)).reshape(global_input_shape)
def batched_cb(indices):
self.assertEqual(len(indices),len(global_mesh.local_devices))
return [global_input_data[index] for index in indices]
gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb)
>>> from jax.experimental.maps import Mesh
>>> import numpy as np
>>> global_input_shape = (8, 2)
>>> mesh_axes = ['x']
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
>>> def batched_cb(indices):
... assert len(indices) == len(global_mesh.local_devices)
... return [global_input_data[index] for index in indices]
>>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb)
>>> gda.local_data(0).shape
(2, 2)
Args:
global_shape : The global shape of the array
Expand Down Expand Up @@ -469,17 +491,23 @@ def from_batched_callback_with_devices(
Example::
global_input_shape = (8, 2)
global_input_data = np.arange(prod(global_input_shape), dtype=np.float32).reshape(global_input_shape)
def cb(cb_inp):
self.assertLen(cb_inp, len(global_mesh.local_devices))
dbs = []
for inp in cb_inp:
index, devices = inp
array = global_input_data[index]
dbs.extend([jax.device_put(array, device) for device in devices])
return dbs
gda = GlobalDeviceArray.from_batched_callback_with_devices(global_input_shape, global_mesh, mesh_axes, cb)
>>> from jax.experimental.maps import Mesh
>>> import numpy as np
>>> global_input_shape = (8, 2)
>>> mesh_axes = [('x', 'y')]
>>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y'))
>>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape)
>>> def cb(cb_inp):
... dbs = []
... for inp in cb_inp:
... index, devices = inp
... array = global_input_data[index]
... dbs.extend([jax.device_put(array, device) for device in devices])
... return dbs
>>> gda = GlobalDeviceArray.from_batched_callback_with_devices(
... global_input_shape, global_mesh, mesh_axes, cb)
>>> gda.local_data(0).shape
(1, 2)
Args:
global_shape : The global shape of the array
Expand Down

0 comments on commit 852e39e

Please sign in to comment.