From 12d3b5a22e2f98cc3da8109ab69389bb11e2c84c Mon Sep 17 00:00:00 2001 From: Alex McKinney <44398246+vvvm23@users.noreply.github.com> Date: Mon, 5 Jun 2023 19:59:36 +0100 Subject: [PATCH] Fixes eval generator init in `train_text_to_image_lora.py` --- examples/text_to_image/train_text_to_image_lora.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 806637f04c53..0f607b88444a 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -825,7 +825,9 @@ def collate_fn(examples): pipeline.set_progress_bar_config(disable=True) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): images.append( @@ -881,7 +883,9 @@ def collate_fn(examples): pipeline.unet.load_attn_procs(args.output_dir) # run inference - generator = torch.Generator(device=accelerator.device).manual_seed(args.seed) + generator = torch.Generator(device=accelerator.device) + if args.seed is not None: + generator = generator.manual_seed(args.seed) images = [] for _ in range(args.num_validation_images): images.append(pipeline(args.validation_prompt, num_inference_steps=30, generator=generator).images[0])