Skip to content

Commit

Permalink
Fixup save_best_model (#137)
Browse files Browse the repository at this point in the history
* Fixup

* Fixup lint
  • Loading branch information
erogol committed Dec 12, 2023
1 parent 7de3bc9 commit c77173d
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions trainer/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,14 +176,23 @@ def save_best_model(
epoch,
out_path,
keep_all_best=False,
keep_after=10000,
keep_after=0,
save_func=None,
**kwargs,
):
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
if (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
):
if isinstance(current_loss, dict):
use_eval_loss = current_loss["eval_loss"] is not None and best_loss["eval_loss"] is not None
is_save_model = (use_eval_loss and current_loss["eval_loss"] < best_loss["eval_loss"]) or (
not use_eval_loss and current_loss["train_loss"] < best_loss["train_loss"]
)
else:
is_save_model = current_loss < best_loss

if isinstance(keep_after, (int, float)):
keep_after = int(keep_after)
is_save_model = is_save_model and current_step > keep_after

if is_save_model:
best_model_name = f"best_model_{current_step}.pth"
checkpoint_path = os.path.join(out_path, best_model_name)
logger.info(" > BEST MODEL : %s", checkpoint_path)
Expand Down

0 comments on commit c77173d

Please sign in to comment.