Skip to content

Commit 8b9b923

Browse files
Fabio Natanael KeplerBorda
andauthored
Keep track of the best model's path saved by ModelCheckpoint (Lightning-AI#1799)
* Add an additional attribute to ModelCheckpoint to keep track of the best model's path Currently, only the best metric value is directly tracked. This new attribute will help in uses cases where the trained model needs to be used or tracked right after training. * Add small description and usage example to docs * Fix PEP8 issues * Fix doctest example * Fix expected output in doctest * Apply suggestions from code review * Show example as code block instead of doctest * Apply suggestions from code review * Update CHANGELOG.md * Rename `ModelCheckpoint.best` to `ModelCheckpoint.best_model_score` Also rename `ModelCheckpoint.best_model` (added in this PR) to `ModelCheckpoint.best_model_path`, for consistency, and `kth_best_model` to `kth_best_model_path`. * Update pytorch_lightning/trainer/training_io.py Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com> * Add warning when loading checkpoint from an old version Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
1 parent 55fdfe3 commit 8b9b923

File tree

3 files changed

+56
-16
lines changed

3 files changed

+56
-16
lines changed

CHANGELOG.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1818

1919
- Early stopping checks `on_validation_end` ([#1458](https://github.com/PyTorchLightning/pytorch-lightning/pull/1458))
2020

21+
- Attribute `best_model_path` to `ModelCheckpoint` for storing and later retrieving the path to the best saved model file ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
22+
2123
### Changed
2224

2325
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))
@@ -26,10 +28,14 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2628

