Skip to content

Commit c0903b8

Browse files
Bordaomry
andauthored
past checkpoints (Lightning-AI#2160)
* past checkpoints * omegaConf save * enforce type * resolve=True Co-authored-by: Omry Yadan <omry@fb.com> * test omegaconf * tests * test past Co-authored-by: Omry Yadan <omry@fb.com>
1 parent c826a5f commit c0903b8

File tree

5 files changed

+176
-134
lines changed

5 files changed

+176
-134
lines changed

pytorch_lightning/core/saving.py

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,17 @@
2121
else:
2222
ALLOWED_CONFIG_TYPES = ALLOWED_CONFIG_TYPES + (Container, )
2323

24+
# the older shall be on the top
25+
CHECKPOINT_PAST_HPARAMS_KEYS = (
26+
'hparams',
27+
'module_arguments', # used in 0.7.6
28+
)
29+
2430

2531
class ModelIO(object):
26-
CHECKPOINT_KEY_HYPER_PARAMS = 'hyper_parameters'
27-
CHECKPOINT_NAME_HYPER_PARAMS = 'hparams_name'
32+
CHECKPOINT_HYPER_PARAMS_KEY = 'hyper_parameters'
33+
CHECKPOINT_HYPER_PARAMS_NAME = 'hparams_name'
34+
CHECKPOINT_HYPER_PARAMS_TYPE = 'hparams_type'
2835

2936
@classmethod
3037
def load_from_metrics(cls, weights_path, tags_csv, map_location=None):
@@ -153,21 +160,29 @@ def load_from_checkpoint(
153160
hparams['on_gpu'] = False
154161

155162
# overwrite hparams by the given file
156-
checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS] = hparams
163+
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = hparams
157164

158-
# override the module_arguments with values that were passed in
159-
checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS].update(kwargs)
165+
# for past checkpoint need to add the new key
166+
if cls.CHECKPOINT_HYPER_PARAMS_KEY not in checkpoint:
167+
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY] = {}
168+
# override the hparams with values that were passed in
169+
checkpoint[cls.CHECKPOINT_HYPER_PARAMS_KEY].update(kwargs)
160170

161171
model = cls._load_model_state(checkpoint, *args, **kwargs)
162172
return model
163173

164174
@classmethod
165175
def _load_model_state(cls, checkpoint: Dict[str, Any], *args, **kwargs):
166176
# pass in the values we saved automatically
167-
if cls.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
168-
# todo add some back compatibility
169-
model_args = checkpoint[cls.CHECKPOINT_KEY_HYPER_PARAMS]
170-
args_name = checkpoint.get(cls.CHECKPOINT_NAME_HYPER_PARAMS)
177+
if cls.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
178+
model_args = {}
179+
# add some back compatibility, the actual one shall be last
180+
for hparam_key in CHECKPOINT_PAST_HPARAMS_KEYS + (cls.CHECKPOINT_HYPER_PARAMS_KEY,):
181+
if hparam_key in checkpoint:
182+
model_args.update(checkpoint[hparam_key])
183+
if cls.CHECKPOINT_HYPER_PARAMS_TYPE in checkpoint:
184+
model_args = checkpoint[cls.CHECKPOINT_HYPER_PARAMS_TYPE](model_args)
185+
args_name = checkpoint.get(cls.CHECKPOINT_HYPER_PARAMS_NAME)
171186
init_args_name = inspect.signature(cls).parameters.keys()
172187
if args_name == 'kwargs':
173188
cls_kwargs = {k: v for k, v in model_args.items() if k in init_args_name}

pytorch_lightning/loggers/tensorboard.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
"""
55

66
import os
7-
import yaml
87
from argparse import Namespace
98
from typing import Optional, Dict, Union, Any
109
from warnings import warn
@@ -18,6 +17,11 @@
1817
from pytorch_lightning.loggers.base import LightningLoggerBase
1918
from pytorch_lightning.utilities import rank_zero_only
2019

20+
try:
21+
from omegaconf import Container
22+
except ImportError:
23+
Container = None
24+
2125

2226
class TensorBoardLogger(LightningLoggerBase):
2327
r"""
@@ -152,7 +156,14 @@ def save(self) -> None:
152156
hparams_file = os.path.join(dir_path, self.NAME_HPARAMS_FILE)
153157

