Skip to content

Commit 9df2b20

Browse files
sets default ddp mode to spawn (Lightning-AI#2168)
* set ddp_spawn as default * spawn message * spawn message * spawn message * spawn message * spawn message * spawn message * spawn message * spawn message
1 parent bb32ae5 commit 9df2b20

File tree

6 files changed

+37
-18
lines changed

6 files changed

+37
-18
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -942,18 +942,18 @@ def init_ddp_connection(
942942
self._init_slurm_connection()
943943

944944
if 'MASTER_ADDR' not in os.environ:
945-
log.warning("MASTER_ADDR environment variable is not defined. Set as localhost")
945+
rank_zero_warn("MASTER_ADDR environment variable is not defined. Set as localhost")
946946
os.environ['MASTER_ADDR'] = '127.0.0.1'
947947
log.debug(f"MASTER_ADDR: {os.environ['MASTER_ADDR']}")
948948

949949
if 'MASTER_PORT' not in os.environ:
950-
log.warning("MASTER_PORT environment variable is not defined. Set as 12910")
950+
rank_zero_warn("MASTER_PORT environment variable is not defined. Set as 12910")
951951
os.environ['MASTER_PORT'] = '12910'
952952
log.debug(f"MASTER_PORT: {os.environ['MASTER_PORT']}")
953953

954954
if 'WORLD_SIZE' in os.environ and int(os.environ['WORLD_SIZE']) != world_size:
955-
log.warning(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
956-
f"is not equal to the computed world size ({world_size}). Ignored.")
955+
rank_zero_warn(f"WORLD_SIZE environment variable ({os.environ['WORLD_SIZE']}) "
956+
f"is not equal to the computed world size ({world_size}). Ignored.")
957957

958958
torch_backend = "nccl" if self.trainer.on_gpu else "gloo"
959959
log.info(f"initializing ddp: LOCAL_RANK: {proc_rank}/{world_size - 1} WORLD_SIZE:{world_size}")

pytorch_lightning/trainer/data_loading.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,11 +96,24 @@ def _percent_range_check(self, name: str) -> None:
9696
def _worker_check(self, dataloader: DataLoader, name: str) -> None:
9797
on_windows = platform.system() == 'Windows'
9898

99-
if isinstance(dataloader, DataLoader) and dataloader.num_workers <= 2 and not on_windows:
99+
# ddp_spawn + num_workers > 0 don't mix! tell the user
100+
is_dataloader = isinstance(dataloader, DataLoader)
101+
using_spawn = self.distributed_backend == 'ddp_spawn'
102+
if is_dataloader and dataloader.num_workers > 0 and not on_windows and using_spawn:
103+
rank_zero_warn('Dataloader(num_workers>0) and ddp_spawn do not mix well! '
104+
'Your performance might suffer dramatically. '
105+
'Please consider setting distributed_backend=ddp to use num_workers > 0 '
106+
'(this is a bottleneck of Python .spawn() and PyTorch')
107+
108+
elif is_dataloader and dataloader.num_workers <= 2 and not on_windows and not using_spawn:
100109
rank_zero_warn(f'The dataloader, {name}, does not have many workers which may be a bottleneck.'
101110
' Consider increasing the value of the `num_workers` argument`'
102111
' in the `DataLoader` init to improve performance.')
103112

113+
elif is_dataloader and dataloader.num_workers == 0 and not on_windows and using_spawn:
114+
rank_zero_warn('You are using `distributed_backend=ddp_spawn` with num_workers=0. '
115+
'For much faster performance, switch to `distributed_backend=ddp` and set `num_workers>0`')
116+
104117
def auto_add_sampler(self, dataloader: DataLoader, train: bool) -> DataLoader:
105118

106119
# don't do anything if it's not a dataloader

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def train_fx(trial_hparams, cluster_manager, _):
128128
from pytorch_lightning.callbacks import ModelCheckpoint
129129
from pytorch_lightning.loggers import LightningLoggerBase
130130
from pytorch_lightning.utilities.exceptions import MisconfigurationException
131-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
131+
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
132132

133133
try:
134134
from apex import amp
@@ -220,9 +220,9 @@ def set_distributed_mode(self, distributed_backend):
220220
elif self.num_gpus > 1:
221221
rank_zero_warn('You requested multiple GPUs but did not specify a backend, e.g.'
222222
' Trainer(distributed_backend=dp) (or ddp, ddp2).'
223-
' Setting distributed_backend=ddp for you.')
224-
self.distributed_backend = 'ddp'
225-
distributed_backend = 'ddp'
223+
' Setting distributed_backend=ddp_spawn for you.')
224+
self.distributed_backend = 'ddp_spawn'
225+
distributed_backend = 'ddp_spawn'
226226

227227
if distributed_backend == "dp":
228228
# do nothing if num_gpus == 0
@@ -264,7 +264,7 @@ def set_distributed_mode(self, distributed_backend):
264264
'To silence this warning set distributed_backend=ddp or distributed_backend=ddp2'
265265
)
266266

