From 852e39e2edb6e82a21ac07e352a65599881467ce Mon Sep 17 00:00:00 2001 From: yashkatariya Date: Thu, 3 Mar 2022 13:40:42 -0800 Subject: [PATCH] convert to doctest --- jax/experimental/global_device_array.py | 136 ++++++++++++++---------- 1 file changed, 82 insertions(+), 54 deletions(-) diff --git a/jax/experimental/global_device_array.py b/jax/experimental/global_device_array.py index 38fb944d6748..1c1c4b8bb995 100644 --- a/jax/experimental/global_device_array.py +++ b/jax/experimental/global_device_array.py @@ -219,40 +219,51 @@ class GlobalDeviceArray: is_fully_replicated : True if the full array value is present on all devices of the global mesh. - Example:: - - # Logical mesh is (hosts, devices) - assert global_mesh.shape == {'x': 4, 'y': 8} - - global_input_shape = (64, 32) - mesh_axes = P('x', 'y') - - # Dummy example data; in practice we wouldn't necessarily materialize global data - # in a single process. - global_input_data = np.arange( - np.prod(global_input_shape)).reshape(global_input_shape) - - def get_local_data_slice(index): - # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4)) - # This method will be called per-local device from the GDA constructor. - return global_input_data[index] - - gda = GlobalDeviceArray.from_callback( - global_input_shape, global_mesh, mesh_axes, get_local_data_slice) - - f = pjit(lambda x: x @ x.T, out_axis_resources = P('y', 'x')) - + Example: + + >>> from jax.experimental.maps import Mesh + >>> from jax.experimental import PartitionSpec as P + >>> import numpy as np + + >>> assert jax.device_count() == 8 + >>> global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) + >>> # Logical mesh is (hosts, devices) + >>> assert global_mesh.shape == {'x': 4, 'y': 2} + + >>> global_input_shape = (8, 2) + >>> mesh_axes = P('x', 'y') + + >>> # Dummy example data; in practice we wouldn't necessarily materialize global data + >>> # in a single process. + >>> global_input_data = np.arange( + ... np.prod(global_input_shape)).reshape(global_input_shape) + + >>> def get_local_data_slice(index): + ... # index will be a tuple of slice objects, e.g. (slice(0, 16), slice(0, 4)) + ... # This method will be called per-local device from the GDA constructor. + ... return global_input_data[index] + + >>> gda = GlobalDeviceArray.from_callback( + ... global_input_shape, global_mesh, mesh_axes, get_local_data_slice) + + >>> print(gda.shape) + (8, 2) + >>> print(gda.local_shards[0].data) # Access the data on a single local device + [[0] + [2]] + >>> print(gda.local_shards[0].data.shape) + (2, 1) + >>> # Numpy-style index into the global array that this data shard corresponds to + >>> print(gda.local_shards[0].index) + (slice(0, 2, None), slice(0, 1, None)) + + # Allow pjit to output GDAs + jax.config.update('jax_parallel_functions_output_gda', True) + + f = pjit(lambda x: x @ x.T, in_axis_resources=P('x', 'y'), out_axis_resources = P('x', 'y')) with global_mesh: out = f(gda) - print(type(out)) # GlobalDeviceArray - print(out.shape) # global shape == (64, 64) - print(out.local_shards[0].data) # Access the data on a single local device, - # e.g. for checkpointing - print(out.local_shards[0].data.shape) # per-device shape == (8, 16) - print(out.local_shards[0].index) # Numpy-style index into the global array that - # this data shard corresponds to - # `out` can be passed to another pjit call, out.local_shards can be used to # export the data to non-jax systems (e.g. for checkpointing or logging), etc. """ @@ -394,11 +405,17 @@ def from_callback(cls, global_shape: Shape, global_mesh: pxla.Mesh, Example:: - global_input_shape = (8, 2) - global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) - def cb(index): - return global_input_data[index] - gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) + >>> from jax.experimental.maps import Mesh + >>> import numpy as np + >>> global_input_shape = (8, 8) + >>> mesh_axes = ['x', 'y'] + >>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(2, 4), ('x', 'y')) + >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) + >>> def cb(index): + ... return global_input_data[index] + >>> gda = GlobalDeviceArray.from_callback(global_input_shape, global_mesh, mesh_axes, cb) + >>> gda.local_data(0).shape + (4, 2) Args: global_shape : The global shape of the array @@ -431,13 +448,18 @@ def from_batched_callback(cls, global_shape: Shape, Example:: - global_input_shape = (8, 2) - global_input_data = np.arange( - prod(global_input_shape)).reshape(global_input_shape) - def batched_cb(indices): - self.assertEqual(len(indices),len(global_mesh.local_devices)) - return [global_input_data[index] for index in indices] - gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb) + >>> from jax.experimental.maps import Mesh + >>> import numpy as np + >>> global_input_shape = (8, 2) + >>> mesh_axes = ['x'] + >>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) + >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) + >>> def batched_cb(indices): + ... assert len(indices) == len(global_mesh.local_devices) + ... return [global_input_data[index] for index in indices] + >>> gda = GlobalDeviceArray.from_batched_callback(global_input_shape, global_mesh, mesh_axes, batched_cb) + >>> gda.local_data(0).shape + (2, 2) Args: global_shape : The global shape of the array @@ -469,17 +491,23 @@ def from_batched_callback_with_devices( Example:: - global_input_shape = (8, 2) - global_input_data = np.arange(prod(global_input_shape), dtype=np.float32).reshape(global_input_shape) - def cb(cb_inp): - self.assertLen(cb_inp, len(global_mesh.local_devices)) - dbs = [] - for inp in cb_inp: - index, devices = inp - array = global_input_data[index] - dbs.extend([jax.device_put(array, device) for device in devices]) - return dbs - gda = GlobalDeviceArray.from_batched_callback_with_devices(global_input_shape, global_mesh, mesh_axes, cb) + >>> from jax.experimental.maps import Mesh + >>> import numpy as np + >>> global_input_shape = (8, 2) + >>> mesh_axes = [('x', 'y')] + >>> global_mesh = global_mesh = Mesh(np.array(jax.devices()).reshape(4, 2), ('x', 'y')) + >>> global_input_data = np.arange(prod(global_input_shape)).reshape(global_input_shape) + >>> def cb(cb_inp): + ... dbs = [] + ... for inp in cb_inp: + ... index, devices = inp + ... array = global_input_data[index] + ... dbs.extend([jax.device_put(array, device) for device in devices]) + ... return dbs + >>> gda = GlobalDeviceArray.from_batched_callback_with_devices( + ... global_input_shape, global_mesh, mesh_axes, cb) + >>> gda.local_data(0).shape + (1, 2) Args: global_shape : The global shape of the array