From 0db89da4c559e7726dfdcd105ac4f06a8e3fc031 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Fri, 2 Dec 2022 16:10:56 +0000 Subject: [PATCH] [Upscaling] Fix batch size --- .../pipeline_stable_diffusion_upscale.py | 8 +-- .../test_stable_diffusion_upscale.py | 51 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index d3c30597f01a..53395fa9c3ca 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -459,8 +459,10 @@ def __call__( else: noise = torch.randn(image.shape, generator=generator, device=device, dtype=text_embeddings.dtype) image = self.low_res_scheduler.add_noise(image, noise, noise_level) - image = torch.cat([image] * 2) if do_classifier_free_guidance else image - noise_level = torch.cat([noise_level] * 2) if do_classifier_free_guidance else noise_level + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = torch.cat([image] * batch_multiplier * num_images_per_prompt) + noise_level = torch.cat([noise_level] * image.shape[0]) # 6. Prepare latent variables height, width = image.shape[2:] @@ -515,7 +517,7 @@ def __call__( latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample # call the callback, if provided - if (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0: + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) diff --git a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py index 2092e153eeb2..ddb9b3358a84 100644 --- a/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py +++ b/tests/pipelines/stable_diffusion_2/test_stable_diffusion_upscale.py @@ -161,6 +161,57 @@ def test_stable_diffusion_upscale(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2 + def test_stable_diffusion_upscale_batch(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + unet = self.dummy_cond_unet_upscale + low_res_scheduler = DDPMScheduler() + scheduler = DDIMScheduler(prediction_type="v_prediction") + vae = self.dummy_vae + text_encoder = self.dummy_text_encoder + tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + + image = self.dummy_image.cpu().permute(0, 2, 3, 1)[0] + low_res_image = Image.fromarray(np.uint8(image)).convert("RGB").resize((64, 64)) + + # make sure here that pndm scheduler skips prk + sd_pipe = StableDiffusionUpscalePipeline( + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + max_noise_level=350, + ) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + prompt = "A painting of a squirrel eating a burger" + output = sd_pipe( + 2 * [prompt], + image=2 * [low_res_image], + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + ) + image = output.images + assert image.shape[0] == 2 + + generator = torch.Generator(device=device).manual_seed(0) + output = sd_pipe( + [prompt], + image=low_res_image, + generator=generator, + num_images_per_prompt=2, + guidance_scale=6.0, + noise_level=20, + num_inference_steps=2, + output_type="np", + ) + image = output.images + assert image.shape[0] == 2 + @unittest.skipIf(torch_device != "cuda", "This test requires a GPU") def test_stable_diffusion_upscale_fp16(self): """Test that stable diffusion upscale works with fp16"""