diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index adc1359a5e56..447adeb9beee 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -159,6 +159,12 @@ def parse_args(input_args=None): " training using `--resume_from_checkpoint`." ), ) + parser.add_argument( + "--generating_progress_steps", + type=int, + default=500, + help=("Save an image generated from the model every X steps."), + ) parser.add_argument( "--resume_from_checkpoint", type=str, @@ -416,6 +422,11 @@ def main(args): if args.seed is not None: set_seed(args.seed) + g = torch.Generator(device=accelerator.device.type) + g.manual_seed(args.seed) + else: + g = torch.Generator(device=accelerator.device.type) + g.manual_seed(42) if args.with_prior_preservation: class_images_dir = Path(args.class_data_dir) @@ -739,6 +750,19 @@ def main(args): accelerator.save_state(save_path) logger.info(f"Saved state to {save_path}") + if global_step % args.generating_progress_steps == 0: + if accelerator.is_main_process: + # generate and save the image + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + torch_dtype=weight_dtype, + ).to(accelerator.device.type) + + image = pipeline(args.instance_prompt, generator=g).images[0] + image.save(f"{args.output_dir}/logs/image_{global_step}.png") + logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]} progress_bar.set_postfix(**logs) accelerator.log(logs, step=global_step)