Skip to content

Commit

Permalink
Internal change.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 636574625
  • Loading branch information
cpgaffney1 authored and jax authors committed May 23, 2024
1 parent 63a13f5 commit 8f6fc11
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion jax/experimental/array_serialization/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,21 @@ async def async_serialize(
primary_host: Optional[int] = 0,
replica_id: int = 0,
):
"""Serialize an array using TensorStore.
Args:
arr_inp: The array to serialize.
tensorstore_spec: The tensorstore spec to use.
commit_future: A list of futures that will be appended to. The futures can
be awaited asynchronously. If None, the futures will be awaited
synchronously by this method.
context: ts.Context instance.
primary_host: Primary host, which indicates the host that will be treated as
the "leader". If None, all hosts are treated as the primary. DO NOT USE
unless you are sure you know what you are doing.
replica_id: Allows overriding the shard replica id that will be saved.
DO NOT USE unless you are sure you know what you are doing.
"""
if (isinstance(arr_inp, array.ArrayImpl) and jax.process_count() > 1 and
arr_inp.is_fully_addressable):
raise ValueError(
Expand All @@ -202,7 +217,9 @@ async def async_serialize(
f'the path "{tensorstore_spec["kvstore"]["path"]}".')

if primary_host is None and is_remote_storage(tensorstore_spec):
raise ValueError(
# Not strictly an error because users may manually split directories into
# per-process subdirectories.
logging.warning(
'When primary_host is set to None and remote storage is used,'
' serialization is not allowed, as this may lead to a race condition'
' between processes.'
Expand Down

0 comments on commit 8f6fc11

Please sign in to comment.