Skip to content

Commit 7c19c37

Browse files
ivannzBorda
andauthored
LearningRateLogger in multi-scheduler setting (Lightning-AI#1944)
* fixed undesired behaviour due to dict.fromkeys * a test for log length consistency * runtime-warn if no schedulers are configured * chlog * move Co-authored-by: Jirka <jirka@pytorchlightning.ai>
1 parent 3af4994 commit 7c19c37

File tree

4 files changed

+118
-84
lines changed

4 files changed

+118
-84
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4242

4343
- Fixed an issue with `Trainer.from_argparse_args` when passing in unknown Trainer args ([#1932](https://github.com/PyTorchLightning/pytorch-lightning/pull/1932))
4444

45-
- Fix bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
45+
- Fixed bug related to logger not being reset correctly for model after tuner algorithms ([#1933](https://github.com/PyTorchLightning/pytorch-lightning/pull/1933))
46+
47+
- Fixed `LearningRateLogger` in multi-scheduler setting ([#1944](https://github.com/PyTorchLightning/pytorch-lightning/pull/1944))
4648

4749

4850
## [0.7.6] - 2020-05-16

pytorch_lightning/callbacks/lr_logger.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
from pytorch_lightning.callbacks.base import Callback
1111
from pytorch_lightning.utilities.exceptions import MisconfigurationException
1212

13+
from pytorch_lightning.utilities import rank_zero_warn
14+
1315

1416
class LearningRateLogger(Callback):
1517
r"""
@@ -45,21 +47,22 @@ def on_train_start(self, trainer, pl_module):
4547
schedulers in the case of multiple of the same type or in
4648
the case of multiple parameter groups
4749
"""
48-
if trainer.lr_schedulers == []:
49-
raise MisconfigurationException(
50-
'Cannot use LearningRateLogger callback with models that have no'
51-
' learning rate schedulers. Please see documentation for'
52-
' `configure_optimizers` method.')
53-
5450
if not trainer.logger:
5551
raise MisconfigurationException(
5652
'Cannot use LearningRateLogger callback with Trainer that has no logger.')
5753

54+
if not trainer.lr_schedulers:
55+
rank_zero_warn(
56+
'You are using LearningRateLogger callback with models that'
57+
' have no learning rate schedulers. Please see documentation'
58+
' for `configure_optimizers` method.', RuntimeWarning
59+
)
60+
5861
# Find names for schedulers
5962
names = self._find_names(trainer.lr_schedulers)
6063

6164
# Initialize for storing values
62-
self.lrs = dict.fromkeys(names, [])
65+
self.lrs = {name: [] for name in names}
6366

6467
def on_batch_start(self, trainer, pl_module):
6568
latest_stat = self._extract_lr(trainer, 'step')

tests/callbacks/test_callbacks.py

Lines changed: 3 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1+
from pathlib import Path
2+
13
import pytest
24

35
import tests.base.utils as tutils
46
from pytorch_lightning import Callback
57
from pytorch_lightning import Trainer, LightningModule
6-
from pytorch_lightning.callbacks import EarlyStopping, LearningRateLogger, ModelCheckpoint
8+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
79
from pytorch_lightning.loggers import TensorBoardLogger
810
from tests.base import EvalModelTemplate
9-
from pathlib import Path
1011

1112

1213
def test_trainer_callback_system(tmpdir):
@@ -281,77 +282,3 @@ def test_model_checkpoint_path(tmpdir, logger_version, expected):
281282

282283
ckpt_version = Path(trainer.ckpt_path).parent.name
283284
assert ckpt_version == expected
284-
285-
286-
def test_lr_logger_single_lr(tmpdir):
287-
""" Test that learning rates are extracted and logged for single lr scheduler"""
288-
tutils.reset_seed()
289-
290-
model = EvalModelTemplate()
291-
model.configure_optimizers = model.configure_optimizers__single_scheduler
292-
293-
lr_logger = LearningRateLogger()
294-
trainer = Trainer(
295-
default_root_dir=tmpdir,
296-
max_epochs=5,
297-
val_percent_check=0.1,
298-
train_percent_check=0.5,
299-
callbacks=[lr_logger]
300-
)
301-
results = trainer.fit(model)
302-
303-
assert results == 1
304-
assert lr_logger.lrs, 'No learning rates logged'
305-
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
306-
'Number of learning rates logged does not match number of lr schedulers'
307-
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
308-
'Names of learning rates not set correctly'
309-
310-
311-
def test_lr_logger_multi_lrs(tmpdir):
312-
""" Test that learning rates are extracted and logged for multi lr schedulers """
313-
tutils.reset_seed()
314-
315-
model = EvalModelTemplate()
316-
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
317-
318-
lr_logger = LearningRateLogger()
319-
trainer = Trainer(
320-
default_root_dir=tmpdir,
321-
max_epochs=1,
322-
val_percent_check=0.1,
323-
train_percent_check=0.5,
324-
callbacks=[lr_logger]
325-
)
326-
results = trainer.fit(model)
327-
328-
assert results == 1
329-
assert lr_logger.lrs, 'No learning rates logged'
330-
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
331-
'Number of learning rates logged does not match number of lr schedulers'
332-
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
333-
'Names of learning rates not set correctly'
334-
335-
336-
def test_lr_logger_param_groups(tmpdir):
337-
""" Test that learning rates are extracted and logged for single lr scheduler"""
338-
tutils.reset_seed()
339-
340-
model = EvalModelTemplate()
341-
model.configure_optimizers = model.configure_optimizers__param_groups
342-
343-
lr_logger = LearningRateLogger()
344-
trainer = Trainer(
345-
default_root_dir=tmpdir,
346-
max_epochs=5,
347-
val_percent_check=0.1,
348-
train_percent_check=0.5,
349-
callbacks=[lr_logger]
350-
)
351-
results = trainer.fit(model)
352-
353-
assert lr_logger.lrs, 'No learning rates logged'
354-
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
355-
'Number of learning rates logged does not match number of param groups'
356-
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
357-
'Names of learning rates not set correctly'

tests/callbacks/test_lr.py

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import pytest
2+
3+
import tests.base.utils as tutils
4+
from pytorch_lightning import Trainer
5+
from pytorch_lightning.callbacks import LearningRateLogger
6+
from tests.base import EvalModelTemplate
7+
8+
9+
def test_lr_logger_single_lr(tmpdir):
10+
""" Test that learning rates are extracted and logged for single lr scheduler. """
11+
tutils.reset_seed()
12+
13+
model = EvalModelTemplate()
14+
model.configure_optimizers = model.configure_optimizers__single_scheduler
15+
16+
lr_logger = LearningRateLogger()
17+
trainer = Trainer(
18+
default_root_dir=tmpdir,
19+
max_epochs=5,
20+
val_percent_check=0.1,
21+
train_percent_check=0.5,
22+
callbacks=[lr_logger]
23+
)
24+
result = trainer.fit(model)
25+
assert result
26+
27+
assert lr_logger.lrs, 'No learning rates logged'
28+
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
29+
'Number of learning rates logged does not match number of lr schedulers'
30+
assert all([k in ['lr-Adam'] for k in lr_logger.lrs.keys()]), \
31+
'Names of learning rates not set correctly'
32+
33+
34+
def test_lr_logger_no_lr(tmpdir):
35+
tutils.reset_seed()
36+
37+
model = EvalModelTemplate()
38+
39+
lr_logger = LearningRateLogger()
40+
trainer = Trainer(
41+
default_root_dir=tmpdir,
42+
max_epochs=5,
43+
val_percent_check=0.1,
44+
train_percent_check=0.5,
45+
callbacks=[lr_logger]
46+
)
47+
48+
with pytest.warns(RuntimeWarning):
49+
result = trainer.fit(model)
50+
assert result
51+
52+
53+
def test_lr_logger_multi_lrs(tmpdir):
54+
""" Test that learning rates are extracted and logged for multi lr schedulers. """
55+
tutils.reset_seed()
56+
57+
model = EvalModelTemplate()
58+
model.configure_optimizers = model.configure_optimizers__multiple_schedulers
59+
60+
lr_logger = LearningRateLogger()
61+
trainer = Trainer(
62+
default_root_dir=tmpdir,
63+
max_epochs=10,
64+
val_percent_check=0.1,
65+
train_percent_check=0.5,
66+
callbacks=[lr_logger]
67+
)
68+
result = trainer.fit(model)
69+
assert result
70+
71+
assert lr_logger.lrs, 'No learning rates logged'
72+
assert len(lr_logger.lrs) == len(trainer.lr_schedulers), \
73+
'Number of learning rates logged does not match number of lr schedulers'
74+
assert all([k in ['lr-Adam', 'lr-Adam-1'] for k in lr_logger.lrs.keys()]), \
75+
'Names of learning rates not set correctly'
76+
assert all(len(lr) == trainer.max_epochs for k, lr in lr_logger.lrs.items()), \
77+
'Length of logged learning rates exceeds the number of epochs'
78+
79+
80+
def test_lr_logger_param_groups(tmpdir):
81+
""" Test that learning rates are extracted and logged for single lr scheduler. """
82+
tutils.reset_seed()
83+
84+
model = EvalModelTemplate()
85+
model.configure_optimizers = model.configure_optimizers__param_groups
86+
87+
lr_logger = LearningRateLogger()
88+
trainer = Trainer(
89+
default_root_dir=tmpdir,
90+
max_epochs=5,
91+
val_percent_check=0.1,
92+
train_percent_check=0.5,
93+
callbacks=[lr_logger]
94+
)
95+
result = trainer.fit(model)
96+
assert result
97+
98+
assert lr_logger.lrs, 'No learning rates logged'
99+
assert len(lr_logger.lrs) == 2 * len(trainer.lr_schedulers), \
100+
'Number of learning rates logged does not match number of param groups'
101+
assert all([k in ['lr-Adam/pg1', 'lr-Adam/pg2'] for k in lr_logger.lrs.keys()]), \
102+
'Names of learning rates not set correctly'

0 commit comments

Comments
 (0)