@@ -116,18 +116,18 @@ def __init__(self, filepath: Optional[str] = None, monitor: str = 'val_loss', ve
116116
117117 torch_inf = torch .tensor (np .Inf )
118118 mode_dict = {
119- 'min' : (torch . lt , torch_inf , 'min' ),
120- 'max' : (torch . gt , - torch_inf , 'max' ),
121- 'auto' : (torch . gt , - torch_inf , 'max' ) if 'acc' in self .monitor or self .monitor .startswith ('fmeasure' )
122- else (torch . lt , torch_inf , 'min' ),
119+ 'min' : (torch_inf , 'min' ),
120+ 'max' : (- torch_inf , 'max' ),
121+ 'auto' : (- torch_inf , 'max' ) if 'acc' in self .monitor or self .monitor .startswith ('fmeasure' )
122+ else (torch_inf , 'min' ),
123123 }
124124
125125 if mode not in mode_dict :
126126 rank_zero_warn (f'ModelCheckpoint mode { mode } is unknown, '
127127 f'fallback to auto mode.' , RuntimeWarning )
128128 mode = 'auto'
129129
130- self .monitor_op , self . kth_value , self .mode = mode_dict [mode ]
130+ self .kth_value , self .mode = mode_dict [mode ]
131131
132132 def _del_model (self , filepath ):
133133 if os .path .isfile (filepath ):
@@ -151,7 +151,12 @@ def check_monitor_top_k(self, current):
151151 if not isinstance (current , torch .Tensor ):
152152 current = torch .tensor (current )
153153
154- return self .monitor_op (current , self .best_k_models [self .kth_best_model ])
154+ monitor_op = {
155+ "min" : torch .lt ,
156+ "max" : torch .gt ,
157+ }[self .mode ]
158+
159+ return monitor_op (current , self .best_k_models [self .kth_best_model ])
155160
156161 def format_checkpoint_name (self , epoch , metrics , ver = None ):
157162 """Generate a filename according to the defined template.
0 commit comments