Skip to content

Commit e977d1c

Browse files
SkafteNickiNicki Skafte
andauthored
Default value for ModelCheckpoint filepath (Lightning-AI#1548)
* allow determine of filepath at runtime * typing Co-authored-by: Nicki Skafte <nugginea@gmail.com>
1 parent 545b38e commit e977d1c

File tree

3 files changed

+50
-14
lines changed

3 files changed

+50
-14
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import re
1111

1212
import numpy as np
13+
from typing import Optional
1314

1415
from pytorch_lightning import _logger as log
1516
from pytorch_lightning.callbacks.base import Callback
@@ -37,6 +38,9 @@ class ModelCheckpoint(Callback):
3738
... filepath='my/path/{epoch}-{val_loss:.2f}-{other_metric:.2f}'
3839
... )
3940
41+
Can also be set to `None`, then it will be set to default location
42+
during trainer construction.
43+
4044
monitor: quantity to monitor.
4145
verbose: verbosity mode. Default: ``False``.
4246
save_top_k: if `save_top_k == k`,
@@ -78,7 +82,7 @@ class ModelCheckpoint(Callback):
7882
7983
"""
8084

81-
def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = False,
85+
def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', verbose: bool = False,
8286
save_top_k: int = 1, save_weights_only: bool = False,
8387
mode: str = 'auto', period: int = 1, prefix: str = ''):
8488
super().__init__()
@@ -90,12 +94,14 @@ def __init__(self, filepath: str, monitor: str = 'val_loss', verbose: bool = Fal
9094

9195
self.monitor = monitor
9296
self.verbose = verbose
93-
if os.path.isdir(filepath):
94-
self.dirpath, self.filename = filepath, '{epoch}'
97+
if filepath is None: # will be determined by trainer at runtime
98+
self.dirpath, self.filename = None, None
9599
else:
96-
self.dirpath, self.filename = os.path.split(filepath)
97-
98-
os.makedirs(self.dirpath, exist_ok=True)
100+
if os.path.isdir(filepath):
101+
self.dirpath, self.filename = filepath, '{epoch}'
102+
else:
103+
self.dirpath, self.filename = os.path.split(filepath)
104+
os.makedirs(self.dirpath, exist_ok=True)
99105
self.save_top_k = save_top_k
100106
self.save_weights_only = save_weights_only
101107
self.period = period

pytorch_lightning/trainer/callback_config.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def configure_checkpoint_callback(self):
3333
Otherwise use os.getcwd()
3434
"""
3535
ckpt_path = self.default_root_dir
36-
if self.checkpoint_callback is True:
36+
if self.checkpoint_callback:
3737
# init a default one
3838
if self.logger is not None:
3939
save_dir = (getattr(self.logger, 'save_dir', None) or
@@ -57,12 +57,18 @@ def configure_checkpoint_callback(self):
5757
train_step_only = not self.is_overriden('validation_step')
5858
monitor_key = 'loss' if train_step_only else 'val_loss'
5959

60-
self.ckpt_path = ckpt_path
61-
os.makedirs(ckpt_path, exist_ok=True)
62-
self.checkpoint_callback = ModelCheckpoint(
63-
filepath=ckpt_path,
64-
monitor=monitor_key
65-
)
60+
if self.checkpoint_callback is True:
61+
os.makedirs(ckpt_path, exist_ok=True)
62+
self.checkpoint_callback = ModelCheckpoint(
63+
filepath=ckpt_path,
64+
monitor=monitor_key
65+
)
66+
# If user specified None in filepath, override with runtime default
67+
elif isinstance(self.checkpoint_callback, ModelCheckpoint) \
68+
and self.checkpoint_callback.dirpath is None:
69+
self.checkpoint_callback.dirpath = ckpt_path
70+
self.checkpoint_callback.filename = '{epoch}'
71+
os.makedirs(self.checkpoint_callback.dirpath, exist_ok=True)
6672
elif self.checkpoint_callback is False:
6773
self.checkpoint_callback = None
6874

tests/trainer/test_callbacks.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import tests.base.utils as tutils
22
from pytorch_lightning import Callback
33
from pytorch_lightning import Trainer, LightningModule
4-
from pytorch_lightning.callbacks import EarlyStopping
4+
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
55
from tests.base import (
66
LightTrainDataloader,
77
LightTestMixin,
@@ -181,3 +181,27 @@ def training_step(self, *args, **kwargs):
181181

182182
assert result == 1, 'training failed to complete'
183183
assert trainer.current_epoch < trainer.max_epochs
184+
185+
186+
def test_model_checkpoint_with_non_string_input(tmpdir):
187+
""" Test that None in checkpoint callback is valid and that chkp_path is
188+
set correctly """
189+
tutils.reset_seed()
190+
191+
class CurrentTestModel(LightTrainDataloader, TestModelBase):
192+
pass
193+
194+
hparams = tutils.get_default_hparams()
195+
model = CurrentTestModel(hparams)
196+
197+
checkpoint = ModelCheckpoint(filepath=None, save_top_k=-1)
198+
199+
trainer = Trainer(default_root_dir=tmpdir,
200+
checkpoint_callback=checkpoint,
201+
overfit_pct=0.20,
202+
max_epochs=5
203+
)
204+
result = trainer.fit(model)
205+
206+
# These should be different if the dirpath has be overridden
207+
assert trainer.ckpt_path != trainer.default_root_dir

0 commit comments

Comments
 (0)