Skip to content

Commit

Permalink
Enable correct resumption from the end of an epoch (#700)
Browse files Browse the repository at this point in the history
* typo

* potensh

* tests

* tests

* Update streaming/base/partition/relaxed.py

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* Update streaming/base/partition/relaxed.py

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>

* ready

---------

Co-authored-by: Mihir Patel <mihir.v.patel7@gmail.com>
  • Loading branch information
snarayan21 and mvpatel2000 committed Jun 18, 2024
1 parent 2b53acb commit a5b9eea
Show file tree
Hide file tree
Showing 4 changed files with 86 additions and 19 deletions.
56 changes: 39 additions & 17 deletions streaming/base/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -922,14 +922,18 @@ def resample_streams(
sample_ids = np.concatenate(sample_ids).astype(np.int64)
return shuffle_units, sample_ids

def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, SharedMemory]:
def _share_work(
self,
sample_ids: NDArray[np.int64],
) -> Tuple[SharedMemory, Optional[SharedMemory]]:
"""Put an epoch's sample ordering into shared memory.
Args:
sample_ids (NDArray[np.int64]): Sample IDs.
Returns:
Tuple[SharedMemory, SharedMemory]: Shared memory arrays containing shape and data.
Tuple[SharedMemory, Optional[SharedMemory]]: Shared memory arrays containing shape and
data, if present.
"""
ndim = 5

Expand All @@ -945,19 +949,26 @@ def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, Shar
shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
shape_shm.buf[:size] = np.array(sample_ids.shape, np.int64).tobytes()

# Save the generated epoch data to shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = sample_ids.size * np.int64().nbytes
data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
data_shm.buf[:size] = sample_ids.tobytes()
if sample_ids.size > 0:
# Save the generated epoch data to shared memory, but only if the sample partition is
# non-empty. Otherwise, the end of the epoch has been reached.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = sample_ids.size * np.int64().nbytes
data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
data_shm.buf[:size] = sample_ids.tobytes()

return shape_shm, data_shm
return shape_shm, data_shm

def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]:
else:

return shape_shm, None

def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]:
"""Get an epoch's sample ordering from shared memory.
Returns:
NDArray[np.int64]: Sample IDs.
Tuple[NDArray[np.int64], SharedMemory, Optional[SharedMemory]]: Sample IDs, shared
memory array for shape, and shared memory array for data, if present.
"""
ndim = 5

Expand All @@ -967,13 +978,22 @@ def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]:
shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64))

# Attach to the generated epoch data in shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = int(np.prod(shape)) * np.int64().nbytes
data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64)
num_elements = int(np.prod(shape))

if num_elements > 0:
# Attach to the generated epoch data in shared memory, but only if the sample partition
# is non-empty. Otherwise, the end of the epoch has been reached.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = num_elements * np.int64().nbytes
data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64)

return sample_ids, shape_shm, data_shm

else:

return sample_ids, shape_shm, data_shm
sample_ids = np.empty(shape=shape, dtype=np.int64)
return sample_ids, shape_shm, None

def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:
"""Get this worker's partition of this epoch's sample space.
Expand Down Expand Up @@ -1025,7 +1045,9 @@ def _get_work(self, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:

# Now clean up after ourselves.
shape_shm.cleanup()
data_shm.cleanup()
# Can be None if the sample partition was empty.
if data_shm is not None:
data_shm.cleanup()

return worker_sample_ids

Expand Down
2 changes: 1 addition & 1 deletion streaming/base/partition/orig.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_partitions_orig(num_samples: int,
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
if num_samples <= drop_first:
if num_samples < drop_first:
raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
f'({num_samples})')

Expand Down
2 changes: 1 addition & 1 deletion streaming/base/partition/relaxed.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def get_partitions_relaxed(num_samples: int,
NDArray[np.int64]: Partitions of shape (physical nodes, ranks per node, workers per rank,
batches per worker, batch size).
"""
if num_samples <= drop_first:
if num_samples < drop_first:
raise ValueError(f'Resuming further into the dataset ({drop_first}) than it has samples ' +
f'({num_samples})')

Expand Down
45 changes: 45 additions & 0 deletions tests/test_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,51 @@ def test_partition_walk(partition_algo: str):
assert x.shape == (22, 8, 8, 1, 10)


@pytest.mark.parametrize('num_samples', [400, 1000])
@pytest.mark.parametrize('num_canonical_nodes', [1, 4])
@pytest.mark.parametrize('num_physical_nodes', [1, 4])
@pytest.mark.parametrize('ranks_per_node', [1, 8])
@pytest.mark.parametrize('workers_per_rank', [1, 8])
@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed'])
def test_partition_drop_all(num_samples: int, num_canonical_nodes: int, num_physical_nodes: int,
ranks_per_node: int, workers_per_rank: int, batch_size: int,
partition_algo: str):
initial_physical_nodes = None
if partition_algo == 'relaxed' and num_canonical_nodes == 4 and ranks_per_node == 8:
num_canonical_nodes = 3
initial_physical_nodes = 3
batch_size = batch_size * 3
num_samples = 3 * num_samples

drop_first = num_samples

x = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first,
initial_physical_nodes)
# Partition should still have the appropriate shape, but without any samples in it.
assert x.shape == (num_physical_nodes, ranks_per_node, workers_per_rank, 0, batch_size)
assert x.size == 0


@pytest.mark.parametrize('num_samples', [400, 1000])
@pytest.mark.parametrize('drop_additional', [1, 400])
@pytest.mark.parametrize('num_canonical_nodes', [4])
@pytest.mark.parametrize('num_physical_nodes', [4])
@pytest.mark.parametrize('ranks_per_node', [8])
@pytest.mark.parametrize('workers_per_rank', [8])
@pytest.mark.parametrize('batch_size', [4])
@pytest.mark.parametrize('partition_algo', ['orig', 'relaxed'])
def test_partition_invalid_drop_first(num_samples: int, drop_additional: int,
num_canonical_nodes: int, num_physical_nodes: int,
ranks_per_node: int, workers_per_rank: int, batch_size: int,
partition_algo: str):
drop_first = num_samples + drop_additional
with pytest.raises(ValueError, match=f'Resuming further into the dataset*'):
_ = get_partitions(partition_algo, num_samples, num_canonical_nodes, num_physical_nodes,
ranks_per_node, workers_per_rank, batch_size, drop_first)


@pytest.mark.parametrize('num_samples', [1, 4])
@pytest.mark.parametrize('num_canonical_nodes', [1, 4])
@pytest.mark.parametrize('num_physical_nodes', [1, 4])
Expand Down

0 comments on commit a5b9eea

Please sign in to comment.