Skip to content

Commit e085e93

Browse files
awaelchlirohitgr7Borda
authored
Add missing test for "multiple dataloader + percent_check fix" (Lightning-AI#2226)
* Init fix num_batches * Fix num_batches in case of multiple dataloaders * Apply suggestions from code review * Changes based on suggestions * Flake8 * Add test to check num_batches * generalize dataloader percent check test * fix formatting * remove hparams * tests * CHANGELOG * Update CHANGELOG.md * max_batches can be int * conflict and rebase * add back the test fix fix message 0.0 works Revert "fix message" This reverts commit 839cacf8b8610f4e697e654ef6f3d2501bf23984. * update changelog * Update CHANGELOG.md * Fix num batches in case of multiple dataloaders and percent_check (Lightning-AI#1920) * git conflict Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * missing union * doc update suggestion by @rohitgr7 * extend test * changelog * docs add note about multiple loaders * update changelog * remove unused variable Co-authored-by: rohitgr7 <rohitgr1998@gmail.com> Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 44385bb commit e085e93

13 files changed

+119
-20
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
### Fixed
2020

21+
- Fixed number batches in case of multiple dataloaders and `limit_{*}_batches` ([#1920](https://github.com/PyTorchLightning/pytorch-lightning/pull/1920), [#2226](https://github.com/PyTorchLightning/pytorch-lightning/pull/2226))
22+
2123
- Fixed an issue with forward hooks not being removed after model summary ([#2298](https://github.com/PyTorchLightning/pytorch-lightning/pull/2298))
2224

2325
- Fixed ROC metric for CUDA tensors ([#2304](https://github.com/PyTorchLightning/pytorch-lightning/pull/2304))

pytorch_lightning/trainer/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,8 @@ def on_train_end(self, trainer, pl_module):
456456
# run for only 10 batches
457457
trainer = Trainer(limit_test_batches=10)
458458
459+
In the case of multiple test dataloaders, the limit applies to each dataloader individually.
460+
459461
limit_val_batches
460462
^^^^^^^^^^^^^^^^^
461463
@@ -473,6 +475,8 @@ def on_train_end(self, trainer, pl_module):
473475
# run for only 10 batches
474476
trainer = Trainer(limit_val_batches=10)
475477
478+
In the case of multiple validation dataloaders, the limit applies to each dataloader individually.
479+
476480
log_gpu_memory
477481
^^^^^^^^^^^^^^
478482
Options:

pytorch_lightning/trainer/data_loading.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,8 +287,6 @@ def _reset_eval_dataloader(
287287
for i, dataloader in enumerate(dataloaders):
288288
num_batches = 0
289289
self._worker_check(dataloader, f'{mode} dataloader {i}')
290-
if not _has_len(dataloader):
291-
num_batches = float('inf')
292290

293291
# percent or num_steps
294292
limit_eval_batches = getattr(self, f'limit_{mode}_batches')

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@
124124

125125
from abc import ABC, abstractmethod
126126
from pprint import pprint
127-
from typing import Callable, Optional, List
127+
from typing import Callable, Optional, List, Union
128128

129129
import torch
130130
from torch.utils.data import DataLoader
@@ -222,13 +222,20 @@ def reset_test_dataloader(self, *args):
222222
def reset_val_dataloader(self, *args):
223223
"""Warning: this is just empty shell for code implemented in other class."""
224224

225-
def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int], test_mode: bool = False):
225+
def _evaluate(
226+
self,
227+
model: LightningModule,
228+
dataloaders: List[DataLoader],
229+
max_batches: Union[int, List[int]],
230+
test_mode: bool = False
231+
):
226232
"""Run evaluation code.
227233
228234
Args:
229-
model: PT model
230-
dataloaders: list of PT dataloaders
231-
max_batches: List of scalars
235+
model: The model to evaluate.
236+
dataloaders: A list of PyTorch dataloaders.
237+
max_batches: An integer or list of integers with length of the number of dataloaders. Each
238+
entry is the number of batches to process in the corresponding dataloader.
232239
test_mode:
233240
"""
234241
# enable eval mode
@@ -244,6 +251,10 @@ def _evaluate(self, model: LightningModule, dataloaders, max_batches: List[int],
244251
# bookkeeping
245252
outputs = []
246253

254+
# convert max_batches to list
255+
if isinstance(max_batches, int):
256+
max_batches = [max_batches] * len(dataloaders)
257+
247258
# run validation
248259
for dataloader_idx, dataloader in enumerate(dataloaders):
249260
dl_outputs = []

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __init__(
223223
224224
min_steps: Force training for at least these number of steps. Disabled by default (None).
225225
226-
limit_train_batches: How much of training dataset to check.
226+
limit_train_batches: How much of training dataset to check (floats = percent, int = num_batches)
227227
228228
limit_val_batches: How much of validation dataset to check (floats = percent, int = num_batches)
229229

pytorch_lightning/trainer/training_loop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,6 @@ class TrainerTrainLoopMixin(ABC):
208208
check_val_every_n_epoch: ...
209209
num_training_batches: int
210210
val_check_batch: ...
211-
num_val_batches: int
212211
disable_validation: bool
213212
fast_dev_run: ...
214213
accumulation_scheduler: ...

tests/base/model_test_dataloaders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
class TestDataloaderVariations(ABC):
88

99
@abstractmethod
10-
def dataloader(self, train: bool):
10+
def dataloader(self, *args, **kwargs):
1111
"""placeholder"""
1212

1313
def test_dataloader(self):
@@ -19,6 +19,11 @@ def test_dataloader__infinite(self):
1919
def test_dataloader__not_implemented_error(self):
2020
return CustomNotImplementedErrorDataloader(self.dataloader(train=False))
2121

22+
def test_dataloader__multiple_mixed_length(self):
23+
lengths = [50, 30, 40]
24+
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
25+
return dataloaders
26+
2227
def test_dataloader__empty(self):
2328
return None
2429

tests/base/model_test_epoch_ends.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ class TestEpochEndVariations(ABC):
77

88
def test_epoch_end(self, outputs):
99
"""
10-
Called at the end of validation to aggregate outputs
10+
Called at the end of test epoch to aggregate outputs
1111
:param outputs: list of individual outputs of each validation step
1212
:return:
1313
"""
@@ -40,7 +40,7 @@ def test_epoch_end(self, outputs):
4040

4141
def test_epoch_end__multiple_dataloaders(self, outputs):
4242
"""
43-
Called at the end of validation to aggregate outputs
43+
Called at the end of test epoch to aggregate outputs
4444
:param outputs: list of individual outputs of each validation step
4545
:return:
4646
"""

tests/base/model_utilities.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
class ModelTemplateData:
77
hparams: ...
88

9-
def dataloader(self, train):
10-
dataset = TrialMNIST(root=self.data_root, train=train, download=True)
9+
def dataloader(self, train: bool, num_samples: int = 100):
10+
dataset = TrialMNIST(root=self.data_root, train=train, num_samples=num_samples, download=True)
1111

1212
loader = DataLoader(
1313
dataset=dataset,

tests/base/model_valid_dataloaders.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,17 @@
77
class ValDataloaderVariations(ABC):
88

99
@abstractmethod
10-
def dataloader(self, train: bool):
10+
def dataloader(self, *args, **kwargs):
1111
"""placeholder"""
1212

1313
def val_dataloader(self):
1414
return self.dataloader(train=False)
1515

16+
def val_dataloader__multiple_mixed_length(self):
17+
lengths = [100, 30]
18+
dataloaders = [self.dataloader(train=False, num_samples=n) for n in lengths]
19+
return dataloaders
20+
1621
def val_dataloader__multiple(self):
1722
return [self.dataloader(train=False),
1823
self.dataloader(train=False)]

0 commit comments

Comments
 (0)