diff --git a/examples/custom_diffusion/train_custom_diffusion.py b/examples/custom_diffusion/train_custom_diffusion.py index 3288fe3258ac..4773446a615b 100644 --- a/examples/custom_diffusion/train_custom_diffusion.py +++ b/examples/custom_diffusion/train_custom_diffusion.py @@ -207,7 +207,7 @@ def __init__( with open(concept["class_prompt"], "r") as f: class_prompt = f.read().splitlines() - class_img_path = [(x, y) for (x, y) in zip(class_images_path, class_prompt)] + class_img_path = list(zip(class_images_path, class_prompt)) self.class_images_path.extend(class_img_path[:num_class_images]) random.shuffle(self.instance_images_path) @@ -1214,50 +1214,52 @@ def main(args): if global_step >= args.max_train_steps: break - if accelerator.is_main_process: - images = [] + if accelerator.is_main_process: + images = [] - if args.validation_prompt is not None and global_step % args.validation_steps == 0: - logger.info( - f"Running validation... \n Generating {args.num_validation_images} images with prompt:" - f" {args.validation_prompt}." - ) - # create pipeline - pipeline = DiffusionPipeline.from_pretrained( - args.pretrained_model_name_or_path, - unet=accelerator.unwrap_model(unet), - text_encoder=accelerator.unwrap_model(text_encoder), - tokenizer=tokenizer, - revision=args.revision, - torch_dtype=weight_dtype, - ) - pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) - pipeline = pipeline.to(accelerator.device) - pipeline.set_progress_bar_config(disable=True) - - # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) - images = [ - pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[0] - for _ in range(args.num_validation_images) - ] - - for tracker in accelerator.trackers: - if tracker.name == "tensorboard": - np_images = np.stack([np.asarray(img) for img in images]) - tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") - if tracker.name == "wandb": - tracker.log( - { - "validation": [ - wandb.Image(image, caption=f"{i}: {args.validation_prompt}") - for i, image in enumerate(images) - ] - } - ) + if args.validation_prompt is not None and global_step % args.validation_steps == 0: + logger.info( + f"Running validation... \n Generating {args.num_validation_images} images with prompt:" + f" {args.validation_prompt}." + ) + # create pipeline + pipeline = DiffusionPipeline.from_pretrained( + args.pretrained_model_name_or_path, + unet=accelerator.unwrap_model(unet), + text_encoder=accelerator.unwrap_model(text_encoder), + tokenizer=tokenizer, + revision=args.revision, + torch_dtype=weight_dtype, + ) + pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.config) + pipeline = pipeline.to(accelerator.device) + pipeline.set_progress_bar_config(disable=True) - del pipeline - torch.cuda.empty_cache() + # run inference + generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + images = [ + pipeline(args.validation_prompt, num_inference_steps=25, generator=generator, eta=1.0).images[ + 0 + ] + for _ in range(args.num_validation_images) + ] + + for tracker in accelerator.trackers: + if tracker.name == "tensorboard": + np_images = np.stack([np.asarray(img) for img in images]) + tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC") + if tracker.name == "wandb": + tracker.log( + { + "validation": [ + wandb.Image(image, caption=f"{i}: {args.validation_prompt}") + for i, image in enumerate(images) + ] + } + ) + + del pipeline + torch.cuda.empty_cache() # Save the custom diffusion layers accelerator.wait_for_everyone()