Skip to content

Commit a91b06e

Browse files
jeremyjordanBorda
andauthored
fix worker warning (Lightning-AI#2504)
* fix worker warning * improve tests * suggestion Co-authored-by: Jirka <jirka@pytorchlightning.ai>
1 parent 96b32be commit a91b06e

File tree

2 files changed

+32
-22
lines changed

2 files changed

+32
-22
lines changed

pytorch_lightning/trainer/data_loading.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -124,22 +124,24 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
124124
# ddp_spawn + num_workers > 0 don't mix! tell the user
125125
is_dataloader = isinstance(dataloader, DataLoader)
126126
using_spawn = self.distributed_backend == 'ddp_spawn'
127-
if is_dataloader and dataloader.num_workers > 0 and not on_windows and using_spawn:
128-
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well! '
129-
'Your performance might suffer dramatically. '
130-
'Please consider setting distributed_backend=ddp to use num_workers > 0 '
131-
'(this is a bottleneck of Python .spawn() and PyTorch')
132-
133-
elif is_dataloader and dataloader.num_workers <= 2 and not on_windows and not using_spawn:
134-
num_cpus = multiprocessing.cpu_count()
135-
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
136-
' Consider increasing the value of the `num_workers` argument` '
137-
f'(try {num_cpus} which is the number of cpus on this machine)'
138-
' in the `DataLoader` init to improve performance.')
139-
140-
elif is_dataloader and dataloader.num_workers == 0 and not on_windows and using_spawn:
141-
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0. '
142-
'For much faster performance, switch to `distributed_backend=ddp` and set `num_workers>0`')
127+
if is_dataloader and not on_windows:
128+
if dataloader.num_workers > 0 and using_spawn:
129+
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well!'
130+
' Your performance might suffer dramatically.'
131+
' Please consider setting distributed_backend=ddp to use num_workers > 0'
132+
' (this is a bottleneck of Python .spawn() and PyTorch')
133+
134+
elif dataloader.num_workers == 0 and using_spawn:
135+
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0.'
136+
' For much faster performance, switch to `distributed_backend=ddp`'
137+
' and set `num_workers>0`')
138+
139+
elif dataloader.num_workers <= 2 and multiprocessing.cpu_count() > 2 and not using_spawn:
140+
num_cpus = multiprocessing.cpu_count()
141+
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
142+
' Consider increasing the value of the `num_workers` argument`'
143+
f' (try {num_cpus} which is the number of cpus on this machine)'
144+
' in the `DataLoader` init to improve performance.')
143145

144146
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
145147

tests/trainer/test_dataloaders.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
import platform
2+
from unittest.mock import patch
23

34
import pytest
45
import torch
56
from packaging.version import parse
67
from torch.utils.data.dataloader import DataLoader
7-
from torch.utils.data.dataset import Subset, IterableDataset
8+
from torch.utils.data.dataset import IterableDataset, Subset
89

910
import tests.base.develop_pipelines as tpipes
1011
from pytorch_lightning import Trainer
11-
from pytorch_lightning.trainer.data_loading import _has_len, _has_iterable_dataset
12+
from pytorch_lightning.trainer.data_loading import _has_iterable_dataset, _has_len
1213
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1314
from tests.base import EvalModelTemplate
1415

@@ -449,7 +450,8 @@ def test_error_on_zero_len_dataloader(tmpdir):
449450

450451
@pytest.mark.skipif(platform.system() == 'Windows', reason='Does not apply to Windows platform.')
451452
@pytest.mark.parametrize('ckpt_path', [None, 'best', 'specific'])
452-
def test_warning_with_few_workers(tmpdir, ckpt_path):
453+
@patch('pytorch_lightning.trainer.data_loading.multiprocessing.cpu_count', return_value=4)
454+
def test_warning_with_few_workers(mock, tmpdir, ckpt_path):
453455
""" Test that error is raised if dataloader with only a few workers is used """
454456

455457
model = EvalModelTemplate()
@@ -476,16 +478,22 @@ def test_warning_with_few_workers(tmpdir, ckpt_path):
476478
trainer = Trainer(**trainer_options)
477479

478480
# fit model
479-
with pytest.warns(UserWarning, match='train'):
481+
with pytest.warns(
482+
UserWarning, match='The dataloader, train dataloader, does not have many workers which may be a bottleneck.'
483+
):
480484
trainer.fit(model, **fit_options)
481485

482-
with pytest.warns(UserWarning, match='val'):
486+
with pytest.warns(
487+
UserWarning, match='The dataloader, val dataloader 0, does not have many workers which may be a bottleneck.'
488+
):
483489
trainer.fit(model, **fit_options)
484490

485491
if ckpt_path == 'specific':
486492
ckpt_path = trainer.checkpoint_callback.best_model_path
487493
test_options = dict(test_dataloaders=train_dl, ckpt_path=ckpt_path)
488-
with pytest.warns(UserWarning, match='test'):
494+
with pytest.warns(
495+
UserWarning, match='The dataloader, test dataloader 0, does not have many workers which may be a bottleneck.'
496+
):
489497
trainer.test(**test_options)
490498

491499

0 commit comments

Comments
 (0)