2729
- Allow passing model hyperparameters as complete kwarg list ([#1896](https://github.com/PyTorchLightning/pytorch-lightning/pull/1896))
2830

31+
- Renamed `ModelCheckpoint`'s attributes `best` to `best_model_score` and `kth_best_model` to `kth_best_model_path` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
32+
2933
- Re-Enable Logger's `ImportError`s ([#1938](https://github.com/PyTorchLightning/pytorch-lightning/pull/1938))
3034

3135
### Deprecated
3236

37+
- Deprecated `ModelCheckpoint`'s attributes `best` and `kth_best_model` ([#1799](https://github.com/PyTorchLightning/pytorch-lightning/pull/1799))
38+
3339
- Dropped official support/testing for older PyTorch versions <1.3 ([#1917](https://github.com/PyTorchLightning/pytorch-lightning/pull/1917))
3440

3541
### Removed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 35 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
class ModelCheckpoint(Callback):
2222
r"""
23-
Save the model after every epoch.
23+
Save the model after every epoch if it improves.
24+
25+
After training finishes, use :attr:`best_model_path` to retrieve the path to the
26+
best checkpoint file and :attr:`best_model_score` to retrieve its score.
2427
2528
Args:
2629
filepath: path to save the model file.
@@ -81,6 +84,13 @@ class ModelCheckpoint(Callback):
8184
... filepath='my/path/sample-mnist_{epoch:02d}-{val_loss:.2f}'
8285
... )
8386
87+
# retrieve the best checkpoint after training
88+
checkpoint_callback = ModelCheckpoint(filepath='my/path/')
89+
trainer = Trainer(checkpoint_callback=checkpoint_callback)
90+
model = ...
91+
trainer.fit(model)
92+
checkpoint_callback.best_model_path
93+
8494
"""
8595

8696
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
@@ -112,8 +122,9 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
112122
self.prefix = prefix
113123
self.best_k_models = {}
114124
# {filename: monitor}
115-
self.kth_best_model = ''
116-
self.best = 0
125+
self.kth_best_model_path = ''
126+
self.best_model_score = 0
127+
self.best_model_path = ''
117128
self.save_function = None
118129

119130
torch_inf = torch.tensor(np.Inf)
@@ -131,6 +142,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
131142

132143
self.kth_value, self.mode = mode_dict[mode]
133144

145+
@property
146+
def best(self):
147+
rank_zero_warn("Attribute `best` has been renamed to `best_model_score` since v0.8.0"
148+
" and will be removed in v0.10.0", DeprecationWarning)
149+
return self.best_model_score
150+
151+
@property
152+
def kth_best_model(self):
153+
rank_zero_warn("Attribute `kth_best_model` has been renamed to `kth_best_model_path` since v0.8.0"
154+
" and will be removed in v0.10.0", DeprecationWarning)
155+
return self.kth_best_model_path
156+
134157
def _del_model(self, filepath):
135158
if os.path.isfile(filepath):
136159
os.remove(filepath)
@@ -162,7 +185,7 @@ def check_monitor_top_k(self, current):
162185
"max": torch.gt,
163186
}[self.mode]
164187

165-
return monitor_op(current, self.best_k_models[self.kth_best_model])
188+
return monitor_op(current, self.best_k_models[self.kth_best_model_path])
166189

167190
def format_checkpoint_name(self, epoch, metrics, ver=None):
168191
"""Generate a filename according to the defined template.
@@ -258,25 +281,26 @@ def _do_check_save(self, filepath, current, epoch):
258281

259282
del_list = []
260283
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
261-
delpath = self.kth_best_model
262-
self.best_k_models.pop(self.kth_best_model)
284+
delpath = self.kth_best_model_path
285+
self.best_k_models.pop(self.kth_best_model_path)
263286
del_list.append(delpath)
264287

265288
self.best_k_models[filepath] = current
266289
if len(self.best_k_models) == self.save_top_k:
267290
# monitor dict has reached k elements
268291
_op = max if self.mode == 'min' else min
269-
self.kth_best_model = _op(self.best_k_models,
270-
key=self.best_k_models.get)
271-
self.kth_value = self.best_k_models[self.kth_best_model]
292+
self.kth_best_model_path = _op(self.best_k_models,
293+
key=self.best_k_models.get)
294+
self.kth_value = self.best_k_models[self.kth_best_model_path]
272295

273296
_op = min if self.mode == 'min' else max
274-
self.best = _op(self.best_k_models.values())
297+
self.best_model_path = _op(self.best_k_models, key=self.best_k_models.get)
298+
self.best_model_score = self.best_k_models[self.best_model_path]
275299

276300
if self.verbose > 0:
277301
log.info(
278302
f'\nEpoch {epoch:05d}: {self.monitor} reached'
279-
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
303+
f' {current:0.5f} (best {self.best_model_score:0.5f}), saving model to'
280304
f' {filepath} as top {self.save_top_k}')
281305
self._save_model(filepath)
282306

pytorch_lightning/trainer/training_io.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,8 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
330330

331331
if not weights_only:
332332
if self.checkpoint_callback:
333-
checkpoint['checkpoint_callback_best'] = self.checkpoint_callback.best
333+
checkpoint['checkpoint_callback_best_model_score'] = self.checkpoint_callback.best_model_score
334+
checkpoint['checkpoint_callback_best_model_path'] = self.checkpoint_callback.best_model_path
334335

335336
if self.early_stop_callback:
336337
checkpoint['early_stop_callback_wait'] = self.early_stop_callback.wait
@@ -401,10 +402,19 @@ def restore_training_state(self, checkpoint):
401402
' This is probably due to `ModelCheckpoint.save_weights_only` being set to `True`.'
402403
)
403404

404-
if self.checkpoint_callback is not None and self.checkpoint_callback is not False:
405-
self.checkpoint_callback.best = checkpoint['checkpoint_callback_best']
406-
407-
if self.early_stop_callback is not None and self.early_stop_callback is not False:
405+
if self.checkpoint_callback:
406+
if 'checkpoint_callback_best_model_score' in checkpoint:
407+
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best_model_score']
408+
else:
409+
# Old naming until version 0.7.6
410+
rank_zero_warn(
411+
'Loading a checkpoint created with an old version of Lightning; '
412+
'this will not be supported in the future.'
413+
)
414+
self.checkpoint_callback.best_model_score = checkpoint['checkpoint_callback_best']
415+
self.checkpoint_callback.best_model_path = checkpoint['checkpoint_callback_best_model_path']
416+
417+
if self.early_stop_callback:
408418
self.early_stop_callback.wait = checkpoint['early_stop_callback_wait']
409419
self.early_stop_callback.patience = checkpoint['early_stop_callback_patience']
410420

0 commit comments

Comments
 (0)