@@ -256,6 +256,69 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
256256 f'Multiple `test_dataloaders` not initiated properly, got { trainer .test_dataloaders } '
257257
258258
259+ @pytest .mark .parametrize (
260+ ['limit_train_batches' , 'limit_val_batches' , 'limit_test_batches' ],
261+ [
262+ pytest .param (0.0 , 0.0 , 0.0 ),
263+ pytest .param (1.0 , 1.0 , 1.0 ),
264+ ]
265+ )
266+ def test_inf_dataloaders_with_limit_percent_batches (tmpdir , limit_train_batches ,
267+ limit_val_batches , limit_test_batches ):
268+ """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit in percent"""
269+ model = EvalModelTemplate ()
270+ model .train_dataloader = model .train_dataloader__infinite
271+ model .val_dataloader = model .val_dataloader__infinite
272+ model .test_dataloader = model .test_dataloader__infinite
273+
274+ trainer = Trainer (
275+ default_root_dir = tmpdir ,
276+ max_epochs = 1 ,
277+ limit_train_batches = limit_train_batches ,
278+ limit_val_batches = limit_val_batches ,
279+ limit_test_batches = limit_test_batches ,
280+ )
281+
282+ results = trainer .fit (model )
283+ assert results == 1
284+ assert trainer .num_training_batches == 0 if limit_train_batches == 0.0 else float ('inf' )
285+ assert trainer .num_val_batches [0 ] == 0 if limit_val_batches == 0.0 else float ('inf' )
286+
287+ trainer .test (ckpt_path = None )
288+ assert trainer .num_test_batches [0 ] == 0 if limit_test_batches == 0.0 else float ('inf' )
289+
290+
291+ @pytest .mark .parametrize (
292+ ['limit_train_batches' , 'limit_val_batches' , 'limit_test_batches' ],
293+ [
294+ pytest .param (0 , 0 , 0 ),
295+ pytest .param (10 , 10 , 10 ),
296+ ]
297+ )
298+ def test_inf_dataloaders_with_limit_num_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
299+ """Verify inf train, val & test dataloaders (e.g. IterableDataset) passed with batch limit as number"""
300+ model = EvalModelTemplate ()
301+ model .train_dataloader = model .train_dataloader__infinite
302+ model .val_dataloader = model .val_dataloader__infinite
303+ model .test_dataloader = model .test_dataloader__infinite
304+
305+ trainer = Trainer (
306+ default_root_dir = tmpdir ,
307+ max_epochs = 1 ,
308+ limit_train_batches = limit_train_batches ,
309+ limit_val_batches = limit_val_batches ,
310+ limit_test_batches = limit_test_batches ,
311+ )
312+
313+ results = trainer .fit (model )
314+ assert results
315+ assert trainer .num_training_batches == limit_train_batches
316+ assert trainer .num_val_batches [0 ] == limit_val_batches
317+
318+ trainer .test (ckpt_path = None )
319+ assert trainer .num_test_batches [0 ] == limit_test_batches
320+
321+
259322@pytest .mark .parametrize (
260323 ['limit_train_batches' , 'limit_val_batches' , 'limit_test_batches' ],
261324 [
@@ -266,7 +329,7 @@ def test_multiple_dataloaders_passed_to_fit(tmpdir, ckpt_path):
266329 ]
267330)
268331def test_dataloaders_with_limit_percent_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
269- """Verify num_batches for val & test dataloaders passed with batch limit in percent"""
332+ """Verify num_batches for train, val & test dataloaders passed with batch limit in percent"""
270333 model = EvalModelTemplate ()
271334 model .val_dataloader = model .val_dataloader__multiple_mixed_length
272335 model .test_dataloader = model .test_dataloader__multiple_mixed_length
@@ -307,7 +370,7 @@ def test_dataloaders_with_limit_percent_batches(tmpdir, limit_train_batches, lim
307370 ]
308371)
309372def test_dataloaders_with_limit_num_batches (tmpdir , limit_train_batches , limit_val_batches , limit_test_batches ):
310- """Verify num_batches for val & test dataloaders passed with batch limit as number"""
373+ """Verify num_batches for train, val & test dataloaders passed with batch limit as number"""
311374 os .environ ['PL_DEV_DEBUG' ] = '1'
312375
313376 model = EvalModelTemplate ()
@@ -436,7 +499,7 @@ def test_train_inf_dataloader_error(tmpdir):
436499
437500 trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , val_check_interval = 0.5 )
438501
439- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
502+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
440503 trainer .fit (model )
441504
442505
@@ -447,7 +510,7 @@ def test_val_inf_dataloader_error(tmpdir):
447510
448511 trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , limit_val_batches = 0.5 )
449512
450- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
513+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
451514 trainer .fit (model )
452515
453516
@@ -458,7 +521,7 @@ def test_test_inf_dataloader_error(tmpdir):
458521
459522 trainer = Trainer (default_root_dir = tmpdir , max_epochs = 1 , limit_test_batches = 0.5 )
460523
461- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
524+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
462525 trainer .test (model )
463526
464527
@@ -774,7 +837,7 @@ def test_train_dataloader_not_implemented_error_failed(tmpdir):
774837
775838 trainer = Trainer (default_root_dir = tmpdir , max_steps = 5 , max_epochs = 1 , val_check_interval = 0.5 )
776839
777- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
840+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
778841 trainer .fit (model )
779842
780843
@@ -785,7 +848,7 @@ def test_val_dataloader_not_implemented_error_failed(tmpdir):
785848
786849 trainer = Trainer (default_root_dir = tmpdir , max_steps = 5 , max_epochs = 1 , limit_val_batches = 0.5 )
787850
788- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
851+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
789852 trainer .fit (model )
790853
791854
@@ -796,5 +859,5 @@ def test_test_dataloader_not_implemented_error_failed(tmpdir):
796859
797860 trainer = Trainer (default_root_dir = tmpdir , max_steps = 5 , max_epochs = 1 , limit_test_batches = 0.5 )
798861
799- with pytest .raises (MisconfigurationException , match = 'infinite DataLoader ' ):
862+ with pytest .raises (MisconfigurationException , match = 'using an IterableDataset ' ):
800863 trainer .test (model )
0 commit comments