267-
log.info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
267+
rank_zero_info(f'GPU available: {torch.cuda.is_available()}, used: {self.on_gpu}')
268268

269269
def configure_slurm_ddp(self, num_gpu_nodes):
270270
self.is_slurm_managing_tasks = False
@@ -298,7 +298,7 @@ def configure_slurm_ddp(self, num_gpu_nodes):
298298

299299
# notify user the that slurm is managing tasks
300300
if self.is_slurm_managing_tasks:
301-
log.info('Multi-processing is handled by Slurm.')
301+
rank_zero_info('Multi-processing is handled by Slurm.')
302302

303303
def determine_ddp_node_rank(self):
304304
if self.is_slurm_managing_tasks:
@@ -316,7 +316,7 @@ def determine_ddp_node_rank(self):
316316
log.warning(f"Multiple environment variables ({node_ids}) defined for node rank. "
317317
f"Using the first one.")
318318
k, rank = node_ids.pop()
319-
log.info(f"Using environment variable {k} for node rank ({rank}).")
319+
rank_zero_info(f"Using environment variable {k} for node rank ({rank}).")
320320
return int(rank)
321321

322322
def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
@@ -336,7 +336,7 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
336336
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_str
337337

338338
# don't make this debug... this is good UX
339-
log.info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
339+
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
340340

341341
def __set_random_port(self):
342342
"""

pytorch_lightning/trainer/trainer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
3232
from pytorch_lightning.trainer.lr_finder import TrainerLRFinderMixin
3333
from pytorch_lightning.utilities.exceptions import MisconfigurationException
34-
from pytorch_lightning.utilities import rank_zero_warn, parsing
34+
from pytorch_lightning.utilities import rank_zero_warn, parsing, rank_zero_info
3535

3636
try:
3737
from apex import amp
@@ -362,8 +362,8 @@ def __init__(
362362
if self.fast_dev_run:
363363
self.num_sanity_val_steps = 0
364364
self.max_epochs = 1
365-
log.info('Running in fast_dev_run mode: will run a full train,'
366-
' val and test loop using a single batch')
365+
rank_zero_info('Running in fast_dev_run mode: will run a full train,'
366+
' val and test loop using a single batch')
367367

368368
# set default save path if user didn't provide one
369369
self.default_root_dir = default_root_dir
@@ -838,7 +838,7 @@ def fit(
838838
self.single_gpu_train(model)
839839

840840
elif self.use_tpu: # pragma: no-cover
841-
log.info(f'training on {self.tpu_cores} TPU cores')
841+
rank_zero_info(f'training on {self.tpu_cores} TPU cores')
842842

843843
# COLAB_GPU is an env var available by default in Colab environments.
844844
start_method = 'fork' if self.on_colab_kaggle else 'spawn'
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""General utilities"""
22

3-
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn
3+
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_warn, rank_zero_info
44
from pytorch_lightning.utilities.apply_func import move_data_to_device
55
from pytorch_lightning.utilities.parsing import AttributeDict

pytorch_lightning/utilities/distributed.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from functools import wraps
22
import warnings
3+
from pytorch_lightning import _logger as log
34

45

56
def rank_zero_only(fn):
@@ -23,4 +24,9 @@ def _warn(*args, **kwargs):
2324
warnings.warn(*args, **kwargs)
2425

2526

27+
def _info(*args, **kwargs):
28+
log.info(*args, **kwargs)
29+
30+
31+
rank_zero_info = rank_zero_only(_info)
2632
rank_zero_warn = rank_zero_only(_warn)

0 commit comments

Comments
 (0)