Skip to content

Commit 8f6b7a2

Browse files
SkafteNickiNicki Skafte
andauthored
Fix user warning produced by apex + scheduler combination (Lightning-AI#1873)
* fix user error produced by apex + scheduler combination * add changelog * added reinit to every configure_apex call * fix styling Co-authored-by: Nicki Skafte <nugginea@gmail.com>
1 parent d610f3b commit 8f6b7a2

File tree

4 files changed

+19
-0
lines changed

4 files changed

+19
-0
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2828

2929
### Fixed
3030

31+
- Fixed user warning when apex was used together with learning rate schedulers ([#1873](https://github.com/PyTorchLightning/pytorch-lightning/pull/1873))
32+
3133
## [0.7.6] - 2020-05-16
3234

3335
### Added

pytorch_lightning/trainer/distrib_data_parallel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,7 @@ def ddp_train(self, process_idx, model):
372372
if self.use_amp and not self.use_native_amp:
373373
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
374374
self.optimizers = optimizers
375+
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
375376

376377
# DDP2 uses all GPUs on the machine
377378
if self.distributed_backend == 'ddp':

pytorch_lightning/trainer/distrib_parts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@ def single_gpu_train(self, model):
497497
# An example
498498
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
499499
self.optimizers = optimizers
500+
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
500501

501502
self.run_pretrain_routine(model)
502503

@@ -559,6 +560,7 @@ def dp_train(self, model):
559560
f' We recommend you switch to ddp if you want to use amp')
560561
else:
561562
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
563+
self.reinit_scheduler_properties(optimizers, self.lr_schedulers)
562564

563565
# create list of device ids
564566
device_ids = self.data_parallel_device_ids
@@ -599,6 +601,7 @@ def horovod_train(self, model):
599601
# An example
600602
model, optimizers = model.configure_apex(amp, model, self.optimizers, self.amp_level)
601603
self.optimizers = optimizers
604+
self.reinit_scheduler_properties(self.optimizers, self.lr_schedulers)
602605

603606
# Horovod: broadcast parameters & optimizer state to ensure consistent initialization
604607
hvd.broadcast_parameters(model.state_dict(), root_rank=0)

pytorch_lightning/trainer/optimizers.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,19 @@ def configure_schedulers(self, schedulers: list):
108108
'is a invalid input.')
109109
return lr_schedulers
110110

111+
def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
112+
# Reinitialize optimizer.step properties added by schedulers
113+
for scheduler in schedulers:
114+
for optimizer in optimizers:
115+
scheduler = scheduler['scheduler']
116+
# check that we dont mix users optimizers and schedulers
117+
if scheduler.optimizer == optimizer:
118+
# Find the mro belonging to the base lr scheduler class
119+
for i, mro in enumerate(scheduler.__class__.__mro__):
120+
if mro == optim.lr_scheduler._LRScheduler:
121+
idx = i
122+
scheduler.__class__.__mro__[idx].__init__(scheduler, optimizer)
123+
111124

112125
class _MockOptimizer(Optimizer):
113126
"""The `_MockOptimizer` will be used inplace of an optimizer in the event that `None`

0 commit comments

Comments
 (0)