From d8e12854098988d2162948c9a853081fcf00b73f Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 13:39:53 -0400 Subject: [PATCH] Always use SeedableRandomSampler (#2110) * Fix tests fully * Change comment * Further comments * Clean * CPU specific * Just use device * Rewrite differently * Rewrite --- src/accelerate/data_loader.py | 29 +++++++++++++++++------------ 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 8337d399a34..7bb9d31738a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -476,7 +476,10 @@ def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like if self.iteration != epoch: self.iteration = epoch - if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): + if hasattr(self.batch_sampler, "set_epoch"): + # Case: `SkipBatchSampler` + self.batch_sampler.set_epoch(epoch) + elif hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): self.batch_sampler.sampler.set_epoch(epoch) # We support if a custom `Dataset` implementation has `set_epoch` # or in general HF datasets `Datasets` @@ -836,17 +839,19 @@ def prepare_data_loader( sampler = getattr(dataloader.sampler, "sampler", None) else: sampler = getattr(dataloader.batch_sampler, "sampler", None) - if isinstance(sampler, RandomSampler) and num_processes > 1: - # When iterating through the dataloader during distributed processes - # we want to ensure that on each process we are iterating through the same - # samples in the same order if a seed is set. This requires a tweak - # to the `torch.utils.data.RandomSampler` class (if used). - sampler = SeedableRandomSampler( - data_source=sampler.data_source, - replacement=sampler.replacement, - num_samples=sampler._num_samples, - generator=getattr(sampler, "generator", torch.Generator()), - ) + if isinstance(sampler, RandomSampler): + # CPU's specifically do not require this workaround + if not ((num_processes == 1) and (device.type == "cpu")): + # When iterating through the dataloader we want to ensure that + # on each process we are iterating through the same + # samples in the same order if a seed is set. This requires a tweak + # to the `torch.utils.data.RandomSampler` class (if used). + sampler = SeedableRandomSampler( + data_source=sampler.data_source, + replacement=sampler.replacement, + num_samples=sampler._num_samples, + generator=getattr(sampler, "generator", torch.Generator()), + ) # No change if no multiprocess if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: