Skip to content

Commit

Permalink
Revert "Always use SeedableRandomSampler (#2110)"
Browse files Browse the repository at this point in the history
This reverts commit d8e1285.
  • Loading branch information
muellerzr committed Nov 1, 2023
1 parent 55088a2 commit bd72a5f
Showing 1 changed file with 12 additions and 17 deletions.
29 changes: 12 additions & 17 deletions src/accelerate/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -476,10 +476,7 @@ def set_epoch(self, epoch: int):
# In case it is manually passed in, the user can set it to what they like
if self.iteration != epoch:
self.iteration = epoch
if hasattr(self.batch_sampler, "set_epoch"):
# Case: `SkipBatchSampler`
self.batch_sampler.set_epoch(epoch)
elif hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"):
self.batch_sampler.sampler.set_epoch(epoch)
# We support if a custom `Dataset` implementation has `set_epoch`
# or in general HF datasets `Datasets`
Expand Down Expand Up @@ -839,19 +836,17 @@ def prepare_data_loader(
sampler = getattr(dataloader.sampler, "sampler", None)
else:
sampler = getattr(dataloader.batch_sampler, "sampler", None)
if isinstance(sampler, RandomSampler):
# CPU's specifically do not require this workaround
if not ((num_processes == 1) and (device.type == "cpu")):
# When iterating through the dataloader we want to ensure that
# on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)
if isinstance(sampler, RandomSampler) and num_processes > 1:
# When iterating through the dataloader during distributed processes
# we want to ensure that on each process we are iterating through the same
# samples in the same order if a seed is set. This requires a tweak
# to the `torch.utils.data.RandomSampler` class (if used).
sampler = SeedableRandomSampler(
data_source=sampler.data_source,
replacement=sampler.replacement,
num_samples=sampler._num_samples,
generator=getattr(sampler, "generator", torch.Generator()),
)

# No change if no multiprocess
if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches:
Expand Down

0 comments on commit bd72a5f

Please sign in to comment.