11import platform
2+ from unittest .mock import patch
23
34import pytest
45import torch
56from packaging .version import parse
67from torch .utils .data .dataloader import DataLoader
7- from torch .utils .data .dataset import Subset , IterableDataset
8+ from torch .utils .data .dataset import IterableDataset , Subset
89
910import tests .base .develop_pipelines as tpipes
1011from pytorch_lightning import Trainer
11- from pytorch_lightning .trainer .data_loading import _has_len , _has_iterable_dataset
12+ from pytorch_lightning .trainer .data_loading import _has_iterable_dataset , _has_len
1213from pytorch_lightning .utilities .exceptions import MisconfigurationException
1314from tests .base import EvalModelTemplate
1415
@@ -449,7 +450,8 @@ def test_error_on_zero_len_dataloader(tmpdir):
449450
450451@pytest .mark .skipif (platform .system () == 'Windows' , reason = 'Does not apply to Windows platform.' )
451452@pytest .mark .parametrize ('ckpt_path' , [None , 'best' , 'specific' ])
452- def test_warning_with_few_workers (tmpdir , ckpt_path ):
453+ @patch ('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count' , return_value = 4 )
454+ def test_warning_with_few_workers (mock , tmpdir , ckpt_path ):
453455 """ Test that error is raised if dataloader with only a few workers is used """
454456
455457 model = EvalModelTemplate ()
@@ -476,16 +478,22 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
476478 trainer = Trainer (** trainer_options )
477479
478480 # fit model
479- with pytest .warns (UserWarning , match = 'train' ):
481+ with pytest .warns (
482+ UserWarning , match = 'The dataloader, train dataloader, does not have many workers which may be a bottleneck.'
483+ ):
480484 trainer .fit (model , ** fit_options )
481485
482- with pytest .warns (UserWarning , match = 'val' ):
486+ with pytest .warns (
487+ UserWarning , match = 'The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.'
488+ ):
483489 trainer .fit (model , ** fit_options )
484490
485491 if ckpt_path == 'specific' :
486492 ckpt_path = trainer .checkpoint_callback .best_model_path
487493 test_options = dict (test_dataloaders = train_dl , ckpt_path = ckpt_path )
488- with pytest .warns (UserWarning , match = 'test' ):
494+ with pytest .warns (
495+ UserWarning , match = 'The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.'
496+ ):
489497 trainer .test (** test_options )
490498
491499
0 commit comments