Skip to content

Commit b620d86

Browse files
williamFalconBorda
andauthored
diable val and test shuffling (Lightning-AI#1600)
* diable val and test shuffling * diable val and test shuffling * diable val and test shuffling * diable val and test shuffling * log * condition * shuffle * refactor Co-authored-by: J. Borovec <jirka.borovec@seznam.cz>
1 parent 791ba91 commit b620d86

File tree

6 files changed

+23
-9
lines changed

6 files changed

+23
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4646

4747
- Updated semantic segmentation example with custom u-net and logging ([#1371](https://github.com/PyTorchLightning/pytorch-lightning/pull/1371))
4848

49+
- Diabled val and test shuffling ([#1600](https://github.com/PyTorchLightning/pytorch-lightning/pull/1600))
50+
4951

5052
### Deprecated
5153

pytorch_lightning/core/lightning.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1349,7 +1349,7 @@ def test_dataloader(self):
13491349
loader = torch.utils.data.DataLoader(
13501350
dataset=dataset,
13511351
batch_size=self.hparams.batch_size,
1352-
shuffle=True
1352+
shuffle=False
13531353
)
13541354
13551355
return loader
@@ -1394,7 +1394,7 @@ def val_dataloader(self):
13941394
loader = torch.utils.data.DataLoader(
13951395
dataset=dataset,
13961396
batch_size=self.hparams.batch_size,
1397-
shuffle=True
1397+
shuffle=False
13981398
)
13991399
14001400
return loader

pytorch_lightning/trainer/data_loading.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from typing import Union, List, Tuple, Callable
44

55
import torch.distributed as torch_distrib
6-
from torch.utils.data import DataLoader
6+
from torch.utils.data import DataLoader, RandomSampler
77
from torch.utils.data.distributed import DistributedSampler
88

99
from pytorch_lightning.core import LightningModule
@@ -195,8 +195,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
195195
self.val_check_batch = int(self.num_training_batches * self.val_check_interval)
196196
self.val_check_batch = max(1, self.val_check_batch)
197197

198-
def _reset_eval_dataloader(self, model: LightningModule,
199-
mode: str) -> Tuple[int, List[DataLoader]]:
198+
def _reset_eval_dataloader(self, model: LightningModule, mode: str) -> Tuple[int, List[DataLoader]]:
200199
"""Generic method to reset a dataloader for evaluation.
201200
202201
Args:
@@ -211,6 +210,13 @@ def _reset_eval_dataloader(self, model: LightningModule,
211210
if not isinstance(dataloaders, list):
212211
dataloaders = [dataloaders]
213212

213+
# shuffling in val and test set is bad practice
214+
for loader in dataloaders:
215+
if mode in ('val', 'test') and hasattr(loader, 'sampler') and isinstance(loader.sampler, RandomSampler):
216+
raise MisconfigurationException(
217+
f'Your {mode}_dataloader has shuffle=True, it is best practice to turn'
218+
' this off for validation and test dataloaders.')
219+
214220
# add samplers
215221
dataloaders = [self.auto_add_sampler(dl, train=False) for dl in dataloaders if dl]
216222

tests/base/eval_model_template.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,14 @@
1010
from tests.base.eval_model_test_steps import TestStepVariations
1111
from tests.base.eval_model_train_dataloaders import TrainDataloaderVariations
1212
from tests.base.eval_model_train_steps import TrainingStepVariations
13-
from tests.base.eval_model_utils import ModelTemplateUtils
13+
from tests.base.eval_model_utils import ModelTemplateUtils, ModelTemplateData
1414
from tests.base.eval_model_valid_dataloaders import ValDataloaderVariations
1515
from tests.base.eval_model_valid_epoch_ends import ValidationEpochEndVariations
1616
from tests.base.eval_model_valid_steps import ValidationStepVariations
1717

1818

1919
class EvalModelTemplate(
20+
ModelTemplateData,
2021
ModelTemplateUtils,
2122
TrainingStepVariations,
2223
ValidationStepVariations,

tests/base/eval_model_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,23 @@
33
from tests.base.datasets import TrialMNIST
44

55

6-
class ModelTemplateUtils:
6+
class ModelTemplateData:
7+
hparams: ...
78

89
def dataloader(self, train):
910
dataset = TrialMNIST(root=self.hparams.data_root, train=train, download=True)
1011

1112
loader = DataLoader(
1213
dataset=dataset,
1314
batch_size=self.hparams.batch_size,
14-
shuffle=True
15+
# test and valid shall not be shuffled
16+
shuffle=train,
1517
)
1618
return loader
1719

20+
21+
class ModelTemplateUtils:
22+
1823
def get_output_metric(self, output, name):
1924
if isinstance(output, dict):
2025
val = output[name]

tests/base/models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def _dataloader(self, train):
149149
loader = DataLoader(
150150
dataset=dataset,
151151
batch_size=batch_size,
152-
shuffle=True
152+
shuffle=train
153153
)
154154

155155
return loader

0 commit comments

Comments
 (0)