Skip to content

Commit

Permalink
Add checkpointing support for Array similar to GDA.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 469271107
  • Loading branch information
yashk2810 authored and jax authors committed Aug 22, 2022
1 parent 384776f commit 7cdb7e1
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 34 deletions.
3 changes: 3 additions & 0 deletions jax/experimental/array.py
Expand Up @@ -264,6 +264,9 @@ def devices(self) -> List[Device]:
self._check_if_deleted()
return list(self.sharding.device_set)

def to_py(self) -> np.ndarray:
return self._value

@pxla.maybe_cached_property
def addressable_shards(self) -> Sequence[Shard]:
self._check_if_deleted()
Expand Down
96 changes: 68 additions & 28 deletions jax/experimental/gda_serialization/serialization.py
Expand Up @@ -15,6 +15,7 @@

import abc
import asyncio
import enum
import itertools
from functools import partial
import re
Expand All @@ -25,6 +26,8 @@
import jax
from jax._src import distributed
from jax.experimental import global_device_array as gda
from jax.experimental import array
from jax.experimental import sharding
from jax.experimental.maps import Mesh
import jax.numpy as jnp
import numpy as np
Expand All @@ -37,6 +40,24 @@
_module_unique_count = itertools.count()


async def create_async_array_from_callback(
global_shape: array.Shape,
inp_sharding: sharding.XLACompatibleSharding,
data_callback: Callable[[array.Index], asyncio.Future],
):
device_to_index_map = inp_sharding.devices_indices_map(global_shape)
future_arrays = [data_callback(device_to_index_map[d]) # type: ignore
for d in inp_sharding._addressable_device_assignment]
# Pause here and come back to `from_async_callback()` when future_arrays are
# ready. device_put cannot happen with future_arrays.
local_arrays = await asyncio.gather(*future_arrays)

dbs = [jax.device_put(array, device)
for array, device in zip(local_arrays, inp_sharding._addressable_device_assignment)]
aval = jax.ShapedArray(global_shape, dbs[0].dtype)
return array.Array(aval, inp_sharding, dbs, committed=True)


async def create_async_gda_from_callback(
global_shape: gda.Shape,
global_mesh: Mesh,
Expand All @@ -58,19 +79,22 @@ async def create_async_gda_from_callback(
gda._GdaFastPathArgs(global_idx_rid, local_devices))


def _get_metadata(gda):
if gda.dtype == jnp.bfloat16:
def _get_metadata(arr):
if arr.dtype == jnp.bfloat16:
# Tensorstore uses 'bfloat16', not '<V2'.
dtype = 'bfloat16'
else:
dtype = np.dtype(gda.dtype).str

dtype = np.dtype(arr.dtype).str
if isinstance(arr, array.Array):
local_shape = arr._arrays[0].shape
else:
local_shape = arr.local_data(0).shape
return {
'compressor': {
'id': 'gzip'
},
'shape': gda.shape,
'chunks': np.array(np.maximum(1, gda.local_data(0).shape)),
'shape': arr.shape,
'chunks': np.array(np.maximum(1, local_shape)),
'dtype': dtype,
}

Expand Down Expand Up @@ -121,12 +145,15 @@ async def release_bytes(self, requested_bytes):
self._cv.notify_all()


async def async_serialize(gda_inp: gda.GlobalDeviceArray, tensorstore_spec,
commit_future=None):
async def async_serialize(arr_inp, tensorstore_spec, commit_future=None):
if (isinstance(arr_inp, array.Array) and jax.process_count() > 1 and
arr_inp.is_fully_addressable()):
raise ValueError('Passing fully addressable Arrays to a multi-host '
'serialization is not allowed.')
# 'metadata' may not be present at the top level (for example, if we are using
# a 'cast' driver).
if not _spec_has_metadata(tensorstore_spec):
tensorstore_spec['metadata'] = _get_metadata(gda_inp)
tensorstore_spec['metadata'] = _get_metadata(arr_inp)

if jax.process_index() == 0:
open_future = ts.open(
Expand Down Expand Up @@ -156,14 +183,17 @@ async def _write_array(shard):
else:
await write_future.commit

future_write_state = jax.tree_util.tree_map(_write_array,
gda_inp.local_shards)
if isinstance(arr_inp, array.Array):
local_shards = arr_inp.addressable_shards
else:
local_shards = arr_inp.local_shards
future_write_state = jax.tree_util.tree_map(_write_array, local_shards)
return await asyncio.gather(*future_write_state)


def run_serialization(gdas, tensorstore_specs):
def run_serialization(arrays, tensorstore_specs):
async def _run_serializer():
future_writer = jax.tree_util.tree_map(async_serialize, gdas, tensorstore_specs)
future_writer = jax.tree_util.tree_map(async_serialize, arrays, tensorstore_specs)
return await asyncio.gather(*future_writer)
asyncio.run(_run_serializer())

Expand All @@ -189,9 +219,15 @@ def estimate_read_memory_footprint(t: ts.TensorStore) -> int:
return num_bytes


class ArrayFlavor(enum.Enum):
GDA = 0
Array = 1


async def async_deserialize(mesh, mesh_axes, tensorstore_spec,
global_shape=None, dtype=None,
byte_limiter: Optional[_LimitInFlightBytes] = None):
byte_limiter: Optional[_LimitInFlightBytes] = None,
return_arr_flavor: ArrayFlavor = ArrayFlavor.GDA):
t = await ts.open(ts.Spec(tensorstore_spec), open=True, context=TS_CONTEXT)
shape = t.shape if global_shape is None else global_shape
new_shard_shape = gda.get_shard_shape(tuple(shape), mesh, mesh_axes)
Expand Down Expand Up @@ -222,23 +258,29 @@ async def cb(index):
await byte_limiter.release_bytes(requested_bytes)
return out

return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)
if return_arr_flavor == ArrayFlavor.Array:
inp_sharding = sharding.MeshPspecSharding(mesh, mesh_axes)
return await create_async_array_from_callback(tuple(shape), inp_sharding, cb)
else:
return await create_async_gda_from_callback(tuple(shape), mesh, mesh_axes, cb)


