diff --git a/horovod/torch/elastic/sampler.py b/horovod/torch/elastic/sampler.py index 8a10624066..3468d28200 100644 --- a/horovod/torch/elastic/sampler.py +++ b/horovod/torch/elastic/sampler.py @@ -32,7 +32,7 @@ class ElasticSampler(torch.utils.data.Sampler): In order to use this object successfully it is recommended that the user: 1. Include this object in the `TorchState`. - 2. Call `record_batch` or `record_indices` after processing a set of samples. + 2. Call `record_batch` after processing a set of samples. 3. Call `set_epoch` at the end of each epoch to clear the processed indices. Args: @@ -54,6 +54,7 @@ def __init__(self, dataset, shuffle=True, seed=0): self.remaining_indices = [] self.num_samples = 0 self.total_size = 0 + self.processed_num = 0 self.reset() @@ -71,33 +72,22 @@ def set_epoch(self, epoch): epoch: Epoch number. """ self.epoch = epoch - self.processed_indices = set() + self.processed_num = 0 self.reset() def record_batch(self, batch_idx, batch_size): - """Record indices at batch `batch_idx` with length `batch_size` as processed.""" - indices = set(self.get_indices(batch_idx, batch_size)) - self.record_indices(indices) - - def record_indices(self, indices): - """Record set `indices` as processed.""" - self.processed_indices.update(indices) - - def get_indices(self, batch_idx, batch_size): - """Return list of indices at batch `batch_idx` with length `batch_size`.""" - start_idx = batch_idx * batch_size - end_idx = min(start_idx + batch_size, len(self.indices)) - return self.indices[start_idx:end_idx] + """Record the number of processed samples.""" + self.processed_num += (batch_size * self.num_replicas) def load_state_dict(self, state_dict): self.epoch = state_dict['epoch'] - self.processed_indices = state_dict['processed_indices'] + self.processed_num = state_dict["processed_num"] self.reset() def state_dict(self): return dict( epoch=self.epoch, - processed_indices=self.processed_indices + processed_num=self.processed_num ) def reset(self): @@ -105,18 +95,18 @@ def reset(self): self.rank = rank() # Exclude any samples we have already processed this epoch - self.remaining_indices = [idx for idx in range(len(self.dataset)) - if idx not in self.processed_indices] + all_indices = [idx for idx in range(len(self.dataset))] + if self.shuffle: + # Shuffle indices across workers deterministically in place + seed = self.seed + self.epoch + random.Random(seed).shuffle(all_indices) + self.remaining_indices = all_indices[self.processed_num:] self.num_samples = int(math.ceil(len(self.remaining_indices) * 1.0 / self.num_replicas)) self.total_size = self.num_samples * self.num_replicas def __iter__(self): self.indices = self.remaining_indices[:] - if self.shuffle: - # Shuffle indices across workers deterministically in place - seed = self.seed + self.epoch - random.Random(seed).shuffle(self.indices) # add extra samples to make it evenly divisible self.indices += self.indices[:(self.total_size - len(self.indices))] diff --git a/horovod/torch/elastic/state.py b/horovod/torch/elastic/state.py index 9806bb7cd5..946d987b2a 100644 --- a/horovod/torch/elastic/state.py +++ b/horovod/torch/elastic/state.py @@ -128,12 +128,7 @@ def restore(self): self.value.load_state_dict(self._saved_sampler_state) def sync(self): - # Get the set of processed indices from all workers - world_processed_indices = _union(allgather_object(self.value.processed_indices)) - - # Replace local processed indices with global indices state_dict = self.value.state_dict() - state_dict['processed_indices'] = world_processed_indices # Broadcast and load the state to make sure we're all in sync self.value.load_state_dict(broadcast_object(state_dict)) diff --git a/test/single/test_torch_elastic.py b/test/single/test_torch_elastic.py index c14463a79a..ed3f9e0e91 100644 --- a/test/single/test_torch_elastic.py +++ b/test/single/test_torch_elastic.py @@ -112,7 +112,7 @@ def __len__(self): state.sync() assert state.sampler.epoch == 0 - assert len(state.sampler.processed_indices) == 0 + assert state.sampler.processed_num == 0 # Normal usage, no errors epochs = 2 @@ -120,12 +120,8 @@ def __len__(self): for epoch in range(epochs): sampler.set_epoch(epoch) for batch_idx, batch in enumerate(data_loader): - batch_indices = sampler.get_indices(batch_idx, batch_size) - batch_data = [dataset[idx] for idx in batch_indices] - assert batch_data == batch.numpy().tolist() - sampler.record_batch(batch_idx, batch_size) - assert len(sampler.processed_indices) == batch_size * (batch_idx + 1) + assert sampler.processed_num == batch_size * (batch_idx + 1) total_batches += 1 assert total_batches == (samples_per_worker / batch_size) * epochs @@ -133,47 +129,44 @@ def __len__(self): # Do not reset epoch: processed samples are retained and data loader repeats total_batches = 0 for _ in enumerate(data_loader): - assert len(sampler.processed_indices) == len(sampler) + assert sampler.processed_num == len(sampler) total_batches += 1 assert total_batches == samples_per_worker / batch_size # Elastic: partial epoch + commit sampler.set_epoch(2) - assert len(sampler.processed_indices) == 0 + assert sampler.processed_num == 0 sampler.record_batch(0, batch_size) sampler.record_batch(1, batch_size) - assert len(sampler.processed_indices) == 2 * batch_size + assert sampler.processed_num == 2 * batch_size - committed_indices = copy.copy(sampler.processed_indices) + committed_num = copy.copy(sampler.processed_num) state.commit() # Elastic: partial epoch + restore sampler.record_batch(2, batch_size) sampler.record_batch(3, batch_size) - assert len(sampler.processed_indices) == 4 * batch_size + assert sampler.processed_num == 4 * batch_size state.restore() - assert len(sampler.processed_indices) == 2 * batch_size - assert sampler.processed_indices == committed_indices + assert sampler.processed_num == 2 * batch_size + assert sampler.processed_num == committed_num # Elastic: sync across workers and verify non-overlap of processed samples sampler.record_batch(2, batch_size) - assert len(sampler.processed_indices) == 3 * batch_size + assert sampler.processed_num == 3 * batch_size state.commit() state.sync() - assert len(sampler.processed_indices) == 3 * batch_size * hvd.size() + assert sampler.processed_num == 3 * batch_size * hvd.size() # After the sync, the remaining indices should be updated and repartitioned total_batches = 0 assert len(sampler) == batch_size for batch_idx, batch in enumerate(data_loader): - batch_indices = sampler.get_indices(batch_idx, batch_size) - overlap_indices = set(batch_indices) & sampler.processed_indices - assert overlap_indices == set() total_batches += 1 assert total_batches == 1