Skip to content

Commit b39f479

Browse files
ananthsubBorda
andauthored
Add support to Tensorboard logger for OmegaConf hparams (Lightning-AI#2846)
* Add support to Tensorboard logger for OmegaConf hparams Address Lightning-AI#2844 We check if we can import omegaconf, and if the hparams are omegaconf instances. if so, we use OmegaConf.merge to preserve the typing, such that saving hparams to yaml actually triggers the OmegaConf branch * avalaible * chlog * test Co-authored-by: Jirka Borovec <jirka@pytorchlightning.ai>
1 parent 91b0d46 commit b39f479

File tree

5 files changed

+58
-24
lines changed

5 files changed

+58
-24
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3939

4040
- Added support returning python scalars in DP ([#1935](https://github.com/PyTorchLightning/pytorch-lightning/pull/1935))
4141

42+
- Added support to Tensorboard logger for OmegaConf `hparams` ([#2846](https://github.com/PyTorchLightning/pytorch-lightning/pull/2846))
43+
4244
### Changed
4345

4446
- Truncated long version numbers in progress bar ([#2594](https://github.com/PyTorchLightning/pytorch-lightning/pull/2594))

pytorch_lightning/core/saving.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@
1717
try:
1818
from omegaconf import Container
1919
except ImportError:
20-
Container = None
20+
OMEGACONF_AVAILABLE = False
21+
else:
22+
OMEGACONF_AVAILABLE = True
2123

2224
# the older shall be on the top
2325
CHECKPOINT_PAST_HPARAMS_KEYS = (
@@ -327,7 +329,7 @@ def save_hparams_to_yaml(config_yaml, hparams: Union[dict, Namespace]) -> None:
327329
if not os.path.isdir(os.path.dirname(config_yaml)):
328330
raise RuntimeError(f'Missing folder: {os.path.dirname(config_yaml)}.')
329331

330-
if Container is not None and isinstance(hparams, Container):
332+
if OMEGACONF_AVAILABLE and isinstance(hparams, Container):
331333
from omegaconf import OmegaConf
332334
OmegaConf.save(hparams, config_yaml, resolve=True)
333335
return

pytorch_lightning/loggers/tensorboard.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,13 @@
1717
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
1818
from pytorch_lightning.utilities import rank_zero_only
1919

20+
try:
21+
from omegaconf import Container, OmegaConf
22+
except ImportError:
23+
OMEGACONF_AVAILABLE = False
24+
else:
25+
OMEGACONF_AVAILABLE = True
26+
2027

2128
class TensorBoardLogger(LightningLoggerBase):
2229
r"""
@@ -112,7 +119,10 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace],
112119
params = self._convert_params(params)
113120

114121
# store params to output
115-
self.hparams.update(params)
122+
if OMEGACONF_AVAILABLE and isinstance(params, Container):
123+
self.hparams = OmegaConf.merge(self.hparams, params)
124+
else:
125+
self.hparams.update(params)
116126

117127
# format params into the suitable for tensorboard
118128
params = self._flatten_dict(params)

pytorch_lightning/trainer/training_io.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,9 @@
132132
try:
133133
from omegaconf import Container
134134
except ImportError:
135-
Container = None
135+
OMEGACONF_AVAILABLE = False
136+
else:
137+
OMEGACONF_AVAILABLE = True
136138

137139

138140
class TrainerIOMixin(ABC):
@@ -390,7 +392,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
390392
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
391393
# add arguments to the checkpoint
392394
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
393-
if Container is not None:
395+
if OMEGACONF_AVAILABLE:
394396
if isinstance(model.hparams, Container):
395397
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
396398

tests/loggers/test_tensorboard.py

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,36 @@
44
import pytest
55
import torch
66
import yaml
7+
from omegaconf import OmegaConf
78
from packaging import version
89

910
from pytorch_lightning import Trainer
1011
from pytorch_lightning.loggers import TensorBoardLogger
1112
from tests.base import EvalModelTemplate
1213

1314

14-
@pytest.mark.skipif(version.parse(torch.__version__) < version.parse('1.5.0'),
15-
reason='Minimal PT version is set to 1.5')
15+
@pytest.mark.skipif(
16+
version.parse(torch.__version__) < version.parse("1.5.0"),
17+
reason="Minimal PT version is set to 1.5",
18+
)
1619
def test_tensorboard_hparams_reload(tmpdir):
1720
model = EvalModelTemplate()
1821

19-
trainer = Trainer(
20-
max_epochs=1,
21-
default_root_dir=tmpdir,
22-
)
22+
trainer = Trainer(max_epochs=1, default_root_dir=tmpdir)
2323
trainer.fit(model)
2424

2525
folder_path = trainer.logger.log_dir
2626

2727
# make sure yaml is there
28-
with open(os.path.join(folder_path, 'hparams.yaml')) as file:
28+
with open(os.path.join(folder_path, "hparams.yaml")) as file:
2929
# The FullLoader parameter handles the conversion from YAML
3030
# scalar values to Python the dictionary format
3131
yaml_params = yaml.safe_load(file)
32-
assert yaml_params['b1'] == 0.5
32+
assert yaml_params["b1"] == 0.5
3333
assert len(yaml_params.keys()) == 10
3434

3535
# verify artifacts
36-
assert len(os.listdir(os.path.join(folder_path, 'checkpoints'))) == 1
36+
assert len(os.listdir(os.path.join(folder_path, "checkpoints"))) == 1
3737
#
3838
# # verify tb logs
3939
# event_acc = EventAccumulator(folder_path)
@@ -88,13 +88,13 @@ def test_tensorboard_named_version(tmpdir):
8888
assert os.listdir(tmpdir / name / expected_version)
8989

9090

91-
@pytest.mark.parametrize("name", ['', None])
91+
@pytest.mark.parametrize("name", ["", None])
9292
def test_tensorboard_no_name(tmpdir, name):
9393
"""Verify that None or empty name works"""
9494
logger = TensorBoardLogger(save_dir=tmpdir, name=name)
9595
logger.log_hyperparams({"a": 1, "b": 2}) # Force data to be written
9696
assert logger.root_dir == tmpdir
97-
assert os.listdir(tmpdir / 'version_0')
97+
assert os.listdir(tmpdir / "version_0")
9898

9999

100100
@pytest.mark.parametrize("step_idx", [10, None])
@@ -104,7 +104,7 @@ def test_tensorboard_log_metrics(tmpdir, step_idx):
104104
"float": 0.3,
105105
"int": 1,
106106
"FloatTensor": torch.tensor(0.1),
107-
"IntTensor": torch.tensor(1)
107+
"IntTensor": torch.tensor(1),
108108
}
109109
logger.log_metrics(metrics, step_idx)
110110

@@ -116,10 +116,10 @@ def test_tensorboard_log_hyperparams(tmpdir):
116116
"int": 1,
117117
"string": "abc",
118118
"bool": True,
119-
"dict": {'a': {'b': 'c'}},
119+
"dict": {"a": {"b": "c"}},
120120
"list": [1, 2, 3],
121-
"namespace": Namespace(foo=Namespace(bar='buzz')),
122-
"layer": torch.nn.BatchNorm1d
121+
"namespace": Namespace(foo=Namespace(bar="buzz")),
122+
"layer": torch.nn.BatchNorm1d,
123123
}
124124
logger.log_hyperparams(hparams)
125125

@@ -131,10 +131,28 @@ def test_tensorboard_log_hparams_and_metrics(tmpdir):
131131
"int": 1,
132132
"string": "abc",
133133
"bool": True,
134-
"dict": {'a': {'b': 'c'}},
134+
"dict": {"a": {"b": "c"}},
135135
"list": [1, 2, 3],
136-
"namespace": Namespace(foo=Namespace(bar='buzz')),
137-
"layer": torch.nn.BatchNorm1d
136+
"namespace": Namespace(foo=Namespace(bar="buzz")),
137+
"layer": torch.nn.BatchNorm1d,
138138
}
139-
metrics = {'abc': torch.tensor([0.54])}
139+
metrics = {"abc": torch.tensor([0.54])}
140+
logger.log_hyperparams(hparams, metrics)
141+
142+
143+
def test_tensorboard_log_omegaconf_hparams_and_metrics(tmpdir):
144+
logger = TensorBoardLogger(tmpdir)
145+
hparams = {
146+
"float": 0.3,
147+
"int": 1,
148+
"string": "abc",
149+
"bool": True,
150+
"dict": {"a": {"b": "c"}},
151+
"list": [1, 2, 3],
152+
# "namespace": Namespace(foo=Namespace(bar="buzz")),
153+
# "layer": torch.nn.BatchNorm1d,
154+
}
155+
hparams = OmegaConf.create(hparams)
156+
157+
metrics = {"abc": torch.tensor([0.54])}
140158
logger.log_hyperparams(hparams, metrics)

0 commit comments

Comments
 (0)