Skip to content

Commit 8baec1a

Browse files
rohitgr7awaelchli
andauthored
Fix shuffle for distributed sampler (Lightning-AI#2789)
* Fix shuffle for distributed sampler * add test * test * chlog * update test * update test * update test * assertions via callback * define callback outside for pickling * skip ddp test on windows Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com>
1 parent 38fce2e commit 8baec1a

File tree

3 files changed

+44
-3
lines changed

3 files changed

+44
-3
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4949

5050
- Fixed data transfer to device when using `torchtext.data.Field` and `include_lengths is True` ([#2689](https://github.com/PyTorchLightning/pytorch-lightning/pull/2689))
5151

52+
- Fixed shuffle argument for distributed sampler ([#2789](https://github.com/PyTorchLightning/pytorch-lightning/pull/2789))
53+
5254
## [0.8.5] - 2020-07-09
5355

5456
### Added

pytorch_lightning/trainer/data_loading.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
163163
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
164164

165165
# replace with distributed sampler
166-
sampler = self._get_distributed_sampler(dataloader)
166+
sampler = self._get_distributed_sampler(dataloader, train)
167167
dataloader = self.replace_sampler(dataloader, sampler)
168168

169169
return dataloader
@@ -179,7 +179,7 @@ def replace_sampler(self, dataloader, sampler):
179179
dataloader = type(dataloader)(**dl_args)
180180
return dataloader
181181

182-
def _get_distributed_sampler(self, dataloader):
182+
def _get_distributed_sampler(self, dataloader, train):
183183
if self.use_tpu:
184184
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
185185
elif self.use_horovod:
@@ -193,6 +193,8 @@ def _get_distributed_sampler(self, dataloader):
193193
}
194194
assert self.distributed_backend is not None
195195
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.global_rank)
196+
197+
kwargs['shuffle'] = train
196198
sampler = DistributedSampler(dataloader.dataset, **kwargs)
197199
return sampler
198200

tests/trainer/test_dataloaders.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@
77
from packaging.version import parse
88
from torch.utils.data.dataloader import DataLoader
99
from torch.utils.data.dataset import IterableDataset, Subset
10+
from torch.utils.data.distributed import DistributedSampler
1011

1112
import tests.base.develop_pipelines as tpipes
12-
from pytorch_lightning import Trainer
13+
from pytorch_lightning import Trainer, Callback
1314
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
1415
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1516
from tests.base import EvalModelTemplate
@@ -640,6 +641,42 @@ class CustomSampler(torch.utils.data.Sampler):
640641
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
641642

642643

644+
class DistribSamplerCallback(Callback):
645+
646+
def on_train_start(self, trainer, pl_module):
647+
train_sampler = trainer.train_dataloader.sampler
648+
assert isinstance(train_sampler, DistributedSampler)
649+
assert train_sampler.shuffle
650+
651+
def on_validation_start(self, trainer, pl_module):
652+
val_sampler = trainer.val_dataloaders[0].sampler
653+
assert isinstance(val_sampler, DistributedSampler)
654+
assert not val_sampler.shuffle
655+
656+
def on_test_start(self, trainer, pl_module):
657+
test_sampler = trainer.test_dataloaders[0].sampler
658+
assert isinstance(test_sampler, DistributedSampler)
659+
assert not test_sampler.shuffle
660+
661+
662+
@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
663+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason='Test requires multiple GPUs')
664+
def test_dataloader_distributed_sampler(tmpdir):
665+
""" Test DistributedSampler and it's arguments for DDP backend """
666+
667+
model = EvalModelTemplate()
668+
trainer = Trainer(
669+
gpus=[0, 1],
670+
num_nodes=1,
671+
distributed_backend='ddp_spawn',
672+
default_root_dir=tmpdir,
673+
max_steps=1,
674+
callbacks=[DistribSamplerCallback()]
675+
)
676+
trainer.fit(model)
677+
trainer.test(ckpt_path=None)
678+
679+
643680
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
644681
def test_batch_size_smaller_than_num_gpus(tmpdir):
645682
# we need at least 3 gpus for this test

0 commit comments

Comments
 (0)