Skip to content

Commit e586ed4

Browse files
BordawilliamFalcon
andauthored
hparams as dict [blocked by 1041] (Lightning-AI#1029)
* hparams as dict * hparams as dict * fixing * fixing * fixing * fixing * typing * typing * chnagelog * update set hparams * use setter * simplify * chnagelog * imports * pylint * typing * Update training_io.py * Update training_io.py * Update lightning.py * Update test_trainer.py * Update __init__.py * Update base.py * Update utils.py * Update test_trainer.py * Update training_io.py * Update test_trainer.py * Update test_trainer.py * Update test_trainer.py * Update test_trainer.py * Update callback_config.py * Update callback_config.py * Update test_trainer.py Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent 6a39573 commit e586ed4

File tree

18 files changed

+168
-87
lines changed

18 files changed

+168
-87
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2424
- Split callbacks in multiple files ([#849](https://github.com/PyTorchLightning/pytorch-lightning/pull/849))
2525
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
2626
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
27+
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
2728

2829
### Changed
2930

@@ -32,6 +33,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3233
- Changed `pytorch_lightning.logging` to `pytorch_lightning.loggers` ([#767](https://github.com/PyTorchLightning/pytorch-lightning/pull/767))
3334
- Moved the default `tqdm_dict` definition from Trainer to `LightningModule`, so it can be overridden by the user ([#749](https://github.com/PyTorchLightning/pytorch-lightning/pull/749))
3435
- Moved functionality of `LightningModule.load_from_metrics` into `LightningModule.load_from_checkpoint` ([#995](https://github.com/PyTorchLightning/pytorch-lightning/pull/995))
36+
- Changed Checkpoint path parameter from `filepath` to `dirpath` ([#1016](https://github.com/PyTorchLightning/pytorch-lightning/pull/1016))
37+
- Freezed models `hparams` as `Namespace` property ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
3538

3639
### Deprecated
3740

docs/source/weights_loading.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ The Lightning checkpoint also saves the hparams (hyperparams) passed into the Li
6262
from argparse import Namespace
6363
6464
# usually these come from command line args
65-
args = Namespace(**{'learning_rate':0.001})
65+
args = Namespace(learning_rate=0.001)
6666
6767
# define you module to have hparams as the first arg
6868
# this means your checkpoint will have everything that went into making

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,11 @@ class ModelCheckpoint(Callback):
2727
2828
# save epoch and val_loss in name
2929
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
30+
3031
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
3132
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
3233
34+
3335
monitor: quantity to monitor.
3436
verbose: verbosity mode, False or True.
3537
save_top_k: if `save_top_k == k`,
@@ -135,7 +137,7 @@ def _save_model(self, filepath: str) -> None:
135137
if self.save_function is not None:
136138
self.save_function(filepath)
137139
else:
138-
raise ValueError(".save_function() not set")
140+
raise ValueError("Method `.save_function()` not set")
139141

140142
def check_monitor_top_k(self, current: float) -> bool:
141143
less_than_k_models = len(self.best_k_models) < self.save_top_k

pytorch_lightning/core/lightning.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import warnings
66
from abc import ABC, abstractmethod
77
from argparse import Namespace
8-
from typing import Optional, Union, Dict, Callable
8+
from typing import Any, Callable, Dict, Optional, Union
99

1010
import torch
1111
import torch.distributed as dist
@@ -68,6 +68,20 @@ def __init__(self, *args, **kwargs):
6868
#: True if using amp
6969
self.use_amp = False
7070

71+
@property
72+
def hparams(self) -> Namespace:
73+
if not hasattr(self, '_hparams'):
74+
return Namespace()
75+
assert isinstance(self._hparams, dict)
76+
return Namespace(**self._hparams)
77+
78+
@hparams.setter
79+
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
80+
"""Set the model hyper-parameters."""
81+
if isinstance(params, Namespace):
82+
params = vars(params)
83+
self._hparams = params
84+
7185
def print(self, *args, **kwargs):
7286
r"""
7387
Prints only from process 0. Use this in any distributed mode to log only once
@@ -1201,7 +1215,8 @@ def _load_model_state(cls, checkpoint):
12011215

12021216
if cls_takes_hparams:
12031217
if ckpt_hparams is not None:
1204-
hparams = Namespace(**ckpt_hparams)
1218+
is_namespace = checkpoint.get('hparams_type') == 'namespace'
1219+
hparams = Namespace(**ckpt_hparams) if is_namespace else ckpt_hparams
12051220
else:
12061221
warnings.warn(
12071222
f"Checkpoint does not contain hyperparameters but {cls.__name__}'s __init__ contains"

pytorch_lightning/core/saving.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,14 @@ def on_hpc_load(self, checkpoint):
3636
"""
3737

3838

39-
def load_hparams_from_tags_csv(tags_csv):
39+
def load_hparams_from_tags_csv(tags_csv) -> Namespace:
4040
if not os.path.isfile(tags_csv):
4141
log.warning(f'Missing Tags: {tags_csv}.')
4242
return Namespace()
4343

44-
tags = {}
4544
with open(tags_csv) as f:
4645
csv_reader = csv.reader(f, delimiter=',')
47-
for row in list(csv_reader)[1:]:
48-
tags[row[0]] = convert(row[1])
46+
tags = {row[0]: convert(row[1]) for row in list(csv_reader)[1:]}
4947
ns = Namespace(**tags)
5048
return ns
5149

pytorch_lightning/loggers/base.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
from abc import ABC, abstractmethod
3+
from argparse import Namespace
34
from functools import wraps
45
from typing import Union, Optional, Dict, Iterable, Any, Callable, List
56

@@ -41,6 +42,12 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
4142
"""
4243
pass
4344

45+
def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
46+
# in case converting from namespace
47+
if isinstance(params, Namespace):
48+
params = vars(params)
49+
return params
50+
4451
@abstractmethod
4552
def log_hyperparams(self, params: argparse.Namespace):
4653
"""Record hyperparameters.
@@ -50,19 +57,19 @@ def log_hyperparams(self, params: argparse.Namespace):
5057
"""
5158
pass
5259

53-
def save(self):
60+
def save(self) -> None:
5461
"""Save log data."""
5562
pass
5663

57-
def finalize(self, status: str):
64+
def finalize(self, status: str) -> None:
5865
"""Do any processing that is necessary to finalize an experiment.
5966
6067
Args:
6168
status: Status that the experiment finished with (e.g. success, failed, aborted)
6269
"""
6370
pass
6471

65-
def close(self):
72+
def close(self) -> None:
6673
"""Do any cleanup that is necessary to close an experiment."""
6774
pass
6875

@@ -72,7 +79,7 @@ def rank(self) -> int:
7279
return self._rank
7380

7481
@rank.setter
75-
def rank(self, value: int):
82+
def rank(self, value: int) -> None:
7683
"""Set the process rank."""
7784
self._rank = value
7885

@@ -107,23 +114,23 @@ def __getitem__(self, index: int) -> LightningLoggerBase:
107114
def experiment(self) -> List[Any]:
108115
return [logger.experiment for logger in self._logger_iterable]
109116

110-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
117+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
111118
[logger.log_metrics(metrics, step) for logger in self._logger_iterable]
112119

113-
def log_hyperparams(self, params: argparse.Namespace):
120+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
114121
[logger.log_hyperparams(params) for logger in self._logger_iterable]
115122

116-
def save(self):
123+
def save(self) -> None:
117124
[logger.save() for logger in self._logger_iterable]
118125

119-
def finalize(self, status: str):
126+
def finalize(self, status: str) -> None:
120127
[logger.finalize(status) for logger in self._logger_iterable]
121128

122-
def close(self):
129+
def close(self) -> None:
123130
[logger.close() for logger in self._logger_iterable]
124131

125132
@LightningLoggerBase.rank.setter
126-
def rank(self, value: int):
133+
def rank(self, value: int) -> None:
127134
self._rank = value
128135
for logger in self._logger_iterable:
129136
logger.rank = value

pytorch_lightning/loggers/comet.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
CometLogger
66
-------------
77
"""
8-
import argparse
8+
from argparse import Namespace
99
from logging import getLogger
10-
from typing import Optional, Dict, Union
10+
from typing import Optional, Dict, Union, Any
1111

1212
try:
1313
from comet_ml import Experiment as CometExperiment
@@ -162,15 +162,16 @@ def experiment(self) -> CometBaseExperiment:
162162
return self._experiment
163163

164164
@rank_zero_only
165-
def log_hyperparams(self, params: argparse.Namespace):
166-
self.experiment.log_parameters(vars(params))
165+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
166+
params = self._convert_params(params)
167+
self.experiment.log_parameters(params)
167168

168169
@rank_zero_only
169170
def log_metrics(
170171
self,
171172
metrics: Dict[str, Union[torch.Tensor, float]],
172173
step: Optional[int] = None
173-
):
174+
) -> None:
174175
# Comet.ml expects metrics to be a dictionary of detached tensors on CPU
175176
for key, val in metrics.items():
176177
if is_tensor(val):
@@ -182,7 +183,7 @@ def reset_experiment(self):
182183
self._experiment = None
183184

184185
@rank_zero_only
185-
def finalize(self, status: str):
186+
def finalize(self, status: str) -> None:
186187
r"""
187188
When calling self.experiment.end(), that experiment won't log any more data to Comet. That's why, if you need
188189
to log any more data you need to create an ExistingCometExperiment. For example, to log data when testing your
@@ -199,7 +200,7 @@ def name(self) -> str:
199200
return self.experiment.project_name
200201

201202
@name.setter
202-
def name(self, value: str):
203+
def name(self, value: str) -> None:
203204
self.experiment.set_name(value)
204205

205206
@property

pytorch_lightning/loggers/mlflow.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,10 @@ def any_lightning_module_function_or_hook(...):
2323
self.logger.experiment.whatever_ml_flow_supports(...)
2424
2525
"""
26-
import argparse
26+
from argparse import Namespace
2727
from logging import getLogger
2828
from time import time
29-
from typing import Optional, Dict, Any
29+
from typing import Optional, Dict, Any, Union
3030

3131
try:
3232
import mlflow
@@ -88,12 +88,13 @@ def run_id(self):
8888
return self._run_id
8989

9090
@rank_zero_only
91-
def log_hyperparams(self, params: argparse.Namespace):
92-
for k, v in vars(params).items():
91+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
92+
params = self._convert_params(params)
93+
for k, v in params.items():
9394
self.experiment.log_param(self.run_id, k, v)
9495

9596
@rank_zero_only
96-
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
97+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
9798
timestamp_ms = int(time() * 1000)
9899
for k, v in metrics.items():
99100
if isinstance(v, str):
@@ -105,7 +106,7 @@ def save(self):
105106
pass
106107

107108
@rank_zero_only
108-
def finalize(self, status: str = 'FINISHED'):
109+
def finalize(self, status: str = 'FINISHED') -> None:
109110
if status == 'success':
110111
status = 'FINISHED'
111112
self.experiment.set_terminated(self.run_id, status)

pytorch_lightning/loggers/neptune.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
NeptuneLogger
77
--------------
88
"""
9-
import argparse
9+
from argparse import Namespace
1010
from logging import getLogger
1111
from typing import Optional, List, Dict, Any, Union, Iterable
1212

@@ -164,16 +164,17 @@ def experiment(self) -> Experiment:
164164
return self._experiment
165165

166166
@rank_zero_only
167-
def log_hyperparams(self, params: argparse.Namespace):
168-
for key, val in vars(params).items():
167+
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
168+
params = self._convert_params(params)
169+
for key, val in params.items():
169170
self.experiment.set_property(f'param__{key}', val)
170171

171172
@rank_zero_only
172173
def log_metrics(
173174
self,
174175
metrics: Dict[str, Union[torch.Tensor, float]],
175176
step: Optional[int] = None
176-
):
177+
) -> None:
177178
"""Log metrics (numeric values) in Neptune experiments
178179
179180
Args:
@@ -184,7 +185,7 @@ def log_metrics(
184185
self.log_metric(key, val, step=step)
185186

186187
@rank_zero_only
187-
def finalize(self, status: str):
188+
def finalize(self, status: str) -> None:
188189
self.experiment.stop()
189190

190191
@property
@@ -207,7 +208,7 @@ def log_metric(
207208
metric_name: str,
208209
metric_value: Union[torch.Tensor, float, str],
209210
step: Optional[int] = None
210-
):
211+
) -> None:
211212
"""Log metrics (numeric values) in Neptune experiments
212213
213214
Args:
@@ -224,7 +225,7 @@ def log_metric(
224225
self.experiment.log_metric(metric_name, x=step, y=metric_value)
225226

226227
@rank_zero_only
227-
def log_text(self, log_name: str, text: str, step: Optional[int] = None):
228+
def log_text(self, log_name: str, text: str, step: Optional[int] = None) -> None:
228229
"""Log text data in Neptune experiment
229230
230231
Args:
@@ -235,7 +236,7 @@ def log_text(self, log_name: str, text: str, step: Optional[int] = None):
235236
self.log_metric(log_name, text, step=step)
236237

237238
@rank_zero_only
238-
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None):
239+
def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] = None) -> None:
239240
"""Log image data in Neptune experiment
240241
241242
Args:
@@ -250,7 +251,7 @@ def log_image(self, log_name: str, image: Union[str, Any], step: Optional[int] =
250251
self.experiment.log_image(log_name, x=step, y=image)
251252

252253
@rank_zero_only
253-
def log_artifact(self, artifact: str, destination: Optional[str] = None):
254+
def log_artifact(self, artifact: str, destination: Optional[str] = None) -> None:
254255
"""Save an artifact (file) in Neptune experiment storage.
255256
256257
Args:
@@ -261,7 +262,7 @@ def log_artifact(self, artifact: str, destination: Optional[str] = None):
261262
self.experiment.log_artifact(artifact, destination)
262263

263264
@rank_zero_only
264-
def set_property(self, key: str, value: Any):
265+
def set_property(self, key: str, value: Any) -> None:
265266
"""Set key-value pair as Neptune experiment property.
266267
267268
Args:
@@ -271,7 +272,7 @@ def set_property(self, key: str, value: Any):
271272
self.experiment.set_property(key, value)
272273

273274
@rank_zero_only
274-
def append_tags(self, tags: Union[str, Iterable[str]]):
275+
def append_tags(self, tags: Union[str, Iterable[str]]) -> None:
275276
"""appends tags to neptune experiment
276277
277278
Args:

0 commit comments

Comments
 (0)