154158
# save the metatags file
155-
save_hparams_to_yaml(hparams_file, self.hparams)
159+
if Container is not None:
160+
if isinstance(self.hparams, Container):
161+
from omegaconf import OmegaConf
162+
OmegaConf.save(self.hparams, hparams_file, resolve=True)
163+
else:
164+
save_hparams_to_yaml(hparams_file, self.hparams)
165+
else:
166+
save_hparams_to_yaml(hparams_file, self.hparams)
156167

157168
@rank_zero_only
158169
def finalize(self, status: str) -> None:

pytorch_lightning/trainer/training_io.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -87,20 +87,20 @@
8787
import re
8888
import signal
8989
from abc import ABC
90-
from argparse import Namespace
9190
from subprocess import call
9291

9392
import torch
9493
import torch.distributed as torch_distrib
9594

95+
import pytorch_lightning
9696
from pytorch_lightning import _logger as log
9797
from pytorch_lightning.core.lightning import LightningModule
9898
from pytorch_lightning.loggers import LightningLoggerBase
9999
from pytorch_lightning.overrides.data_parallel import (
100100
LightningDistributedDataParallel,
101101
LightningDataParallel,
102102
)
103-
from pytorch_lightning.utilities import rank_zero_warn, parsing
103+
from pytorch_lightning.utilities import rank_zero_warn
104104
from pytorch_lightning.utilities.io import load as pl_load
105105

106106
try:
@@ -119,6 +119,11 @@
119119
else:
120120
HOROVOD_AVAILABLE = True
121121

122+
try:
123+
from omegaconf import Container
124+
except ImportError:
125+
Container = None
126+
122127

123128
class TrainerIOMixin(ABC):
124129

@@ -267,8 +272,8 @@ def save_checkpoint(self, filepath, weights_only: bool = False):
267272
try:
268273
self._atomic_save(checkpoint, filepath)
269274
except AttributeError as err:
270-
if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
271-
del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]
275+
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
276+
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
272277
rank_zero_warn('Warning, `module_arguments` dropped from checkpoint.'
273278
f' An attribute is not picklable {err}')
274279
self._atomic_save(checkpoint, filepath)
@@ -320,6 +325,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
320325
checkpoint = {
321326
'epoch': self.current_epoch + 1,
322327
'global_step': self.global_step + 1,
328+
'pytorch-ligthning_version': pytorch_lightning.__version__,
323329
}
324330

325331
if not weights_only:
@@ -356,10 +362,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
356362

357363
if model.hparams:
358364
if hasattr(model, '_hparams_name'):
359-
checkpoint[LightningModule.CHECKPOINT_NAME_HYPER_PARAMS] = model._hparams_name
365+
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
360366
# add arguments to the checkpoint
361-
# todo: add some recursion in case of OmegaConf
362-
checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS] = dict(model.hparams)
367+
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY] = model.hparams
368+
if Container is not None:
369+
if isinstance(model.hparams, Container):
370+
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_TYPE] = type(model.hparams)
363371

364372
# give the model a chance to add a few things
365373
model.on_save_checkpoint(checkpoint)
@@ -473,8 +481,8 @@ def hpc_save(self, folderpath: str, logger):
473481
try:
474482
self._atomic_save(checkpoint, filepath)
475483
except AttributeError as err:
476-
if LightningModule.CHECKPOINT_KEY_HYPER_PARAMS in checkpoint:
477-
del checkpoint[LightningModule.CHECKPOINT_KEY_HYPER_PARAMS]
484+
if LightningModule.CHECKPOINT_HYPER_PARAMS_KEY in checkpoint:
485+
del checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_KEY]
478486
rank_zero_warn('warning, `module_arguments` dropped from checkpoint.'
479487
f' An attribute is not picklable {err}')
480488
self._atomic_save(checkpoint, filepath)

0 commit comments

Comments
 (0)