Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Elastic Horovod] Fix the bug for ElasticSampler and hvd.elastic.state #3144

Merged
merged 5 commits into from
Oct 21, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
36 changes: 13 additions & 23 deletions horovod/torch/elastic/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class ElasticSampler(torch.utils.data.Sampler):
In order to use this object successfully it is recommended that the user:

1. Include this object in the `TorchState`.
2. Call `record_batch` or `record_indices` after processing a set of samples.
2. Call `record_batch` after processing a set of samples.
3. Call `set_epoch` at the end of each epoch to clear the processed indices.

Args:
Expand All @@ -54,6 +54,7 @@ def __init__(self, dataset, shuffle=True, seed=0):
self.remaining_indices = []
self.num_samples = 0
self.total_size = 0
self.processed_num = 0

self.reset()

Expand All @@ -71,52 +72,41 @@ def set_epoch(self, epoch):
epoch: Epoch number.
"""
self.epoch = epoch
self.processed_indices = set()
self.processed_num = 0
self.reset()

def record_batch(self, batch_idx, batch_size):
"""Record indices at batch `batch_idx` with length `batch_size` as processed."""
indices = set(self.get_indices(batch_idx, batch_size))
self.record_indices(indices)

def record_indices(self, indices):
"""Record set `indices` as processed."""
self.processed_indices.update(indices)

def get_indices(self, batch_idx, batch_size):
"""Return list of indices at batch `batch_idx` with length `batch_size`."""
start_idx = batch_idx * batch_size
end_idx = min(start_idx + batch_size, len(self.indices))
return self.indices[start_idx:end_idx]
"""Record the number of processed samples."""
self.processed_num += (batch_size * self.num_replicas)

def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.processed_indices = state_dict['processed_indices']
self.processed_num = state_dict["processed_num"]
self.reset()

def state_dict(self):
return dict(
epoch=self.epoch,
processed_indices=self.processed_indices
processed_num=self.processed_num
)

def reset(self):
self.num_replicas = size()
self.rank = rank()

# Exclude any samples we have already processed this epoch
self.remaining_indices = [idx for idx in range(len(self.dataset))
if idx not in self.processed_indices]
all_indices = [idx for idx in range(len(self.dataset))]
if self.shuffle:
# Shuffle indices across workers deterministically in place
seed = self.seed + self.epoch
random.Random(seed).shuffle(all_indices)
self.remaining_indices = all_indices[self.processed_num:]

self.num_samples = int(math.ceil(len(self.remaining_indices) * 1.0 / self.num_replicas))
self.total_size = self.num_samples * self.num_replicas

def __iter__(self):
self.indices = self.remaining_indices[:]
if self.shuffle:
# Shuffle indices across workers deterministically in place
seed = self.seed + self.epoch
random.Random(seed).shuffle(self.indices)

# add extra samples to make it evenly divisible
self.indices += self.indices[:(self.total_size - len(self.indices))]
Expand Down
5 changes: 0 additions & 5 deletions horovod/torch/elastic/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,12 +128,7 @@ def restore(self):
self.value.load_state_dict(self._saved_sampler_state)

def sync(self):
# Get the set of processed indices from all workers
world_processed_indices = _union(allgather_object(self.value.processed_indices))

# Replace local processed indices with global indices
state_dict = self.value.state_dict()
state_dict['processed_indices'] = world_processed_indices

# Broadcast and load the state to make sure we're all in sync
self.value.load_state_dict(broadcast_object(state_dict))
Expand Down
29 changes: 11 additions & 18 deletions test/single/test_torch_elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,68 +112,61 @@ def __len__(self):
state.sync()

assert state.sampler.epoch == 0
assert len(state.sampler.processed_indices) == 0
assert state.sampler.processed_num == 0

# Normal usage, no errors
epochs = 2
total_batches = 0
for epoch in range(epochs):
sampler.set_epoch(epoch)
for batch_idx, batch in enumerate(data_loader):
batch_indices = sampler.get_indices(batch_idx, batch_size)
batch_data = [dataset[idx] for idx in batch_indices]
assert batch_data == batch.numpy().tolist()

sampler.record_batch(batch_idx, batch_size)
assert len(sampler.processed_indices) == batch_size * (batch_idx + 1)
assert sampler.processed_num == batch_size * (batch_idx + 1)

total_batches += 1
assert total_batches == (samples_per_worker / batch_size) * epochs

# Do not reset epoch: processed samples are retained and data loader repeats
total_batches = 0
for _ in enumerate(data_loader):
assert len(sampler.processed_indices) == len(sampler)
assert sampler.processed_num == len(sampler)
total_batches += 1
assert total_batches == samples_per_worker / batch_size

# Elastic: partial epoch + commit
sampler.set_epoch(2)
assert len(sampler.processed_indices) == 0
assert sampler.processed_num == 0

sampler.record_batch(0, batch_size)
sampler.record_batch(1, batch_size)
assert len(sampler.processed_indices) == 2 * batch_size
assert sampler.processed_num == 2 * batch_size

committed_indices = copy.copy(sampler.processed_indices)
committed_num = copy.copy(sampler.processed_num)
state.commit()

# Elastic: partial epoch + restore
sampler.record_batch(2, batch_size)
sampler.record_batch(3, batch_size)
assert len(sampler.processed_indices) == 4 * batch_size
assert sampler.processed_num == 4 * batch_size

state.restore()

assert len(sampler.processed_indices) == 2 * batch_size
assert sampler.processed_indices == committed_indices
assert sampler.processed_num == 2 * batch_size
assert sampler.processed_num == committed_num

# Elastic: sync across workers and verify non-overlap of processed samples
sampler.record_batch(2, batch_size)
assert len(sampler.processed_indices) == 3 * batch_size
assert sampler.processed_num == 3 * batch_size

state.commit()
state.sync()

assert len(sampler.processed_indices) == 3 * batch_size * hvd.size()
assert sampler.processed_num == 3 * batch_size * hvd.size()

# After the sync, the remaining indices should be updated and repartitioned
total_batches = 0
assert len(sampler) == batch_size
for batch_idx, batch in enumerate(data_loader):
batch_indices = sampler.get_indices(batch_idx, batch_size)
overlap_indices = set(batch_indices) & sampler.processed_indices
assert overlap_indices == set()
total_batches += 1
assert total_batches == 1

Expand Down