Skip to content

Commit 9b31272

Browse files
Ir1dBordawilliamFalcon
authored
feat: save checkpoint before deleting old ones (Lightning-AI#1453)
* feat: save checkpoint before deleting old ones * fix: make sure that the new model is not deleted * changelog Co-authored-by: J. Borovec <jirka.borovec@seznam.cz> Co-authored-by: William Falcon <waf2107@columbia.edu>
1 parent 2ab2f7d commit 9b31272

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3434

3535
### Fixed
3636

37+
38+
- Fixed saving checkpoint before deleting old ones ([#1453](https://github.com/PyTorchLightning/pytorch-lightning/pull/1453))
39+
3740
- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))
3841

3942
- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,10 +216,12 @@ def on_validation_end(self, trainer, pl_module):
216216

217217
def _do_check_save(self, filepath, current, epoch):
218218
# remove kth
219+
220+
del_list = []
219221
if len(self.best_k_models) == self.save_top_k and self.save_top_k > 0:
220222
delpath = self.kth_best_model
221223
self.best_k_models.pop(self.kth_best_model)
222-
self._del_model(delpath)
224+
del_list.append(delpath)
223225

224226
self.best_k_models[filepath] = current
225227
if len(self.best_k_models) == self.save_top_k:
@@ -238,3 +240,7 @@ def _do_check_save(self, filepath, current, epoch):
238240
f' {current:0.5f} (best {self.best:0.5f}), saving model to'
239241
f' {filepath} as top {self.save_top_k}')
240242
self._save_model(filepath)
243+
244+
for cur_path in del_list:
245+
if cur_path != filepath:
246+
self._del_model(cur_path)

0 commit comments

Comments
 (0)