Skip to content

Commit

Permalink
Actually save the last checkpoint...
Browse files Browse the repository at this point in the history
  • Loading branch information
d8ahazard committed Nov 9, 2022
1 parent ef4c80a commit 6f81b3b
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions dreambooth/train_dreambooth.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,13 @@ def save_weights(step, save_model, save_img):
progress_bar.set_postfix(**logs)
accelerator.log(logs, step=global_step)


progress_bar.update(1)
global_step += 1
lifetime_step += 1
shared.state.job_no = global_step

training_complete = global_step >= args.max_train_steps or shared.state.interrupted
if global_step > 0:
save_img = not global_step % args.save_preview_every
save_model = not global_step % args.save_embedding_every
Expand All @@ -869,13 +876,6 @@ def save_weights(step, save_model, save_img):
if save_img or save_model:
save_weights(lifetime_step, save_model, save_img)

progress_bar.update(1)
global_step += 1
lifetime_step += 1
shared.state.job_no = global_step

training_complete = global_step >= args.max_train_steps or shared.state.interrupted

shared.state.textinfo = f"Training, step {global_step}/{args.max_train_steps} current, {lifetime_step}/{args.max_train_steps + args.total_steps} lifetime"


Expand Down

0 comments on commit 6f81b3b

Please sign in to comment.