Skip to content

Commit 3453bba

Browse files
re-enabled naming metrics in ckpt name (Lightning-AI#3060)
* re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name * re-enabled naming metrics in ckpt name
1 parent cefc7f7 commit 3453bba

File tree

6 files changed

+69
-8
lines changed

6 files changed

+69
-8
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,10 +339,11 @@ def on_validation_end(self, trainer, pl_module):
339339

340340
self.epoch_last_check = epoch
341341

342-
filepath = self.format_checkpoint_name(epoch, metrics)
342+
ckpt_name_metrics = trainer.logged_metrics
343+
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics)
343344
version_cnt = 0
344345
while gfile.exists(filepath):
345-
filepath = self.format_checkpoint_name(epoch, metrics, ver=version_cnt)
346+
filepath = self.format_checkpoint_name(epoch, ckpt_name_metrics, ver=version_cnt)
346347
# this epoch called before
347348
version_cnt += 1
348349

pytorch_lightning/trainer/logging.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class TrainerLoggingMixin(ABC):
2424
default_root_dir: str
2525
slurm_job_id: int
2626
num_gpus: int
27+
logged_metrics: ...
2728

2829
def configure_logger(self, logger):
2930
if logger is True:
@@ -75,6 +76,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
7576
self.logger.agg_and_log_metrics(scalar_metrics, step=step)
7677
self.logger.save()
7778

79+
# track the logged metrics
80+
self.logged_metrics = scalar_metrics
7881
self.dev_debugger.track_logged_metrics_history(scalar_metrics)
7982

8083
def add_progress_bar_metrics(self, metrics):

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -374,6 +374,7 @@ def __init__(
374374
self.batch_idx = 0
375375
self.progress_bar_metrics = {}
376376
self.callback_metrics = {}
377+
self.logged_metrics = {}
377378
self.num_training_batches = 0
378379
self.num_val_batches = []
379380
self.num_test_batches = []

tests/callbacks/test_model_checkpoint.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import os
2+
import re
23
import pickle
34
import platform
45
from pathlib import Path
@@ -128,3 +129,58 @@ def test_model_checkpoint_save_last_checkpoint_contents(tmpdir):
128129
model_last = EvalModelTemplate.load_from_checkpoint(path_last)
129130
for w0, w1 in zip(model_last_epoch.parameters(), model_last.parameters()):
130131
assert w0.eq(w1).all()
132+
133+
134+
def test_ckpt_metric_names(tmpdir):
135+
model = EvalModelTemplate()
136+
137+
trainer = Trainer(
138+
default_root_dir=tmpdir,
139+
max_epochs=1,
140+
gradient_clip_val=1.0,
141+
overfit_batches=0.20,
142+
progress_bar_refresh_rate=0,
143+
limit_train_batches=0.01,
144+
limit_val_batches=0.01,
145+
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
146+
)
147+
148+
trainer.fit(model)
149+
150+
# make sure the checkpoint we saved has the metric in the name
151+
ckpts = os.listdir(tmpdir)
152+
ckpts = [x for x in ckpts if 'val_loss' in x]
153+
assert len(ckpts) == 1
154+
val = re.sub('[^0-9.]', '', ckpts[0])
155+
assert len(val) > 3
156+
157+
158+
def test_ckpt_metric_names_results(tmpdir):
159+
model = EvalModelTemplate()
160+
model.training_step = model.training_step_result_obj
161+
model.training_step_end = None
162+
model.training_epoch_end = None
163+
164+
model.validation_step = model.validation_step_result_obj
165+
model.validation_step_end = None
166+
model.validation_epoch_end = None
167+
168+
trainer = Trainer(
169+
default_root_dir=tmpdir,
170+
max_epochs=1,
171+
gradient_clip_val=1.0,
172+
overfit_batches=0.20,
173+
progress_bar_refresh_rate=0,
174+
limit_train_batches=0.01,
175+
limit_val_batches=0.01,
176+
checkpoint_callback=ModelCheckpoint(filepath=tmpdir + '/{val_loss:.2f}')
177+
)
178+
179+
trainer.fit(model)
180+
181+
# make sure the checkpoint we saved has the metric in the name
182+
ckpts = os.listdir(tmpdir)
183+
ckpts = [x for x in ckpts if 'val_loss' in x]
184+
assert len(ckpts) == 1
185+
val = re.sub('[^0-9.]', '', ckpts[0])
186+
assert len(val) > 3

tests/trainer/test_eval_loop_dict_return.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def test_validation_step_dict_return(tmpdir):
136136
assert k in eval_results[1]
137137

138138
# ensure all the keys ended up as candidates for callbacks
139-
assert len(trainer.callback_metrics) == 8
139+
assert len(trainer.callback_metrics) == 7
140140

141141
# make sure correct steps were called
142142
assert model.validation_step_called
@@ -211,7 +211,7 @@ def test_val_step_step_end(tmpdir):
211211
assert k in eval_results[1]
212212

213213
# ensure all the keys ended up as candidates for callbacks
214-
assert len(trainer.callback_metrics) == 9
214+
assert len(trainer.callback_metrics) == 8
215215

216216
# make sure correct steps were called
217217
assert model.validation_step_called
@@ -254,7 +254,7 @@ def test_no_val_step_end(tmpdir):
254254
assert k in eval_results
255255

256256
# ensure all the keys ended up as candidates for callbacks
257-
assert len(trainer.callback_metrics) == 9
257+
assert len(trainer.callback_metrics) == 8
258258

259259
# make sure correct steps were called
260260
assert model.validation_step_called
@@ -297,7 +297,7 @@ def test_full_val_loop(tmpdir):
297297
assert k in eval_results
298298

299299
# ensure all the keys ended up as candidates for callbacks
300-
assert len(trainer.callback_metrics) == 10
300+
assert len(trainer.callback_metrics) == 9
301301

302302
# make sure correct steps were called
303303
assert model.validation_step_called

tests/trainer/test_trainer_steps_scalar_return.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def test_full_training_loop_scalar(tmpdir):
108108
assert model.training_epoch_end_called
109109

110110
# assert epoch end metrics were added
111-
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
111+
assert len(trainer.callback_metrics) == 0
112112
assert len(trainer.progress_bar_metrics) == 0
113113

114114
# make sure training outputs what is expected
@@ -151,7 +151,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
151151
assert model.training_epoch_end_called
152152

153153
# assert epoch end metrics were added
154-
assert 'epoch' in trainer.callback_metrics and len(trainer.callback_metrics) == 1
154+
assert len(trainer.callback_metrics) == 0
155155
assert len(trainer.progress_bar_metrics) == 0
156156

157157
# make sure training outputs what is expected

0 commit comments

Comments
 (0)