From 633b44a2518f7c80559ee7342768b27206b34267 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 15:31:46 +0000 Subject: [PATCH 1/8] Fix tests fully --- src/accelerate/data_loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 8337d399a34..1c3b02d73e8 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -478,6 +478,8 @@ def set_epoch(self, epoch: int): self.iteration = epoch if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): self.batch_sampler.sampler.set_epoch(epoch) + elif hasattr(self.batch_sampler, "set_epoch"): + self.batch_sampler.set_epoch(epoch) # We support if a custom `Dataset` implementation has `set_epoch` # or in general HF datasets `Datasets` elif hasattr(self.dataset, "set_epoch"): @@ -836,7 +838,7 @@ def prepare_data_loader( sampler = getattr(dataloader.sampler, "sampler", None) else: sampler = getattr(dataloader.batch_sampler, "sampler", None) - if isinstance(sampler, RandomSampler) and num_processes > 1: + if isinstance(sampler, RandomSampler): # When iterating through the dataloader during distributed processes # we want to ensure that on each process we are iterating through the same # samples in the same order if a seed is set. This requires a tweak From 6f56455c12796d260ca3738ae8663d55e05b1244 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 15:35:09 +0000 Subject: [PATCH 2/8] Change comment --- src/accelerate/data_loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 1c3b02d73e8..77902ad4e18 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -839,8 +839,8 @@ def prepare_data_loader( else: sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler): - # When iterating through the dataloader during distributed processes - # we want to ensure that on each process we are iterating through the same + # When iterating through the dataloader we want to ensure that + # on each process we are iterating through the same # samples in the same order if a seed is set. This requires a tweak # to the `torch.utils.data.RandomSampler` class (if used). sampler = SeedableRandomSampler( From 504528f4d23b592f476079c6b1f1c8a0e92ee34a Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 15:36:05 +0000 Subject: [PATCH 3/8] Further comments --- src/accelerate/data_loader.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 77902ad4e18..2335ada9351 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -476,10 +476,11 @@ def set_epoch(self, epoch: int): # In case it is manually passed in, the user can set it to what they like if self.iteration != epoch: self.iteration = epoch - if hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): - self.batch_sampler.sampler.set_epoch(epoch) - elif hasattr(self.batch_sampler, "set_epoch"): + if hasattr(self.batch_sampler, "set_epoch"): + # Case: `SkipBatchSampler` self.batch_sampler.set_epoch(epoch) + elif hasattr(self.batch_sampler, "sampler") and hasattr(self.batch_sampler.sampler, "set_epoch"): + self.batch_sampler.sampler.set_epoch(epoch) # We support if a custom `Dataset` implementation has `set_epoch` # or in general HF datasets `Datasets` elif hasattr(self.dataset, "set_epoch"): From bc8593f4184e973259c7ae976e5be879ad3e6104 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 15:39:32 +0000 Subject: [PATCH 4/8] Clean --- src/accelerate/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 2335ada9351..4de436cb55a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -840,7 +840,7 @@ def prepare_data_loader( else: sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler): - # When iterating through the dataloader we want to ensure that + # When iterating through the dataloader we want to ensure that # on each process we are iterating through the same # samples in the same order if a seed is set. This requires a tweak # to the `torch.utils.data.RandomSampler` class (if used). From 2220ce8a983102377b38caa870ca00d7d6967786 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 15:56:53 +0000 Subject: [PATCH 5/8] CPU specific --- src/accelerate/data_loader.py | 24 ++++++++++++++---------- 1 file changed, 14 insertions(+), 10 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 4de436cb55a..e28bf71f67f 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -840,16 +840,20 @@ def prepare_data_loader( else: sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler): - # When iterating through the dataloader we want to ensure that - # on each process we are iterating through the same - # samples in the same order if a seed is set. This requires a tweak - # to the `torch.utils.data.RandomSampler` class (if used). - sampler = SeedableRandomSampler( - data_source=sampler.data_source, - replacement=sampler.replacement, - num_samples=sampler._num_samples, - generator=getattr(sampler, "generator", torch.Generator()), - ) + # CPU's specifically do not require this workaround + if (AcceleratorState().distributed_type == DistributedType.NO) and (AcceleratorState().device.type == "cpu"): + pass + else: + # When iterating through the dataloader we want to ensure that + # on each process we are iterating through the same + # samples in the same order if a seed is set. This requires a tweak + # to the `torch.utils.data.RandomSampler` class (if used). + sampler = SeedableRandomSampler( + data_source=sampler.data_source, + replacement=sampler.replacement, + num_samples=sampler._num_samples, + generator=getattr(sampler, "generator", torch.Generator()), + ) # No change if no multiprocess if (num_processes != 1 or state.distributed_type == DistributedType.MEGATRON_LM) and not dispatch_batches: From 44b4593823ef33820fc3fb787700027478793153 Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 16:13:31 +0000 Subject: [PATCH 6/8] Just use device --- src/accelerate/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index e28bf71f67f..348c000619e 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -841,7 +841,7 @@ def prepare_data_loader( sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler): # CPU's specifically do not require this workaround - if (AcceleratorState().distributed_type == DistributedType.NO) and (AcceleratorState().device.type == "cpu"): + if (AcceleratorState().distributed_type == DistributedType.NO) and (device.type == "cpu"): pass else: # When iterating through the dataloader we want to ensure that From 1ecdede07d6142a1bc1951d04f01b560b3f46d3f Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 16:14:17 +0000 Subject: [PATCH 7/8] Rewrite differently --- src/accelerate/data_loader.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 348c000619e..1c1ba2a8174 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -841,7 +841,7 @@ def prepare_data_loader( sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler): # CPU's specifically do not require this workaround - if (AcceleratorState().distributed_type == DistributedType.NO) and (device.type == "cpu"): + if num_processes == 1 and (device.type == "cpu"): pass else: # When iterating through the dataloader we want to ensure that From d721b4eab96d31a9e3948964fbe20e6d4b42e8fe Mon Sep 17 00:00:00 2001 From: Zach Mueller Date: Wed, 1 Nov 2023 16:15:03 +0000 Subject: [PATCH 8/8] Rewrite --- src/accelerate/data_loader.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/accelerate/data_loader.py b/src/accelerate/data_loader.py index 1c1ba2a8174..7bb9d31738a 100644 --- a/src/accelerate/data_loader.py +++ b/src/accelerate/data_loader.py @@ -841,9 +841,7 @@ def prepare_data_loader( sampler = getattr(dataloader.batch_sampler, "sampler", None) if isinstance(sampler, RandomSampler): # CPU's specifically do not require this workaround - if num_processes == 1 and (device.type == "cpu"): - pass - else: + if not ((num_processes == 1) and (device.type == "cpu")): # When iterating through the dataloader we want to ensure that # on each process we are iterating through the same # samples in the same order if a seed is set. This requires a tweak