Skip to content

Commit bcb45d9

Browse files
williamFalconBorda
andauthored
proper checkpoint implementation (Lightning-AI#1043)
* enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * enabled early stopping/checkpooiunt even without val step * name formatting * version * testing * add test * fix test * Update model_checkpoint.py * doctests * pylint * tests * debug * debug * enabled early stopping/checkpooiunt even without val step * fix MNIST download (Lightning-AI#1044) * fix MNIST download * simple * name formatting * version * testing * add test * fix test * doctests * tests * debug * debug * rebased 1041 * rebased 1041 * tests * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 * rebased 1041 Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 165b9fb commit bcb45d9

File tree

14 files changed

+208
-194
lines changed

14 files changed

+208
-194
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2525
- Support for user defined callbacks ([#889](https://github.com/PyTorchLightning/pytorch-lightning/pull/889) and [#950](https://github.com/PyTorchLightning/pytorch-lightning/pull/950))
2626
- Added support for multiple loggers to be passed to `Trainer` as an iterable (e.g. list, tuple, etc.) ([#903](https://github.com/PyTorchLightning/pytorch-lightning/pull/903))
2727
- Added support for logging hparams as dict ([#1029](https://github.com/PyTorchLightning/pytorch-lightning/pull/1029))
28+
- Checkpoint and early stopping now work without val step ([#1041](https://github.com/PyTorchLightning/pytorch-lightning/pull/1041))
2829

2930
### Changed
3031

Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,36 @@
1-
r"""
2-
Model Checkpoint
3-
==============
4-
Save the model as often as requested.
5-
6-
"""
7-
81
import os
9-
import glob
2+
import shutil
103
import logging as log
114
import warnings
5+
import re
126

137
import numpy as np
148

15-
from .base import Callback
9+
from pytorch_lightning.callbacks.base import Callback
1610

1711

1812
class ModelCheckpoint(Callback):
1913
r"""
2014
Save the model after every epoch.
2115
2216
Args:
23-
dirpath: path to save the model file.
17+
filepath: path to save the model file.
2418
Can contain named formatting options to be auto-filled.
2519
2620
Example::
2721
28-
# save epoch and val_loss in name
29-
ModelCheckpoint(filepath='{epoch:02d}-{val_loss:.2f}.hdf5')
22+
# no path
23+
ModelCheckpoint()
24+
# saves like /my/path/epoch_0.ckpt
3025
31-
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
32-
# if model already exits, the file will be: /my/path/here/sample-mnist-v0_epoch=02_val_loss=0.32.ckpt
26+
# save any arbitrary metrics like and val_loss, etc in name
27+
ModelCheckpoint(filepath='/my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}')
28+
# saves file like: /my/path/epoch=2-val_loss=0.2_other_metric=0.3.ckpt
3329
3430
35-
monitor: quantity to monitor.
36-
verbose: verbosity mode, False or True.
37-
save_top_k: if `save_top_k == k`,
31+
monitor (str): quantity to monitor.
32+
verbose (bool): verbosity mode, False or True.
33+
save_top_k (int): if `save_top_k == k`,
3834
the best k models according to
3935
the quantity monitored will be saved.
4036
if ``save_top_k == 0``, no models are saved.
@@ -43,54 +39,51 @@ class ModelCheckpoint(Callback):
4339
if ``save_top_k >= 2`` and the callback is called multiple
4440
times inside an epoch, the name of the saved file will be
4541
appended with a version count starting with `v0`.
46-
mode: one of {auto, min, max}.
42+
mode (str): one of {auto, min, max}.
4743
If ``save_top_k != 0``, the decision
4844
to overwrite the current save file is made
4945
based on either the maximization or the
5046
minimization of the monitored quantity. For `val_acc`,
5147
this should be `max`, for `val_loss` this should
5248
be `min`, etc. In `auto` mode, the direction is
5349
automatically inferred from the name of the monitored quantity.
54-
save_weights_only: if True, then only the model's weights will be
50+
save_weights_only (bool): if True, then only the model's weights will be
5551
saved (`model.save_weights(filepath)`), else the full model
5652
is saved (`model.save(filepath)`).
57-
period: Interval (number of epochs) between checkpoints.
58-
prefix: String name for particular model
53+
period (int): Interval (number of epochs) between checkpoints.
5954
60-
Example:
55+
Example::
6156
6257
from pytorch_lightning import Trainer
6358
from pytorch_lightning.callbacks import ModelCheckpoint
6459
6560
# saves checkpoints to my_path whenever 'val_loss' has a new min
66-
checkpoint_callback = ModelCheckpoint('my_path')
61+
checkpoint_callback = ModelCheckpoint(filepath='my_path')
6762
Trainer(checkpoint_callback=checkpoint_callback)
63+
64+
# save epoch and val_loss in name
65+
ModelCheckpoint(filepath='/my/path/here/sample-mnist_{epoch:02d}-{val_loss:.2f}')
66+
# saves file like: /my/path/here/sample-mnist_epoch=02_val_loss=0.32.ckpt
6867
"""
69-
#: checkpoint extension
70-
EXTENSION = '.ckpt'
71-
72-
def __init__(
73-
self,
74-
dirpath: str,
75-
monitor: str = 'val_loss',
76-
verbose: bool = False,
77-
save_top_k: int = 1,
78-
save_weights_only: bool = False,
79-
mode: str = 'auto',
80-
period: int = 1,
81-
prefix: str = ''
82-
):
68+
69+
def __init__(self, filepath, monitor: str = 'val_loss', verbose: bool = False,
70+
save_top_k: int = 1, save_weights_only: bool = False,
71+
mode: str = 'auto', period: int = 1, prefix: str = ''):
8372
super().__init__()
84-
if save_top_k and os.path.isdir(dirpath) and len(os.listdir(dirpath)) > 0:
73+
if save_top_k and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
8574
warnings.warn(
86-
f"Checkpoint directory {dirpath} exists and is not empty with save_top_k != 0."
75+
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
8776
"All files in this directory will be deleted when a checkpoint is saved!"
8877
)
8978

9079
self.monitor = monitor
9180
self.verbose = verbose
92-
self.dirpath = dirpath
93-
os.makedirs(dirpath, exist_ok=True)
81+
if os.path.isdir(filepath):
82+
self.dirpath, self.filename = filepath, '{epoch}'
83+
else:
84+
self.dirpath, self.filename = os.path.split(filepath)
85+
86+
os.makedirs(self.dirpath, exist_ok=True)
9487
self.save_top_k = save_top_k
9588
self.save_weights_only = save_weights_only
9689
self.period = period
@@ -102,14 +95,6 @@ def __init__(
10295
self.best = 0
10396
self.save_function = None
10497

105-
# this create unique prefix if the give already exists
106-
existing_checkpoints = sorted(glob.glob(os.path.join(self.dirpath, '*' + self.EXTENSION)))
107-
existing_names = set(os.path.basename(ckpt).split('_epoch=')[0] for ckpt in existing_checkpoints)
108-
version_cnt = 0
109-
while self.prefix in existing_names:
110-
self.prefix = f'{prefix}-v{version_cnt}'
111-
version_cnt += 1
112-
11398
mode_dict = {
11499
'min': (np.less, np.Inf, 'min'),
115100
'max': (np.greater, -np.Inf, 'max'),
@@ -125,39 +110,65 @@ def __init__(
125110

126111
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
127112

128-
def _del_model(self, filepath: str) -> None:
129-
# shutil.rmtree(filepath)
113+
def _del_model(self, filepath):
130114
os.remove(filepath)
131115

132-
def _save_model(self, filepath: str) -> None:
116+
def _save_model(self, filepath):
133117
# make paths
134-
os.makedirs(self.dirpath, exist_ok=True)
118+
os.makedirs(os.path.dirname(filepath), exist_ok=True)
135119

136120
# delegate the saving to the model
137121
if self.save_function is not None:
138122
self.save_function(filepath)
139123
else:
140-
raise ValueError("Method `.save_function()` not set")
124+
raise ValueError(".save_function() not set")
141125

142-
def check_monitor_top_k(self, current: float) -> bool:
126+
def check_monitor_top_k(self, current):
143127
less_than_k_models = len(self.best_k_models) < self.save_top_k
144128
if less_than_k_models:
145129
return True
146130
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
147131

148-
def _get_available_filepath(self, current: float, epoch: int) -> str:
149-
current_str = f'{current:.2f}' if current else 'NaN'
150-
fname = f'{self.prefix}_epoch={epoch}_{self.monitor}={current_str}'
151-
filepath = os.path.join(self.dirpath, fname + self.EXTENSION)
152-
assert not os.path.isfile(filepath)
132+
def format_checkpoint_name(self, epoch, metrics, ver=None):
133+
"""Generate a filename according define template.
134+
135+
Examples
136+
--------
137+
>>> tmpdir = os.path.dirname(__file__)
138+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}'))
139+
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
140+
'epoch=0.ckpt'
141+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch:03d}'))
142+
>>> os.path.basename(ckpt.format_checkpoint_name(5, {}))
143+
'epoch=005.ckpt'
144+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{epoch}-{val_loss:.2f}'))
145+
>>> os.path.basename(ckpt.format_checkpoint_name(2, dict(val_loss=0.123456)))
146+
'epoch=2-val_loss=0.12.ckpt'
147+
>>> ckpt = ModelCheckpoint(os.path.join(tmpdir, '{missing:d}'))
148+
>>> os.path.basename(ckpt.format_checkpoint_name(0, {}))
149+
'missing=0.ckpt'
150+
"""
151+
# check if user passed in keys to the string
152+
groups = re.findall(r'(\{.*?)[:\}]', self.filename)
153+
154+
if len(groups) == 0:
155+
# default name
156+
filename = f'{self.prefix}_ckpt_epoch_{epoch}'
157+
else:
158+
metrics['epoch'] = epoch
159+
filename = self.filename
160+
for tmp in groups:
161+
name = tmp[1:]
162+
filename = filename.replace(tmp, name + '={' + name)
163+
if name not in metrics:
164+
metrics[name] = 0
165+
filename = filename.format(**metrics)
166+
str_ver = f'_v{ver}' if ver is not None else ''
167+
filepath = os.path.join(self.dirpath, self.prefix + filename + str_ver + '.ckpt')
153168
return filepath
154169

155-
def on_validation_end(self, trainer, pl_module) -> None:
156-
# only run on main process
157-
if trainer.proc_rank != 0:
158-
return
159-
160-
logs = trainer.callback_metrics
170+
def on_validation_end(self, trainer, pl_module):
171+
metrics = trainer.callback_metrics
161172
epoch = trainer.current_epoch
162173
self.epochs_since_last_check += 1
163174

@@ -166,27 +177,36 @@ def on_validation_end(self, trainer, pl_module) -> None:
166177
return
167178
if self.epochs_since_last_check >= self.period:
168179
self.epochs_since_last_check = 0
169-
current = logs.get(self.monitor)
170-
filepath = self._get_available_filepath(current, epoch)
180+
181+
filepath = self.format_checkpoint_name(epoch, metrics)
182+
version_cnt = 0
183+
while os.path.isfile(filepath):
184+
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
185+
# this epoch called before
186+
version_cnt += 1
171187

172188
if self.save_top_k != -1:
189+
current = metrics.get(self.monitor)
173190

174191
if current is None:
175-
warnings.warn(f'Can save best model only with {self.monitor} available,'
176-
' skipping.', RuntimeWarning)
192+
warnings.warn(
193+
f'Can save best model only with {self.monitor} available,'
194+
' skipping.', RuntimeWarning)
177195
else:
178196
if self.check_monitor_top_k(current):
179197
self._do_check_save(filepath, current, epoch)
180198
else:
181199
if self.verbose > 0:
182-
log.info('Epoch %05d: %s was not in top %i', epoch, self.monitor, self.save_top_k)
200+
log.info(
201+
f'\nEpoch {epoch:05d}: {self.monitor}'
202+
f' was not in top {self.save_top_k}')
183203

184204
else:
185205
if self.verbose > 0:
186-
log.info('Epoch %05d: saving model to %s', epoch, filepath)
206+
log.info(f'\nEpoch {epoch:05d}: saving model to {filepath}')
187207
self._save_model(filepath)
188208

189-
def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
209+
def _do_check_save(self, filepath, current, epoch):
190210
# remove kth
191211
if len(self.best_k_models) == self.save_top_k:
192212
delpath = self.kth_best_model
@@ -205,6 +225,8 @@ def _do_check_save(self, filepath: str, current: float, epoch: int) -> None:
205225
self.best = _op(self.best_k_models.values())
206226

207227
if self.verbose > 0:
208-
log.info('Epoch {epoch:05d}: %s reached %0.5f (best %0.5f), saving model to %s as top %i',
209-
epoch, self.monitor, current, self.best, filepath, self.save_top_k)
228+
log.info(
229+
f'\nEpoch {epoch:05d}: {self.monitor} reached'
230+
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
231+
f' {filepath} as top {self.save_top_k}')
210232
self._save_model(filepath)

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -68,19 +68,7 @@ def __init__(self, *args, **kwargs):
6868
#: True if using amp
6969
self.use_amp = False
7070

71-
@property
72-
def hparams(self) -> Namespace:
73-
if not hasattr(self, '_hparams'):
74-
return Namespace()
75-
assert isinstance(self._hparams, dict)
76-
return Namespace(**self._hparams)
77-
78-
@hparams.setter
79-
def hparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
80-
"""Set the model hyper-parameters."""
81-
if isinstance(params, Namespace):
82-
params = vars(params)
83-
self._hparams = params
71+
self.hparams = None
8472

8573
def print(self, *args, **kwargs):
8674
r"""

pytorch_lightning/loggers/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ def _convert_params(self, params: Union[Dict[str, Any], Namespace]) -> Dict[str,
4646
# in case converting from namespace
4747
if isinstance(params, Namespace):
4848
params = vars(params)
49+
50+
if params is None:
51+
params = {}
52+
4953
return params
5054

5155
@abstractmethod

pytorch_lightning/trainer/callback_config.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,15 @@ def configure_checkpoint_callback(self):
4848
else:
4949
ckpt_path = os.path.join(self.default_save_path, "checkpoints")
5050

51+
# when no val step is defined, use 'loss' otherwise 'val_loss'
52+
train_step_only = not self.is_overriden('validation_step')
53+
monitor_key = 'loss' if train_step_only else 'val_loss'
54+
5155
self.ckpt_path = ckpt_path
56+
os.makedirs(ckpt_path, exist_ok=True)
5257
self.checkpoint_callback = ModelCheckpoint(
53-
dirpath=ckpt_path
58+
filepath=ckpt_path,
59+
monitor=monitor_key
5460
)
5561
elif self.checkpoint_callback is False:
5662
self.checkpoint_callback = None

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,6 @@ class TrainerEvaluationLoopMixin(ABC):
165165
process_output: ...
166166
training_tqdm_dict: ...
167167
proc_rank: int
168-
checkpoint_callback: ...
169168
current_epoch: int
170169
callback_metrics: ...
171170
test_dataloaders: DataLoader
@@ -377,11 +376,6 @@ def run_evaluation(self, test_mode: bool = False):
377376
# Validation/Test end callbacks
378377
if test_mode:
379378
self.on_test_end()
380-
else:
381-
# model checkpointing
382-
if self.checkpoint_callback is not None:
383-
self.checkpoint_callback.on_validation_end(self, self.get_model())
384-
self.on_validation_end()
385379

386380
def evaluation_forward(self, model, batch, batch_idx, dataloader_idx, test_mode: bool = False):
387381
# make dataloader_idx arg in validation_step optional

pytorch_lightning/trainer/trainer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,9 +1132,6 @@ def run_pretrain_routine(self, model: LightningModule):
11321132
# wait for all processes to catch up
11331133
torch_xla.core.xla_model.rendezvous("pl.Trainer.run_pretrain_routine")
11341134

1135-
# set up checkpoint callback
1136-
self.configure_checkpoint_callback()
1137-
11381135
# register auto-resubmit when on SLURM
11391136
self.register_slurm_signal_handlers()
11401137

@@ -1151,6 +1148,9 @@ def run_pretrain_routine(self, model: LightningModule):
11511148
# if cluster resets state, the model will update with the saved weights
11521149
self.model = model
11531150

1151+
# set up checkpoint callback
1152+
self.configure_checkpoint_callback()
1153+
11541154
# restore training and model before hpc call
11551155
self.restore_weights(model)
11561156

0 commit comments

Comments
 (0)