-
Notifications
You must be signed in to change notification settings - Fork 160
Description
Describe the bug
def _early_stop_choice(
self,
wait,
min_score,
metric_score,
max_score,
model,
dump_dir,
fold,
patience,
epoch,
):
score = list(metric_score.values())[0]
judge_metric = list(metric_score.keys())[0]
is_increase = METRICS_REGISTER[self.task][judge_metric][1]
if is_increase:
is_early_stop, max_score, wait = self._judge_early_stop_increase(
wait, score, max_score, model, dump_dir, fold, patience, epoch
)
else:
is_early_stop, min_score, wait = self._judge_early_stop_decrease(
wait, score, min_score, model, dump_dir, fold, patience, epoch
)
return is_early_stop, min_score, wait, max_score
def early_stop_choice(self, model, epoch, loss, metric_score=None):
"""
Determines if early stopping criteria are met, based on either loss improvement or custom metric score.
:param model: The model being trained.
:param epoch: The current epoch number.
:param loss: The current loss value.
:param metric_score: The current metric score.
:return: A boolean indicating whether early stopping should occur.
"""
if not isinstance(self.metrics_str, str) or self.metrics_str in [
'loss',
'none',
'',
]:
return self._judge_early_stop_loss(loss, model, epoch)
else:
return self.metrics._early_stop_choice(
self.wait,
self.min_loss,
metric_score,
model,
self.dump_dir,
self.fold,
self.patience,
epoch,
)
File ~/Uni-Mol/unimol_tools/unimol_tools/tasks/trainer.py:755, in EarlyStopper.early_stop_choice(self, model, epoch, loss, metric_score)
753 return self._judge_early_stop_loss(loss, model, epoch)
754 else:
--> 755 return self.metrics._early_stop_choice(
756 self.wait,
757 self.min_loss,
758 metric_score,
759 model,
760 self.dump_dir,
761 self.fold,
762 self.patience,
763 epoch,
764 )
TypeError: Metrics._early_stop_choice() missing 1 required positional argument: 'epoch'
Uni-Mol Version
Uni-Mol Tools
Expected behavior
Need to be fixed?
To Reproduce
No response
Environment
No response
Additional Context
No response