Skip to content

Commit 42d5cfc

Browse files
None check for filepath in ModelCheckpoint (Lightning-AI#1654)
Check if the optional filepath is None before checking if it exists Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent 9b86aea commit 42d5cfc

File tree

3 files changed

+10
-6
lines changed

3 files changed

+10
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1616

1717
### Fixed
1818

19+
- Fixed ModelCheckpoint not None checking filepath ([1654](https://github.com/PyTorchLightning/pytorch-lightning/pull/1654))
20+
1921

2022
## [0.7.5] - 2020-04-27
2123

2224
### Changed
23-
25+
2426
- Allow logging of metrics together with `hparams` ([#1630](https://github.com/PyTorchLightning/pytorch-lightning/pull/1630))
2527
- Allow metrics logged together with hparams ([#1630](https://github.com/PyTorchLightning/pytorch-lightning/pull/1630))
2628

@@ -51,7 +53,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5153
- Added `ddp_cpu` backend for testing ddp without GPUs ([#1158](https://github.com/PyTorchLightning/pytorch-lightning/pull/1158))
5254
- Added [Horovod](http://horovod.ai) support as a distributed backend `Trainer(distributed_backend='horovod')` ([#1529](https://github.com/PyTorchLightning/pytorch-lightning/pull/1529))
5355
- Added support for 8 core distributed training on Kaggle TPU's ([#1568](https://github.com/PyTorchLightning/pytorch-lightning/pull/1568))
54-
- Added support for native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), [#1580](https://github.com/PyTorchLightning/pytorch-lightning/pull/1580))
56+
- Added support for native AMP ([#1561](https://github.com/PyTorchLightning/pytorch-lightning/pull/1561), [#1580](https://github.com/PyTorchLightning/pytorch-lightning/pull/1580))
5557

5658
### Changed
5759

@@ -78,7 +80,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7880
- Fixed loggers - flushing last logged metrics even before continue, e.g. `trainer.test()` results ([#1459](https://github.com/PyTorchLightning/pytorch-lightning/pull/1459))
7981
- Fixed optimizer configuration when `configure_optimizers` returns dict without `lr_scheduler` ([#1443](https://github.com/PyTorchLightning/pytorch-lightning/pull/1443))
8082
- Fixed `LightningModule` - mixing hparams and arguments in `LightningModule.__init__()` crashes load_from_checkpoint() ([#1505](https://github.com/PyTorchLightning/pytorch-lightning/pull/1505))
81-
- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).
83+
- Added a missing call to the `on_before_zero_grad` model hook ([#1493](https://github.com/PyTorchLightning/pytorch-lightning/pull/1493)).
8284
- Allow use of sweeps with `WandbLogger` ([#1512](https://github.com/PyTorchLightning/pytorch-lightning/pull/1512))
8385
- Fixed a bug that caused the `callbacks` Trainer argument to reference a global variable ([#1534](https://github.com/PyTorchLightning/pytorch-lightning/pull/1534)).
8486
- Fixed a bug that set all boolean CLI arguments from `Trainer.add_argparse_args` always to True ([#1571](https://github.com/PyTorchLightning/pytorch-lightning/pull/1571))

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
8686
save_top_k: int = 1, save_weights_only: bool = False,
8787
mode: str = 'auto', period: int = 1, prefix: str = ''):
8888
super().__init__()
89-
if save_top_k > 0 and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
89+
if save_top_k > 0 and filepath is not None and os.path.isdir(filepath) and len(os.listdir(filepath)) > 0:
9090
rank_zero_warn(
9191
f"Checkpoint directory {filepath} exists and is not empty with save_top_k != 0."
9292
"All files in this directory will be deleted when a checkpoint is saved!"

tests/callbacks/test_callbacks.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pytest
12
import tests.base.utils as tutils
23
from pytorch_lightning import Callback
34
from pytorch_lightning import Trainer, LightningModule
@@ -249,7 +250,8 @@ def test_pickling(tmpdir):
249250
pickle.dumps(early_stopping)
250251

251252

252-
def test_model_checkpoint_with_non_string_input(tmpdir):
253+
@pytest.mark.parametrize('save_top_k', [-1, 0, 1, 2])
254+
def test_model_checkpoint_with_non_string_input(tmpdir, save_top_k):
253255
""" Test that None in checkpoint callback is valid and that chkp_path is
254256
set correctly """
255257
tutils.reset_seed()
@@ -260,7 +262,7 @@ class CurrentTestModel(LightTrainDataloader, TestModelBase):
260262
hparams = tutils.get_default_hparams()
261263
model = CurrentTestModel(hparams)
262264

263-
checkpoint = ModelCheckpoint(filepath=None, save_top_k=-1)
265+
checkpoint = ModelCheckpoint(filepath=None, save_top_k=save_top_k)
264266

265267
trainer = Trainer(default_root_dir=tmpdir,
266268
checkpoint_callback=checkpoint,

0 commit comments

Comments
 (0)