Skip to content

Commit 112dd5c

Browse files
lgvazjeremyjordan
andauthored
Adds the option of saving the last model on checkpoint (Lightning-AI#1908)
* saves model every epoch * implement test for save_last * Update CHANGELOG.md * Update CHANGELOG.md * changes test description Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com> Co-authored-by: Jeremy Jordan <13970565+jeremyjordan@users.noreply.github.com>
1 parent a34eb9e commit 112dd5c

File tree

3 files changed

+20
-9
lines changed

3 files changed

+20
-9
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1414

1515
- Allow dataloaders without sampler field present ([#1907](https://github.com/PyTorchLightning/pytorch-lightning/pull/1907))
1616

17+
- Added option `save_last` to save the model at the end of every epoch in `ModelCheckpoint` [(#1908)](https://github.com/PyTorchLightning/pytorch-lightning/pull/1908)
18+
1719
### Changed
1820

1921
- Allow user to select individual TPU core to train on ([#1729](https://github.com/PyTorchLightning/pytorch-lightning/pull/1729))

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ class ModelCheckpoint(Callback):
4343
4444
monitor: quantity to monitor.
4545
verbose: verbosity mode. Default: ``False``.
46+
save_last: always saves the model at the end of the epoch. Default: ``False``.
4647
save_top_k: if `save_top_k == k`,
4748
the best k models according to
4849
the quantity monitored will be saved.
@@ -83,7 +84,7 @@ class ModelCheckpoint(Callback):
8384
"""
8485

8586
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
86-
save_top_k: int = 1, save_weights_only: bool = False,
87+
save_last: bool = False, save_top_k: int = 1, save_weights_only: bool = False,
8788
mode: str = 'auto', period: int = 1, prefix: str = ''):
8889
super().__init__()
8990
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
@@ -103,6 +104,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
103104
else:
104105
self.dirpath, self.filename = os.path.split(filepath)
105106
os.makedirs(self.dirpath, exist_ok=True)
107+
self.save_last = save_last
106108
self.save_top_k = save_top_k
107109
self.save_weights_only = save_weights_only
108110
self.period = period
@@ -217,6 +219,10 @@ def on_validation_end(self, trainer, pl_module):
217219

218220
self.epoch_last_check = epoch
219221

222+
if self.save_last:
223+
filepath = os.path.join(self.dirpath, self.prefix + 'last.ckpt')
224+
self._save_model(filepath)
225+
220226
filepath = self.format_checkpoint_name(epoch, metrics)
221227
version_cnt = 0
222228
while os.path.isfile(filepath):

tests/trainer/test_trainer.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -229,19 +229,21 @@ def test_dp_output_reduce():
229229
assert reduced['b']['c'] == out['b']['c']
230230

231231

232-
@pytest.mark.parametrize(["save_top_k", "file_prefix", "expected_files"], [
233-
pytest.param(-1, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'},
232+
@pytest.mark.parametrize(["save_top_k", "save_last", "file_prefix", "expected_files"], [
233+
pytest.param(-1, False, '', {'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt', 'epoch=1.ckpt', 'epoch=0.ckpt'},
234234
id="CASE K=-1 (all)"),
235-
pytest.param(1, 'test_prefix_', {'test_prefix_epoch=4.ckpt'},
235+
pytest.param(1, False, 'test_prefix_', {'test_prefix_epoch=4.ckpt'},
236236
id="CASE K=1 (2.5, epoch 4)"),
237-
pytest.param(2, '', {'epoch=4.ckpt', 'epoch=2.ckpt'},
237+
pytest.param(2, False, '', {'epoch=4.ckpt', 'epoch=2.ckpt'},
238238
id="CASE K=2 (2.5 epoch 4, 2.8 epoch 2)"),
239-
pytest.param(4, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'},
239+
pytest.param(4, False, '', {'epoch=1.ckpt', 'epoch=4.ckpt', 'epoch=3.ckpt', 'epoch=2.ckpt'},
240240
id="CASE K=4 (save all 4 base)"),
241-
pytest.param(3, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'},
241+
pytest.param(3, False, '', {'epoch=2.ckpt', 'epoch=3.ckpt', 'epoch=4.ckpt'},
242242
id="CASE K=3 (save the 2nd, 3rd, 4th model)"),
243+
pytest.param(1, True, '', {'epoch=4.ckpt', 'last.ckpt'},
244+
id="CASE K=1 (save the 4th model and the last model)"),
243245
])
244-
def test_model_checkpoint_options(tmpdir, save_top_k, file_prefix, expected_files):
246+
def test_model_checkpoint_options(tmpdir, save_top_k, save_last, file_prefix, expected_files):
245247
"""Test ModelCheckpoint options."""
246248

247249
def mock_save_function(filepath, *args):
@@ -250,7 +252,8 @@ def mock_save_function(filepath, *args):
250252
# simulated losses
251253
losses = [10, 9, 2.8, 5, 2.5]
252254

253-
checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, prefix=file_prefix, verbose=1)
255+
checkpoint_callback = ModelCheckpoint(tmpdir, save_top_k=save_top_k, save_last=save_last,
256+
prefix=file_prefix, verbose=1)
254257
checkpoint_callback.save_function = mock_save_function
255258
trainer = Trainer()
256259

0 commit comments

Comments
 (0)