def run_deserialization(global_meshes, mesh_axes, tensorstore_specs,
global_shapes=None, dtypes=None, concurrent_gb=32):
global_shapes=None, dtypes=None, concurrent_gb=32,
return_arr_flavor=ArrayFlavor.GDA):
concurrent_bytes = concurrent_gb * 10**9

async def _run_deserializer():
# Object should be created once per process.
byte_limiter = _LimitInFlightBytes(concurrent_bytes)

future_gdas = jax.tree_util.tree_map(
partial(async_deserialize, byte_limiter=byte_limiter),
future_arrays = jax.tree_util.tree_map(
partial(async_deserialize, byte_limiter=byte_limiter,
return_arr_flavor=return_arr_flavor),
global_meshes, mesh_axes, tensorstore_specs,
[None] * len(tensorstore_specs) if global_shapes is None else global_shapes,
[None] * len(tensorstore_specs) if dtypes is None else dtypes)
return await asyncio.gather(*future_gdas)
return await asyncio.gather(*future_arrays)
return asyncio.run(_run_deserializer())


Expand Down Expand Up @@ -299,10 +341,7 @@ def wait_until_finished(self):
"""Blocks until serialization has finished."""

@abc.abstractmethod
# TODO(b/233793426): Try removing temp_checkpoint_dir and final_checkpoint_dir
# from the API and use a callback instead. This will affect how async
# mechanism works.
def serialize(self, gdas, tensorstore_specs, *,
def serialize(self, arrays, tensorstore_specs, *,
on_commit_callback: Callable[[], None]):
"""Serializes GDAs to TensorStore."""

Expand Down Expand Up @@ -396,8 +435,8 @@ def _add_futures(self, futures: Sequence[asyncio.Future]):
class GlobalAsyncCheckpointManager(AsyncManager, GlobalAsyncCheckpointManagerBase):
"""Responsible for serializing GDAs via TensorStore."""

def serialize(self, gdas, tensorstore_specs, *, on_commit_callback):
"""Serializes GlobalDeviceArrays via TensorStore asynchronously.
def serialize(self, arrays, tensorstore_specs, *, on_commit_callback):
"""Serializes GlobalDeviceArrays or Arrays via TensorStore asynchronously.
TensorStore writes to a storage layer in 2 steps:
* Reading/copying from the source after which the source can be modified.
Expand All @@ -409,8 +448,9 @@ def serialize(self, gdas, tensorstore_specs, *, on_commit_callback):
finish in a separate thread allowing other computation to proceed.
Args:
gdas: GlobalDeviceArrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs.
arrays: GlobalDeviceArrays or Arrays that should be serialized.
tensorstore_specs: TensorStore specs that are used to serialize GDAs or
Arrays.
temp_checkpoint_dir: Temporary checkpoint directory where the checkpoints
will be written.
final_checkpoint_dir: Final checkpoint directory where the checkpoints
Expand All @@ -423,7 +463,7 @@ def serialize(self, gdas, tensorstore_specs, *, on_commit_callback):

async def _run_serializer():
future_writer = jax.tree_util.tree_map(
async_serialize, gdas, tensorstore_specs, commit_futures)
async_serialize, arrays, tensorstore_specs, commit_futures)
return await asyncio.gather(*future_writer)

asyncio.run(_run_serializer())
Expand Down
62 changes: 61 additions & 1 deletion jax/experimental/gda_serialization/serialization_test.py
Expand Up @@ -20,6 +20,8 @@
from jax._src import test_util as jtu
from jax._src import util
from jax.config import config
from jax.experimental import array
from jax.experimental.sharding import MeshPspecSharding
from jax.experimental import PartitionSpec as P
from jax.experimental.global_device_array import GlobalDeviceArray
from jax.experimental.gda_serialization import serialization
Expand Down Expand Up @@ -53,7 +55,7 @@ def cb2(index):
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

# Third GDA
def cb3(index):
def cb3(_):
return np.array([])
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
gda3 = GlobalDeviceArray.from_callback((0,), global_mesh1d, P(None), cb3)
Expand Down Expand Up @@ -89,6 +91,64 @@ def cb3(index):
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(m3.dtype, np.float32)

def test_checkpointing_with_array(self):
global_mesh = jtu.create_global_mesh((4, 2), ('x', 'y'))
inp_shape = (8, 2)
pspec = P('x', 'y')
num = util.prod(inp_shape)

# First Array
global_input_data1 = np.arange(num).reshape(inp_shape)
a1 = array.make_array_from_callback(
inp_shape, MeshPspecSharding(global_mesh, pspec),
lambda idx: global_input_data1[idx])
ckpt_dir1 = pathlib.Path(self.create_tempdir('first').full_path)

# Second Array
global_input_data2 = np.arange(num, num + num).reshape(inp_shape)
a2 = array.make_array_from_callback(
inp_shape, MeshPspecSharding(global_mesh, pspec),
lambda idx: global_input_data2[idx])
ckpt_dir2 = pathlib.Path(self.create_tempdir('second').full_path)

# Third Array
def cb3(_):
return np.array([])
global_mesh1d = jtu.create_global_mesh((8,), ('x',))
a3 = array.make_array_from_callback(
(0,), MeshPspecSharding(global_mesh1d, P(None)), cb3)
ckpt_dir3 = pathlib.Path(self.create_tempdir('third').full_path)

ckpt_paths = [str(ckpt_dir1), str(ckpt_dir2), str(ckpt_dir3)]
tspecs = jax.tree_util.tree_map(serialization.get_tensorstore_spec, ckpt_paths)

serialization.run_serialization([a1, a2, a3], tspecs)

m1, m2, m3 = serialization.run_deserialization(
[global_mesh, global_mesh, global_mesh1d],
[pspec, P('x'), P(None)],
tspecs, return_arr_flavor=serialization.ArrayFlavor.Array)

self.assertArraysEqual(m1.addressable_shards[0].data.to_py(),
np.array([[0], [2]]))
self.assertArraysEqual(m1.addressable_shards[1].data.to_py(),
np.array([[1], [3]]))
self.assertEqual(m1.addressable_shards[0].data.shape, (2, 1))
self.assertEqual(m1.dtype, np.int32)

self.assertArraysEqual(m2.addressable_shards[0].data.to_py(),
np.array([[16, 17], [18, 19]]))
self.assertArraysEqual(m2.addressable_shards[1].data.to_py(),
np.array([[16, 17], [18, 19]]))
self.assertEqual(m2.addressable_shards[0].data.shape, (2, 2))
self.assertEqual(m2.dtype, np.int32)

for i, s in enumerate(m3.addressable_shards):
self.assertEqual(s.index, (slice(None),))
self.assertEqual(s.replica_id, i)
self.assertArraysEqual(s.data.to_py(), np.array([]))
self.assertEqual(m3.dtype, np.float32)

def test_checkpointing_with_bigger_shape(self):
global_mesh = jtu.create_global_mesh((2, 2), ('x', 'y'))
global_input_shape = (8, 2)
Expand Down
10 changes: 5 additions & 5 deletions jax/experimental/sharding.py
Expand Up @@ -167,7 +167,7 @@ def device_set(self) -> Set[Device]:
return set(self.mesh.devices.flat)

def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
self, global_shape: Shape) -> Mapping[Device, Index]:
# TODO(yashkatariya): Remove this when utilities are moved to pxla.py.
from jax.experimental import global_device_array

