-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Description
Description:
In the train_dreambooth.py script, the logging and progress bar updates appear to be executed on every training step, even when using gradient accumulation. This might lead to incorrect or redundant logging.
The relevant code is located around lines 1393-1395:
diffusers/examples/dreambooth/train_dreambooth.py
Lines 1346 to 1395 in 093cd3f
| # Checks if the accelerator has performed an optimization step behind the scenes | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| if accelerator.is_main_process: | |
| if global_step % args.checkpointing_steps == 0: | |
| # _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | |
| if args.checkpoints_total_limit is not None: | |
| checkpoints = os.listdir(args.output_dir) | |
| checkpoints = [d for d in checkpoints if d.startswith("checkpoint")] | |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.split("-")[1])) | |
| # before we save the new checkpoint, we need to have at _most_ `checkpoints_total_limit - 1` checkpoints | |
| if len(checkpoints) >= args.checkpoints_total_limit: | |
| num_to_remove = len(checkpoints) - args.checkpoints_total_limit + 1 | |
| removing_checkpoints = checkpoints[0:num_to_remove] | |
| logger.info( | |
| f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints" | |
| ) | |
| logger.info(f"removing checkpoints: {', '.join(removing_checkpoints)}") | |
| for removing_checkpoint in removing_checkpoints: | |
| removing_checkpoint = os.path.join(args.output_dir, removing_checkpoint) | |
| shutil.rmtree(removing_checkpoint) | |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | |
| accelerator.save_state(save_path) | |
| logger.info(f"Saved state to {save_path}") | |
| images = [] | |
| if args.validation_prompt is not None and global_step % args.validation_steps == 0: | |
| images = log_validation( | |
| unwrap_model(text_encoder) if text_encoder is not None else text_encoder, | |
| tokenizer, | |
| unwrap_model(unet), | |
| vae, | |
| args, | |
| accelerator, | |
| weight_dtype, | |
| global_step, | |
| validation_prompt_encoder_hidden_states, | |
| validation_prompt_negative_prompt_embeds, | |
| ) | |
| logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} | |
| progress_bar.set_postfix(**logs) | |
| accelerator.log(logs, step=global_step) |
The global_step is only incremented when accelerator.sync_gradients is true. However, the logging calls (progress_bar.set_postfix and accelerator.log) are outside this block. This means that when gradient accumulation is used, these lines are executed for every batch, but the global_step value passed to accelerator.log does not change until an optimization step occurs. This could result in multiple log entries for the same global_step.
It seems more appropriate to move the logging logic inside the if accelerator.sync_gradients: block to ensure that logging only happens once per optimization step.
Proposed Change:
if accelerator.sync_gradients:
progress_bar.update(1)
global_step += 1
if accelerator.is_main_process:
# ... checkpointing and validation logic ...
- logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
- progress_bar.set_postfix(**logs)
- accelerator.log(logs, step=global_step)
+ logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
+ progress_bar.set_postfix(**logs)
+ accelerator.log(logs, step=global_step)Could you please confirm if this is the intended behavior or if the indentation should be corrected? Thank you!