Skip to content

Commit de9c9f0

Browse files
rohitgr7Bordaawaelchlimergify[bot]ethanwharris
authored
Support limit_mode_batches (int) for infinite dataloader (Lightning-AI#2787)
* Support limit_mode_batches(int) for infinite dataloader * flake8 * revert and update * add and update tests * pep8 * chlog * Update CHANGELOG.md Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> * Add suggestions by @awaelchli * docs * Apply suggestions from code review Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> * Apply suggestions from code review * fix * max * check Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> Co-authored-by: Adrian Wälchli <aedu.waelchli@gmail.com> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com> Co-authored-by: Ethan Harris <ewah1g13@soton.ac.uk> Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
1 parent b2a7d75 commit de9c9f0

File tree

7 files changed

+112
-41
lines changed

7 files changed

+112
-41
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3131

3232
- Added remaining `sklearn` metrics: `AveragePrecision`, `BalancedAccuracy`, `CohenKappaScore`, `DCG`, `Hamming`, `Hinge`, `Jaccard`, `MeanAbsoluteError`, `MeanSquaredError`, `MeanSquaredLogError`, `MedianAbsoluteError`, `R2Score`, `MeanPoissonDeviance`, `MeanGammaDeviance`, `MeanTweedieDeviance`, `ExplainedVariance` ([#2562](https://github.com/PyTorchLightning/pytorch-lightning/pull/2562))
3333

34+
- Added support for `limit_{mode}_batches (int)` to work with infinite dataloader (IterableDataset) ([#2787](https://github.com/PyTorchLightning/pytorch-lightning/pull/2787))
35+
3436
### Changed
3537

3638
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

docs/source/sequences.rst

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,8 @@ Lightning can handle TBTT automatically via this flag.
4949
.. note:: If you need to modify how the batch is split,
5050
override :meth:`pytorch_lightning.core.LightningModule.tbptt_split_batch`.
5151

52-
.. note:: Using this feature requires updating your LightningModule's :meth:`pytorch_lightning.core.LightningModule.training_step` to include
53-
a `hiddens` arg.
52+
.. note:: Using this feature requires updating your LightningModule's
53+
:meth:`pytorch_lightning.core.LightningModule.training_step` to include a `hiddens` arg.
5454

5555
----------
5656

@@ -59,10 +59,13 @@ Iterable Datasets
5959
Lightning supports using IterableDatasets as well as map-style Datasets. IterableDatasets provide a more natural
6060
option when using sequential data.
6161

62-
.. note:: When using an IterableDataset you must set the val_check_interval to 1.0 (the default) or to an int
63-
(specifying the number of training batches to run before validation) when initializing the Trainer.
64-
This is due to the fact that the IterableDataset does not have a __len__ and Lightning requires this to calculate
65-
the validation interval when val_check_interval is less than one.
62+
.. note:: When using an IterableDataset you must set the ``val_check_interval`` to 1.0 (the default) or an int
63+
(specifying the number of training batches to run before validation) when initializing the Trainer. This is
64+
because the IterableDataset does not have a ``__len__`` and Lightning requires this to calculate the validation
65+
interval when ``val_check_interval`` is less than one. Similarly, you can set ``limit_{mode}_batches`` to a float or
66+
an int. If it is set to 0.0 or 0 it will set ``num_{mode}_batches`` to 0, if it is an int it will set ``num_{mode}_batches``
67+
to ``limit_{mode}_batches``, if it is set to 1.0 it will run for the whole dataset, otherwise it will throw an exception.
68+
Here mode can be train/val/test.
6669

6770
.. testcode::
6871

@@ -87,3 +90,9 @@ option when using sequential data.
8790

8891
# Set val_check_interval
8992
trainer = Trainer(val_check_interval=100)
93+
94+
# Set limit_val_batches to 0.0 or 0
95+
trainer = Trainer(limit_val_batches=0.0)
96+
97+
# Set limit_val_batches as an int
98+
trainer = Trainer(limit_val_batches=100)

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1754,7 +1754,7 @@ def to_onnx(self, file_path: str, input_sample: Optional[Tensor] = None, **kwarg
17541754
elif self.example_input_array is not None:
17551755
input_data = self.example_input_array
17561756
else:
1757-
raise ValueError(f'input_sample and example_input_array tensors are both missing.')
1757+
raise ValueError('`input_sample` and `example_input_array` tensors are both missing.')
17581758

17591759
if 'example_outputs' not in kwargs:
17601760
self.eval()

pytorch_lightning/trainer/data_loading.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -212,18 +212,19 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
212212
# automatically add samplers
213213
self.train_dataloader = self.auto_add_sampler(self.train_dataloader, train=True)
214214

215+
self.num_training_batches = len(self.train_dataloader) if _has_len(self.train_dataloader) else float('inf')
215216
self._worker_check(self.train_dataloader, 'train dataloader')
216217
self._check_batch_limits('limit_train_batches')
217218

218-
if not _has_len(self.train_dataloader):
219-
self.num_training_batches = float('inf')
220-
else:
221-
# try getting the length
222-
if isinstance(self.limit_train_batches, float):
223-
self.num_training_batches = len(self.train_dataloader)
224-
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
225-
else:
226-
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
219+
if isinstance(self.limit_train_batches, int) or self.limit_train_batches == 0.0:
220+
self.num_training_batches = min(self.num_training_batches, int(self.limit_train_batches))
221+
elif self.num_training_batches != float('inf'):
222+
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
223+
elif self.limit_train_batches != 1.0:
224+
raise MisconfigurationException(
225+
'When using an IterableDataset for `limit_train_batches`,'
226+
' `Trainer(limit_train_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
227+
' `num_training_batches` to use.')
227228

228229
# determine when to check validation
229230
# if int passed in, val checks that often
@@ -241,8 +242,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
241242
self.val_check_batch = float('inf')
242243
else:
243244
raise MisconfigurationException(
244-
'When using an infinite DataLoader (e.g. with an IterableDataset'
245-
' or when DataLoader does not implement `__len__`) for `train_dataloader`,'
245+
'When using an IterableDataset for `train_dataloader`,'
246246
' `Trainer(val_check_interval)` must be `1.0` or an int. An int k specifies'
247247
' checking validation every k training batches.')
248248
else:
@@ -304,24 +304,21 @@ def _reset_eval_dataloader(
304304
for i, dataloader in enumerate(dataloaders):
305305
num_batches = len(dataloader) if _has_len(dataloader) else float('inf')
306306
self._worker_check(dataloader, f'{mode} dataloader {i}')
307+
self._check_batch_limits(f'limit_{mode}_batches')
307308

308309
# percent or num_steps
309310
limit_eval_batches = getattr(self, f'limit_{mode}_batches')
310311

311-
if num_batches != float('inf'):
312-
self._check_batch_limits(f'limit_{mode}_batches')
313-
314-
# limit num batches either as a percent or num steps
315-
if isinstance(limit_eval_batches, float):
316-
num_batches = int(num_batches * limit_eval_batches)
317-
else:
318-
num_batches = min(len(dataloader), limit_eval_batches)
319-
320-
elif limit_eval_batches not in (0.0, 1.0):
312+
# limit num batches either as a percent or num steps
313+
if isinstance(limit_eval_batches, int) or limit_eval_batches == 0.0:
314+
num_batches = min(num_batches, int(limit_eval_batches))
315+
elif num_batches != float('inf'):
316+
num_batches = int(num_batches * limit_eval_batches)
317+
elif limit_eval_batches != 1.0:
321318
raise MisconfigurationException(
322-
'When using an infinite DataLoader (e.g. with an IterableDataset'
323-
f' or when DataLoader does not implement `__len__`) for `limit_{mode}_batches`,'
324-
f' `Trainer(limit_{mode}_batches)` must be `0.0` or `1.0`.')
319+
'When using an IterableDataset for `limit_{mode}_batches`,'
320+
f' `Trainer(limit_{mode}_batches)` must be `0.0`, `1.0` or an int. An int k specifies'
321+
f' `num_{mode}_batches` to use.')
325322

326323
if num_batches == 0 and limit_eval_batches > 0.0 and isinstance(limit_eval_batches, float):
327324
min_pct = 1.0 / len(dataloader)

pytorch_lightning/trainer/training_tricks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def _adjust_batch_size(trainer,
269269
if hasattr(model, batch_arg_name):
270270
setattr(model, batch_arg_name, value)
271271
else:
272-
setattr(model.hparams, batch_arg_name, value)
272+
setattr(model.hparams, batch_arg_name, value)
273273
new_size = value
274274
if desc:
275275
log.info(f'Batch size {batch_size} {desc}, trying batch size {new_size}')

tests/models/test_onnx_save.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def test_error_if_no_input(tmpdir):
8484
model = EvalModelTemplate()
8585
model.example_input_array = None
8686
file_path = os.path.join(tmpdir, "model.onxx")
87-
with pytest.raises(ValueError, match=r'input_sample and example_input_array tensors are both missing'):
87+
with pytest.raises(ValueError, match=r'`input_sample` and `example_input_array` tensors are both missing'):
8888
model.to_onnx(file_path)
8989

9090

tests/trainer/test_dataloaders.py

Lines changed: 71 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -256,6 +256,69 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
256256
f'Multiple `test_dataloaders` not initiated properly, got {trainer.test_dataloaders}'
257257

258258

259+
@pytest.mark.parametrize(
260+
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
261+
[
262+
pytest.param(0.0, 0.0, 0.0),
263+
pytest.param(1.0, 1.0, 1.0),
264+
]
265+
)
266+
def test_inf_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches,
267+
limit_val_batches, limit_test_batches):
268+
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
269+
model = EvalModelTemplate()
270+
model.train_dataloader = model.train_dataloader__infinite
271+
model.val_dataloader = model.val_dataloader__infinite
272+
model.test_dataloader = model.test_dataloader__infinite
273+
274+
trainer = Trainer(
275+
default_root_dir=tmpdir,
276+
max_epochs=1,
277+
limit_train_batches=limit_train_batches,
278+
limit_val_batches=limit_val_batches,
279+
limit_test_batches=limit_test_batches,
280+
)
281+
282+
results = trainer.fit(model)
283+
assert results == 1
284+
assert trainer.num_training_batches == 0 if limit_train_batches == 0.0 else float('inf')
285+
assert trainer.num_val_batches[0] == 0 if limit_val_batches == 0.0 else float('inf')
286+
287+
trainer.test(ckpt_path=None)
288+
assert trainer.num_test_batches[0] == 0 if limit_test_batches == 0.0 else float('inf')
289+
290+
291+
@pytest.mark.parametrize(
292+
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
293+
[
294+
pytest.param(0, 0, 0),
295+
pytest.param(10, 10, 10),
296+
]
297+
)
298+
def test_inf_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
299+
"""Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
300+
model = EvalModelTemplate()
301+
model.train_dataloader = model.train_dataloader__infinite
302+
model.val_dataloader = model.val_dataloader__infinite
303+
model.test_dataloader = model.test_dataloader__infinite
304+
305+
trainer = Trainer(
306+
default_root_dir=tmpdir,
307+
max_epochs=1,
308+
limit_train_batches=limit_train_batches,
309+
limit_val_batches=limit_val_batches,
310+
limit_test_batches=limit_test_batches,
311+
)
312+
313+
results = trainer.fit(model)
314+
assert results
315+
assert trainer.num_training_batches == limit_train_batches
316+
assert trainer.num_val_batches[0] == limit_val_batches
317+
318+
trainer.test(ckpt_path=None)
319+
assert trainer.num_test_batches[0] == limit_test_batches
320+
321+
259322
@pytest.mark.parametrize(
260323
['limit_train_batches', 'limit_val_batches', 'limit_test_batches'],
261324
[
@@ -266,7 +329,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
266329
]
267330
)
268331
def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
269-
"""Verify num_batches for val & test dataloaders passed with batch limit in percent"""
332+
"""Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
270333
model = EvalModelTemplate()
271334
model.val_dataloader = model.val_dataloader__multiple_mixed_length
272335
model.test_dataloader = model.test_dataloader__multiple_mixed_length
@@ -307,7 +370,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
307370
]
308371
)
309372
def test_dataloaders_with_limit_num_batches(tmpdir, limit_train_batches, limit_val_batches, limit_test_batches):
310-
"""Verify num_batches for val & test dataloaders passed with batch limit as number"""
373+
"""Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
311374
os.environ['PL_DEV_DEBUG'] = '1'
312375

313376
model = EvalModelTemplate()
@@ -436,7 +499,7 @@ def test_train_inf_dataloader_error(tmpdir):
436499

437500
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, val_check_interval=0.5)
438501

439-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
502+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
440503
trainer.fit(model)
441504

442505

@@ -447,7 +510,7 @@ def test_val_inf_dataloader_error(tmpdir):
447510

448511
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_val_batches=0.5)
449512

450-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
513+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
451514
trainer.fit(model)
452515

453516

@@ -458,7 +521,7 @@ def test_test_inf_dataloader_error(tmpdir):
458521

459522
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, limit_test_batches=0.5)
460523

461-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
524+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
462525
trainer.test(model)
463526

464527

@@ -774,7 +837,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):
774837

775838
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, val_check_interval=0.5)
776839

777-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
840+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
778841
trainer.fit(model)
779842

780843

@@ -785,7 +848,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):
785848

786849
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_val_batches=0.5)
787850

788-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
851+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
789852
trainer.fit(model)
790853

791854

@@ -796,5 +859,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):
796859

797860
trainer = Trainer(default_root_dir=tmpdir, max_steps=5, max_epochs=1, limit_test_batches=0.5)
798861

799-
with pytest.raises(MisconfigurationException, match='infinite DataLoader'):
862+
with pytest.raises(MisconfigurationException, match='using an IterableDataset'):
800863
trainer.test(model)

0 commit comments

Comments
 (0)