Expand All @@ -186,7 +186,7 @@ def _to_xla_op_sharding(
self,
num_dimensions: int,
axis_ctx: Optional[Union[mlir.SPMDAxisContext, mlir.ShardingContext]] = None
) -> Optional[xc.OpSharding]:
) -> xc.OpSharding:
from jax.experimental.pjit import get_array_mapping

array_mapping = get_array_mapping(self._parsed_pspec)
Expand Down Expand Up @@ -233,7 +233,7 @@ def device_set(self) -> Set[Device]:
return {self._device}

def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
self, global_shape: Shape) -> Mapping[Device, Index]:
return {self._device: (slice(None),) * len(global_shape)}

def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
Expand All @@ -243,7 +243,7 @@ def device_replica_id_map(self, global_shape: Shape) -> Mapping[Device, int]:
def _device_assignment(self) -> XLADeviceAssignment:
return [self._device]

def _to_xla_op_sharding(self, num_dimensions: int) -> Optional[xc.OpSharding]:
def _to_xla_op_sharding(self, num_dimensions: int) -> xc.OpSharding:
return _get_replicated_op_sharding()


Expand Down Expand Up @@ -324,7 +324,7 @@ def device_set(self) -> Set[Device]:

@functools.lru_cache(maxsize=4096)
def devices_indices_map(
self, global_shape: Shape) -> Mapping[Device, Optional[Index]]:
self, global_shape: Shape) -> Mapping[Device, Index]:
indices = pxla.op_sharding_to_indices(self._op_sharding, global_shape,
len(self._devices))
return dict(safe_zip(self._devices, indices))
Expand Down

0 comments on commit 7cdb7e1

Please sign in to comment.