Skip to content

Commit 735520b

Browse files
authored
Fix ModelCheckpoint not being fixable.
1 parent acfb054 commit 735520b

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
116116

117117
torch_inf = torch.tensor(np.Inf)
118118
mode_dict = {
119-
'min': (torch.lt, torch_inf, 'min'),
120-
'max': (torch.gt, -torch_inf, 'max'),
121-
'auto': (torch.gt, -torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
122-
else (torch.lt, torch_inf, 'min'),
119+
'min': (torch_inf, 'min'),
120+
'max': (-torch_inf, 'max'),
121+
'auto': (-torch_inf, 'max') if 'acc' in self.monitor or self.monitor.startswith('fmeasure')
122+
else (torch_inf, 'min'),
123123
}
124124

125125
if mode not in mode_dict:
126126
rank_zero_warn(f'ModelCheckpoint mode {mode} is unknown, '
127127
f'fallback to auto mode.', RuntimeWarning)
128128
mode = 'auto'
129129

130-
self.monitor_op, self.kth_value, self.mode = mode_dict[mode]
130+
self.kth_value, self.mode = mode_dict[mode]
131131

132132
def _del_model(self, filepath):
133133
if os.path.isfile(filepath):
@@ -151,7 +151,12 @@ def check_monitor_top_k(self, current):
151151
if not isinstance(current, torch.Tensor):
152152
current = torch.tensor(current)
153153

154-
return self.monitor_op(current, self.best_k_models[self.kth_best_model])
154+
monitor_op = {
155+
"min": torch.lt,
156+
"max": torch.gt,
157+
}[self.mode]
158+
159+
return monitor_op(current, self.best_k_models[self.kth_best_model])
155160

156161
def format_checkpoint_name(self, epoch, metrics, ver=None):
157162
"""Generate a filename according to the defined template.

0 commit comments

Comments
 (0)