Skip to content

Commit 11069c8

Browse files
williamFalconBordajeremyjordan
authored
Fix ddp tests + .test() (Lightning-AI#2512)
* added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * fix deprecation warnings * added base tests for tpu * added base tests for tpu * Update pytorch_lightning/trainer/trainer.py Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu * added base tests for tpu Co-authored-by: Jirka <jirka@pytorchlightning.ai> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
1 parent fb85d49 commit 11069c8

26 files changed

+468
-227
lines changed

pytorch_lightning/core/decorators.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def data_loader(fn):
1313
Warnings:
1414
This decorator deprecated in v0.7.0 and it will be removed v0.9.0.
1515
"""
16-
rank_zero_warn('`data_loader` decorator deprecated in v0.7.0. Will be removed v0.9.0', DeprecationWarning)
16+
rank_zero_warn("`data_loader` decorator deprecated in v0.7.0. It will be removed in v0.9.0", DeprecationWarning)
1717

1818
def inner_fx(self):
1919
return fn(self)

pytorch_lightning/loggers/tensorboard.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ def experiment(self) -> SummaryWriter:
106106
self._experiment = SummaryWriter(log_dir=self.log_dir, **self._kwargs)
107107
return self._experiment
108108

109+
@experiment.setter
110+
def experiment(self, exp):
111+
self._experiment = exp
112+
109113
@rank_zero_only
110114
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
111115
metrics: Optional[Dict[str, Any]] = None) -> None:

pytorch_lightning/trainer/data_loading.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
221221
self.num_training_batches = len(self.train_dataloader)
222222
self.num_training_batches = int(self.num_training_batches * self.limit_train_batches)
223223
else:
224-
self.num_training_batches = self.limit_train_batches
224+
self.num_training_batches = min(len(self.train_dataloader), self.limit_train_batches)
225225

226226
# determine when to check validation
227227
# if int passed in, val checks that often
@@ -313,7 +313,7 @@ def _reset_eval_dataloader(
313313
if isinstance(limit_eval_batches, float):
314314
num_batches = int(num_batches * limit_eval_batches)
315315
else:
316-
num_batches = limit_eval_batches
316+
num_batches = min(len(dataloader), limit_eval_batches)
317317

318318
elif limit_eval_batches not in (0.0, 1.0):
319319
raise MisconfigurationException(
@@ -340,8 +340,7 @@ def reset_val_dataloader(self, model: LightningModule) -> None:
340340
model: The current `LightningModule`
341341
"""
342342
if self.is_overridden('validation_step'):
343-
self.num_val_batches, self.val_dataloaders = \
344-
self._reset_eval_dataloader(model, 'val')
343+
self.num_val_batches, self.val_dataloaders = self._reset_eval_dataloader(model, 'val')
345344

346345
def reset_test_dataloader(self, model) -> None:
347346
"""Resets the validation dataloader and determines the number of batches.

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def train_fx(trial_hparams, cluster_manager, _):
122122
from time import sleep
123123
import numpy as np
124124
from os.path import abspath
125+
from torch import distributed as dist
126+
import queue
125127

126128
import torch
127129
from pytorch_lightning import _logger as log
@@ -163,6 +165,10 @@ def train_fx(trial_hparams, cluster_manager, _):
163165
else:
164166
XLA_AVAILABLE = True
165167

168+
pid = os.getpid()
169+
rng1 = np.random.RandomState(pid)
170+
RANDOM_PORTS = rng1.randint(10000, 19999, 100)
171+
166172

167173
class TrainerDDPMixin(ABC):
168174

@@ -178,6 +184,7 @@ class TrainerDDPMixin(ABC):
178184
use_tpu: bool
179185
default_root_dir: str
180186
progress_bar_callback: ...
187+
checkpoint_callback: ...
181188
num_processes: int
182189
num_nodes: int
183190
node_rank: int
@@ -377,17 +384,19 @@ def set_nvidia_flags(self, is_slurm_managing_tasks, data_parallel_device_ids):
377384
# don't make this debug... this is good UX
378385
rank_zero_info(f'CUDA_VISIBLE_DEVICES: [{os.environ["CUDA_VISIBLE_DEVICES"]}]')
379386

380-
def set_random_port(self):
387+
def set_random_port(self, force=False):
381388
"""
382389
When running DDP NOT managed by SLURM, the ports might collide
383390
"""
384-
try:
385-
default_port = os.environ['MASTER_PORT']
386-
except Exception:
387-
# use the process id as a seed to a generator for port only
388-
pid = os.getpid()
389-
rng1 = np.random.RandomState(pid)
390-
default_port = rng1.randint(10000, 19999, 1)[0]
391+
# pick a random port first
392+
assert self.num_nodes == 1, 'random port can only be called from single node training'
393+
global RANDOM_PORTS
394+
default_port = RANDOM_PORTS[-1]
395+
RANDOM_PORTS = RANDOM_PORTS[:-1]
396+
397+
# when not forced, use the user port
398+
if not force:
399+
default_port = os.environ.get('MASTER_PORT', default_port)
391400

392401
os.environ['MASTER_PORT'] = str(default_port)
393402

@@ -446,15 +455,24 @@ def spawn_ddp_children(self, model):
446455
sleep(delay)
447456

448457
local_rank = 0
449-
self.ddp_train(local_rank, model, is_master=True)
458+
results = self.ddp_train(local_rank, q=None, model=model, is_master=True)
459+
del os.environ['WORLD_SIZE']
450460

451-
def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
461+
return results
462+
463+
def ddp_train(self, process_idx, q, model, is_master=False, proc_offset=0):
452464
"""
453-
Entry point into a DP thread
454-
:param gpu_idx:
455-
:param model:
456-
:param cluster_obj:
457-
:return:
465+
Entry point for ddp
466+
467+
Args:
468+
process_idx:
469+
q:
470+
model:
471+
is_master:
472+
proc_offset:
473+
474+
Returns:
475+
458476
"""
459477
# offset the process id if requested
460478
process_idx = process_idx + proc_offset
@@ -535,7 +553,17 @@ def ddp_train(self, process_idx, model, is_master=False, proc_offset=0):
535553
model = model.configure_ddp(model, device_ids)
536554

537555
# continue training routine
538-
self.run_pretrain_routine(model)
556+
results = self.run_pretrain_routine(model)
557+
558+
# clean up memory
559+
torch.cuda.empty_cache()
560+
561+
if self.global_rank == 0 and q is not None:
562+
q.put(self.checkpoint_callback.best_model_path)
563+
q.put(results)
564+
565+
if self.global_rank == 0 and self.distributed_backend != 'ddp_spawn':
566+
return results
539567

540568
def save_spawn_weights(self, model):
541569
"""

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytorch_lightning.utilities import move_data_to_device, NATIVE_AMP_AVALAIBLE
2222
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2323
from pytorch_lightning.utilities.distributed import rank_zero_only
24+
from pytorch_lightning.utilities import rank_zero_warn
2425

2526
try:
2627
from apex import amp
@@ -182,7 +183,8 @@ def single_gpu_train(self, model):
182183
self.optimizers = optimizers
183184
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
184185

185-
self.run_pretrain_routine(model)
186+
results = self.run_pretrain_routine(model)
187+
return results
186188

187189
def tpu_train(self, tpu_core_idx, model):
188190
# call setup after the ddp process has connected
@@ -221,6 +223,7 @@ def tpu_train(self, tpu_core_idx, model):
221223

222224
# when training ends on these platforms dump weights to get out of the main process
223225
if self.on_colab_kaggle:
226+
rank_zero_warn('cleaning up... please do not interrupt')
224227
self.save_spawn_weights(model)
225228

226229
def dp_train(self, model):
@@ -229,12 +232,12 @@ def dp_train(self, model):
229232
if self.is_function_implemented('setup', model):
230233
model.setup('fit')
231234

235+
model.cuda(self.root_gpu)
236+
232237
# CHOOSE OPTIMIZER
233238
# allow for lr schedulers as well
234239
self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
235240

236-
model.cuda(self.root_gpu)
237-
238241
# hack forward to do autocast for the user
239242
model_autocast_original_forward = model.forward
240243
if self.use_amp and NATIVE_AMP_AVALAIBLE:
@@ -264,10 +267,11 @@ def dp_train(self, model):
264267

265268
model = LightningDataParallel(model, device_ids=device_ids)
266269

267-
self.run_pretrain_routine(model)
268-
270+
result = self.run_pretrain_routine(model)
269271
model.forward = model_autocast_original_forward
270272

273+
return result
274+
271275
def horovod_train(self, model):
272276
# call setup after the ddp process has connected
273277
self.setup('fit')
@@ -325,10 +329,11 @@ def filter_named_parameters(model, optimizer):
325329
# Synchronization will be performed explicitly following backward()
326330
stack.enter_context(optimizer.skip_synchronize())
327331

328-
self.run_pretrain_routine(model)
332+
result = self.run_pretrain_routine(model)
329333

330334
# Make sure all workers have finished training before returning to the user
331335
hvd.join()
336+
return result
332337

333338

334339
def _normalize_parse_gpu_string_input(s: Union[int, str, List[int]]) -> Union[int, List[int]]:

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def _evaluate(
325325
if self.is_overridden('test_end', model=model):
326326
# TODO: remove in v1.0.0
327327
eval_results = model.test_end(outputs)
328-
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed v1.0.'
328+
rank_zero_warn('Method `test_end` was deprecated in v0.7 and will be removed in v1.0.'
329329
' Use `test_epoch_end` instead.', DeprecationWarning)
330330

331331
elif self.is_overridden('test_epoch_end', model=model):
@@ -335,7 +335,7 @@ def _evaluate(
335335
if self.is_overridden('validation_end', model=model):
336336
# TODO: remove in v1.0.0
337337
eval_results = model.validation_end(outputs)
338-
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed v1.0.'
338+
rank_zero_warn('Method `validation_end` was deprecated in v0.7 and will be removed in v1.0.'
339339
' Use `validation_epoch_end` instead.', DeprecationWarning)
340340

341341
elif self.is_overridden('validation_epoch_end', model=model):
@@ -391,6 +391,7 @@ def run_evaluation(self, test_mode: bool = False):
391391
eval_results = self._evaluate(self.model, dataloaders, max_batches, test_mode)
392392

393393
# enable no returns
394+
callback_metrics = {}
394395
if eval_results is not None and len(eval_results) > 0:
395396
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.process_output(eval_results)
396397

@@ -428,6 +429,8 @@ def run_evaluation(self, test_mode: bool = False):
428429
else:
429430
self.on_validation_end()
430431

432+
return callback_metrics
433+
431434
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
432435
# make dataloader_idx arg in validation_step optional
433436
args = [batch, batch_idx]

0 commit comments

Comments
 (0)