Skip to content

Commit ade3f36

Browse files
Raise an error when lightning replaces an existing sampler (Lightning-AI#2020)
* Raise an error when lightning replaces an existing sampler Currently, Trainer replaces the existing sampler with DistributedSampler if running distributing training and `replace_sampler_ddp=True` (default behaviour). If a user has configured an existing sampler, this would lead to widely different results if running a distributed vs non-distributed training. This PR fixes this by raising an Error if user has configured a sampler and uses `replace_sampler_ddp=True`. The recommended behavior from now on is to either remove the sampler or set `replace_sampler_ddp=False` * Fix tests * Simpler fix * Fix tests * Make inner method protected * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent e85a646 commit ade3f36

File tree

2 files changed

+38
-24
lines changed

2 files changed

+38
-24
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Union, List, Tuple, Callable
44

55
import torch.distributed as torch_distrib
6-
from torch.utils.data import DataLoader, RandomSampler
6+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
77
from torch.utils.data.distributed import DistributedSampler
88

99
from pytorch_lightning.core import LightningModule
@@ -113,39 +113,39 @@ def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
113113
need_dist_sampler = (self.use_ddp or self.use_ddp2 or self.use_horovod or self.use_tpu)
114114

115115
if self.replace_sampler_ddp and need_dist_sampler:
116+
if not isinstance(dataloader.sampler, (SequentialSampler, RandomSampler)):
117+
raise MisconfigurationException(
118+
'You seem to have configured a sampler in your DataLoader. This will be replaced '
119+
' by `DistributedSampler` since `replace_sampler_ddp` is True and you are using'
120+
' distributed training. Either remove the sampler from your DataLoader or set'
121+
' `replace_sampler_ddp`=False if you want to use your custom sampler.')
122+
116123
skip_keys = ['sampler', 'batch_sampler', 'dataset_kind']
117124

118125
dl_args = {
119126
k: v for k, v in dataloader.__dict__.items() if not k.startswith('_') and k not in skip_keys
120127
}
121128

122-
if self.use_tpu:
123-
sampler = DistributedSampler(
124-
dataloader.dataset,
125-
num_replicas=xm.xrt_world_size(),
126-
rank=xm.get_ordinal(),
127-
)
128-
elif self.use_horovod:
129-
sampler = DistributedSampler(dataloader.dataset,
130-
num_replicas=hvd.size(),
131-
rank=hvd.rank())
132-
else:
133-
world_size = {
134-
'ddp': self.num_nodes * self.num_processes,
135-
'ddp2': self.num_nodes,
136-
'ddp_cpu': self.num_processes * self.num_nodes
137-
}
138-
sampler = DistributedSampler(
139-
dataloader.dataset,
140-
num_replicas=world_size[self.distributed_backend],
141-
rank=self.proc_rank,
142-
)
143-
144-
dl_args['sampler'] = sampler
129+
dl_args['sampler'] = self._get_distributed_sampler(dataloader)
145130
dataloader = type(dataloader)(**dl_args)
146131

147132
return dataloader
148133

134+
def _get_distributed_sampler(self, dataloader):
135+
if self.use_tpu:
136+
kwargs = dict(num_replicas=xm.xrt_world_size(), rank=xm.get_ordinal())
137+
elif self.use_horovod:
138+
kwargs = dict(num_replicas=hvd.size(), rank=hvd.rank())
139+
else:
140+
world_size = {
141+
'ddp': self.num_nodes * self.num_processes,
142+
'ddp2': self.num_nodes,
143+
'ddp_cpu': self.num_processes * self.num_nodes
144+
}
145+
kwargs = dict(num_replicas=world_size[self.distributed_backend], rank=self.proc_rank)
146+
sampler = DistributedSampler(dataloader.dataset, **kwargs)
147+
return sampler
148+
149149
def reset_train_dataloader(self, model: LightningModule) -> None:
150150
"""Resets the train dataloader and initialises required variables
151151
(number of batches, when to validate, etc.).

tests/trainer/test_dataloaders.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -416,6 +416,20 @@ class CustomDummyObj:
416416
assert isinstance(result, CustomDataLoader)
417417
assert hasattr(result, 'dummy_kwarg')
418418

419+
# Shuffled DataLoader should also work
420+
result = trainer.auto_add_sampler(CustomDataLoader(list(range(1000)), shuffle=True), train=True)
421+
assert isinstance(result, torch.utils.data.DataLoader)
422+
assert isinstance(result, CustomDataLoader)
423+
assert hasattr(result, 'dummy_kwarg')
424+
425+
class CustomSampler(torch.utils.data.Sampler):
426+
pass
427+
428+
# Should raise an error if existing sampler is being replaced
429+
with pytest.raises(MisconfigurationException, match='DistributedSampler'):
430+
trainer.auto_add_sampler(
431+
CustomDataLoader(list(range(1000)), sampler=CustomSampler(list(range(1000)))), train=True)
432+
419433

420434
@pytest.mark.skipif(torch.cuda.device_count() < 3, reason='Test requires multiple GPUs')
421435
def test_batch_size_smaller_than_num_gpus():

0 commit comments

Comments
 (0)