diff --git a/dreambooth/train_dreambooth.py b/dreambooth/train_dreambooth.py index 509c3435..f984c212 100644 --- a/dreambooth/train_dreambooth.py +++ b/dreambooth/train_dreambooth.py @@ -860,6 +860,13 @@ def save_weights(step, save_model, save_img): progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step) + + progress_bar.update(1) + global_step += 1 + lifetime_step += 1 + shared.state.job_no = global_step + + training_complete = global_step >= args.max_train_steps or shared.state.interrupted if global_step > 0: save_img = not global_step % args.save_preview_every save_model = not global_step % args.save_embedding_every @@ -869,13 +876,6 @@ def save_weights(step, save_model, save_img): if save_img or save_model: save_weights(lifetime_step, save_model, save_img) - progress_bar.update(1) - global_step += 1 - lifetime_step += 1 - shared.state.job_no = global_step - - training_complete = global_step >= args.max_train_steps or shared.state.interrupted - shared.state.textinfo = f"Training, step {global_step}/{args.max_train_steps} current, {lifetime_step}/{args.max_train_steps + args.total_steps} lifetime"