From e4d82f27b8ffaa82f85ddcbe3b386c4da49dabe8 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 11:06:27 -0700 Subject: [PATCH 01/20] Added explanation of 'strength' parameter --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 518a9a3e9781..f6f5a51c69ec 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -710,6 +710,13 @@ def __call__( The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. + strength (`float`, *optional*, defaults to 1.): + Conceptually, indicates how much to transform the masked portion of the reference `image`. + Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the + larger the `strength`. The number of denoising steps depends on the amount of noise initially + added. When `strength` is 1, added noise will be maximum and the denoising process will run for + the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, + essentially ignores the masked portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. From a0a04b805286f8b42d2fe5d6770d6d85e490f300 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 11:09:54 -0700 Subject: [PATCH 02/20] Added get_timesteps function which relies on new strength parameter --- .../pipeline_stable_diffusion_inpaint.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index f6f5a51c69ec..d55e5e931b66 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -668,6 +668,16 @@ def prepare_mask_latents( # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + def get_timesteps(self, num_inference_steps, strength): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start @torch.no_grad() def __call__( @@ -844,8 +854,7 @@ def __call__( mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) # 5. set timesteps - self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps = self.scheduler.timesteps + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels From 43fce9364e29e0f6a030457fbe5069c85723a54b Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 11:10:12 -0700 Subject: [PATCH 03/20] Added `strength` parameter which defaults to 1. --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index d55e5e931b66..039e0d54d5b8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -687,6 +687,7 @@ def __call__( mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, + strength: float = 1., num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, From 417604718a09438ea6e0cc20acf4e273d2075ed8 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 11:18:30 -0700 Subject: [PATCH 04/20] Swapped ordering so `noise_timestep` can be calculated before masking the image this is required when you aren't applying 100% noise to the masked region, e.g. strength < 1. --- .../pipeline_stable_diffusion_inpaint.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 039e0d54d5b8..c31270898bc9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -851,11 +851,14 @@ def __call__( negative_prompt_embeds=negative_prompt_embeds, ) - # 4. Preprocess mask and image - resizes image and mask w.r.t height and width - mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width) - - # 5. set timesteps + # 4. set timesteps timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength) + # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) + noise_timestep = timesteps[0:1] + noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + + # 5. Preprocess mask and image + mask, masked_image = prepare_mask_and_masked_image(image, mask_image) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels From db47974bb7811cc713c80cfa119e3526af2fd5f6 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 11:34:01 -0700 Subject: [PATCH 05/20] Added strength to check_inputs, throws error if out of range --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index c31270898bc9..3599131359af 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -558,11 +558,15 @@ def check_inputs( prompt, height, width, + strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None, ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -820,6 +824,7 @@ def __call__( prompt, height, width, + strength, callback_steps, negative_prompt, prompt_embeds, From 81660d0845bc73fe4e543d58300994ca5018eb81 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 14:29:54 -0700 Subject: [PATCH 06/20] Changed `prepare_latents` to initialise latents w.r.t strength inspired from the stable diffusion img2img pipeline, init latents are initialised by converting the init image into a VAE latent and adding noise (based upon the strength parameter passed in), e.g. random when strength = 1, or the init image at strength = 0. --- .../pipeline_stable_diffusion_inpaint.py | 30 +++++++++++++++---- 1 file changed, 24 insertions(+), 6 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3599131359af..2b8afc93d5a7 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -146,7 +146,7 @@ def prepare_mask_and_masked_image(image, mask, height, width): masked_image = image * (mask < 0.5) - return mask, masked_image + return image, mask, masked_image class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -605,7 +605,7 @@ def check_inputs( ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -614,12 +614,28 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # initialise latents as image + noise + + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = self.scheduler.add_noise(image_latents, noise, timestep) + # latents = noise else: latents = latents.to(device) + # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + return latents def prepare_mask_latents( @@ -857,17 +873,19 @@ def __call__( ) # 4. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) - noise_timestep = timesteps[0:1] - noise_timestep = noise_timestep.repeat(batch_size * num_images_per_prompt) + latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 5. Preprocess mask and image - mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + image, mask, masked_image = prepare_mask_and_masked_image(image, mask_image) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( + image, + latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, height, From 73b2d2080147a8eb13f5fedf9fdffc11ab56526f Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 15:05:45 -0700 Subject: [PATCH 07/20] WIP: Added a unit test for the new strength parameter in the StableDiffusionInpaintingPipeline still need to add correct regression values --- .../test_stable_diffusion_inpaint.py | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index cdf138c4e178..27d055d3d2dd 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -325,6 +325,27 @@ def test_stable_diffusion_inpaint_pil_input_resolution_test(self): # verify that the returned image has the same height and width as the input height and width assert image.shape == (1, inputs["height"], inputs["width"], 3) +def test_stable_diffusion_inpaint_strength_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input strength + inputs["strength"] = 0.1 + inputs["height"] = 128 + inputs["width"] = 128 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, inputs["height"], inputs["width"], 3) + + image_slice = image[0, 253:256, 253:256, -1].flatten() + expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272]) + assert np.abs(expected_slice - image_slice).max() < 3e-3 @nightly @require_torch_gpu From d900fb83e51d453d7e289a91997252e4236cdaa3 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 16:11:05 -0700 Subject: [PATCH 08/20] Created a is_strength_max to initialise from pure random noise --- .../pipeline_stable_diffusion_inpaint.py | 40 +++++++++++-------- 1 file changed, 23 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 2b8afc93d5a7..514d699885b2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -146,7 +146,7 @@ def prepare_mask_and_masked_image(image, mask, height, width): masked_image = image * (mask < 0.5) - return image, mask, masked_image + return mask, masked_image, image class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -605,7 +605,7 @@ def check_inputs( ) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents - def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, is_strength_max, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -614,21 +614,24 @@ def prepare_latents(self, image, timestep, batch_size, num_channels_latents, hei ) if latents is None: - # initialise latents as image + noise - - image = image.to(device=device, dtype=dtype) - if isinstance(generator, list): - image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) - for i in range(batch_size) - ] - else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) - - image_latents = self.vae.config.scaling_factor * image_latents noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = self.scheduler.add_noise(image_latents, noise, timestep) - # latents = noise + if is_strength_max: + # if strength is 100% then simply initialise the latents to noise + latents = noise + else: + # otherwise initialise latents as init image + noise + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) @@ -877,9 +880,11 @@ def __call__( timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise + is_strength_max = strength == 1. # 5. Preprocess mask and image - image, mask, masked_image = prepare_mask_and_masked_image(image, mask_image) + mask, masked_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels @@ -893,6 +898,7 @@ def __call__( prompt_embeds.dtype, device, generator, + is_strength_max, latents, ) From 4fe9a2625a58058adc11ea33a2390221f831de63 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 16:29:05 -0700 Subject: [PATCH 09/20] Updated unit tests w.r.t new strength parameter + fixed new strength unit test --- .../test_stable_diffusion_inpaint.py | 86 +++++++++++-------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 27d055d3d2dd..344453575ece 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -325,27 +325,25 @@ def test_stable_diffusion_inpaint_pil_input_resolution_test(self): # verify that the returned image has the same height and width as the input height and width assert image.shape == (1, inputs["height"], inputs["width"], 3) -def test_stable_diffusion_inpaint_strength_test(self): - pipe = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", safety_checker=None - ) - pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - inputs = self.get_inputs(torch_device) - # change input strength - inputs["strength"] = 0.1 - inputs["height"] = 128 - inputs["width"] = 128 - image = pipe(**inputs).images - # verify that the returned image has the same height and width as the input height and width - assert image.shape == (1, inputs["height"], inputs["width"], 3) - - image_slice = image[0, 253:256, 253:256, -1].flatten() - expected_slice = np.array([0.0425, 0.0273, 0.0344, 0.1694, 0.1727, 0.1812, 0.3256, 0.3311, 0.3272]) - assert np.abs(expected_slice - image_slice).max() < 3e-3 + def test_stable_diffusion_inpaint_strength_test(self): + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input strength + inputs["strength"] = 0.75 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, 253:256, 253:256, -1].flatten() + expected_slice = np.array([0.04033301, 0.20779768, 0.33687602, 0.11189868, 0.07587512, 0.16396358, 0.13988009, 0.01543523, 0.06719428]) + assert np.abs(expected_slice - image_slice).max() < 3e-3 @nightly @require_torch_gpu @@ -449,24 +447,30 @@ def test_pil_inputs(self): mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) - t_mask, t_masked = prepare_mask_and_masked_image(im, mask, height, width) + t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width) self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor)) + self.assertTrue(isinstance(t_image, torch.Tensor)) self.assertEqual(t_mask.ndim, 4) self.assertEqual(t_masked.ndim, 4) + self.assertEqual(t_image.ndim, 4) self.assertEqual(t_mask.shape, (1, 1, height, width)) self.assertEqual(t_masked.shape, (1, 3, height, width)) + self.assertEqual(t_image.shape, (1, 3, height, width)) self.assertTrue(t_mask.dtype == torch.float32) self.assertTrue(t_masked.dtype == torch.float32) + self.assertTrue(t_image.dtype == torch.float32) self.assertTrue(t_mask.min() >= 0.0) self.assertTrue(t_mask.max() <= 1.0) self.assertTrue(t_masked.min() >= -1.0) self.assertTrue(t_masked.min() <= 1.0) + self.assertTrue(t_image.min() >= -1.0) + self.assertTrue(t_image.min() >= -1.0) self.assertTrue(t_mask.sum() > 0.0) @@ -489,11 +493,12 @@ def test_np_inputs(self): ) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) - t_mask_pil, t_masked_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) + self.assertTrue((t_image_np == t_image_pil).all()) def test_torch_3D_2D_inputs(self): height, width = 32, 32 @@ -523,13 +528,14 @@ def test_torch_3D_2D_inputs(self): im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_3D_3D_inputs(self): height, width = 32, 32 @@ -560,13 +566,14 @@ def test_torch_3D_3D_inputs(self): im_np = im_tensor.numpy().transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_2D_inputs(self): height, width = 32, 32 @@ -597,13 +604,14 @@ def test_torch_4D_2D_inputs(self): im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy() - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_3D_inputs(self): height, width = 32, 32 @@ -635,13 +643,14 @@ def test_torch_4D_3D_inputs(self): im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_4D_4D_inputs(self): height, width = 32, 32 @@ -674,13 +683,14 @@ def test_torch_4D_4D_inputs(self): im_np = im_tensor.numpy()[0].transpose(1, 2, 0) mask_np = mask_tensor.numpy()[0][0] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) - t_mask_np, t_masked_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_batch_4D_3D(self): height, width = 32, 32 @@ -713,15 +723,17 @@ def test_torch_batch_4D_3D(self): im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy() for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) + t_image_np = torch.cat([n[2] for n in nps]) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_torch_batch_4D_4D(self): height, width = 32, 32 @@ -755,15 +767,17 @@ def test_torch_batch_4D_4D(self): im_nps = [im.numpy().transpose(1, 2, 0) for im in im_tensor] mask_nps = [mask.numpy()[0] for mask in mask_tensor] - t_mask_tensor, t_masked_tensor = prepare_mask_and_masked_image( + t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width ) nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) + t_image_np = torch.cat([n[2] for n in nps]) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) + self.assertTrue((t_image_tensor == t_image_np).all()) def test_shape_mismatch(self): height, width = 32, 32 From 8aa9489ae6a6def982e3e5d367de40ab8defbc62 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 16:29:26 -0700 Subject: [PATCH 10/20] renamed parameter to avoid confusion with variable of same name --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 514d699885b2..9d377eeb54d8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -884,12 +884,12 @@ def __call__( is_strength_max = strength == 1. # 5. Preprocess mask and image - mask, masked_image, image = prepare_mask_and_masked_image(image, mask_image, height, width) + mask, masked_image, init_image = prepare_mask_and_masked_image(image, mask_image, height, width) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( - image, + init_image, latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, From cd3101bde1c18df8010ba17af515a51699e61d87 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 16:41:39 -0700 Subject: [PATCH 11/20] Updated regression values for new strength test - now passes --- .../pipelines/stable_diffusion/test_stable_diffusion_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 344453575ece..673d953e2c78 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -342,7 +342,7 @@ def test_stable_diffusion_inpaint_strength_test(self): assert image.shape == (1, 512, 512, 3) image_slice = image[0, 253:256, 253:256, -1].flatten() - expected_slice = np.array([0.04033301, 0.20779768, 0.33687602, 0.11189868, 0.07587512, 0.16396358, 0.13988009, 0.01543523, 0.06719428]) + expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) assert np.abs(expected_slice - image_slice).max() < 3e-3 @nightly From aca884f9ffbaa3ca9ee01ec74e51e667109114a3 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Fri, 12 May 2023 16:46:02 -0700 Subject: [PATCH 12/20] removed 'copied from' comment as this method is now different and divergent from the cpy --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 9d377eeb54d8..11ffef62cae1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -604,7 +604,6 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, is_strength_max, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: From f245d6e6517ffe17352500e1601a75d296054fe5 Mon Sep 17 00:00:00 2001 From: Rupert Menneer <71332436+rupertmenneer@users.noreply.github.com> Date: Tue, 16 May 2023 15:11:16 -0700 Subject: [PATCH 13/20] Update src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py Co-authored-by: Patrick von Platen --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 11ffef62cae1..87d4becbb172 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -691,7 +691,7 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if_img2img.IFImg2ImgPipeline.get_timesteps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stabe_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) From 60c1a357da3d2e73d68dba5df7929d6c1e2a326d Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Tue, 16 May 2023 15:18:42 -0700 Subject: [PATCH 14/20] Ensure backwards compatibility for prepare_mask_and_masked_image created a return_image boolean and initialised to false --- .../pipeline_stable_diffusion_inpaint.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 87d4becbb172..af822357bb21 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -36,7 +36,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -def prepare_mask_and_masked_image(image, mask, height, width): +def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool = False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -146,7 +146,11 @@ def prepare_mask_and_masked_image(image, mask, height, width): masked_image = image * (mask < 0.5) - return mask, masked_image, image + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + + return mask, masked_image class StableDiffusionInpaintPipeline(DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin): @@ -883,7 +887,7 @@ def __call__( is_strength_max = strength == 1. # 5. Preprocess mask and image - mask, masked_image, init_image = prepare_mask_and_masked_image(image, mask_image, height, width) + mask, masked_image, init_image = prepare_mask_and_masked_image(image, mask_image, height, width, return_image=True) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels From b0f874b6a6d468f5b998a5cfdfbb4f4598070973 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Tue, 16 May 2023 15:35:55 -0700 Subject: [PATCH 15/20] Ensure backwards compatibility for prepare_latents --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index af822357bb21..7d1f64ab7fc8 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -608,13 +608,19 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def prepare_latents(self, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator, is_strength_max, latents=None): + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, image=None, timestep=None, is_strength_max=True, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + f"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + f"However, either the image or the noise timestep has not been provided." + ) if latents is None: noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) From 46583cc2fff0d68df0c9a51afc2124200e75eaa9 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Tue, 16 May 2023 15:36:09 -0700 Subject: [PATCH 16/20] Fixed copy check typo --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 7d1f64ab7fc8..1aecbfae6761 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -701,7 +701,7 @@ def prepare_mask_latents( masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stabe_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) From c14ecc69ca6cc6bd83c12889b4342fa797f820a8 Mon Sep 17 00:00:00 2001 From: Rupert Menneer Date: Tue, 16 May 2023 16:07:14 -0700 Subject: [PATCH 17/20] Fixes w.r.t backward compibility changes --- .../pipeline_stable_diffusion_inpaint.py | 15 +++---- .../test_stable_diffusion_inpaint.py | 44 ++++++++++++------- 2 files changed, 34 insertions(+), 25 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 1aecbfae6761..3d219d441b6f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -556,7 +556,6 @@ def decode_latents(self, latents): image = image.cpu().permute(0, 2, 3, 1).float().numpy() return image - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs def check_inputs( self, prompt, @@ -702,12 +701,12 @@ def prepare_mask_latents( return mask, masked_image_latents # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps - def get_timesteps(self, num_inference_steps, strength): + def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] return timesteps, num_inference_steps - t_start @@ -886,7 +885,7 @@ def __call__( # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength, device=device) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise @@ -898,8 +897,6 @@ def __call__( # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( - init_image, - latent_timestep, batch_size * num_images_per_prompt, num_channels_latents, height, @@ -907,8 +904,10 @@ def __call__( prompt_embeds.dtype, device, generator, - is_strength_max, - latents, + image=init_image, + timestep=latent_timestep, + is_strength_max=is_strength_max, + latents=latents, ) # 7. Prepare mask latent variables diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 673d953e2c78..cc034c213b5e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -447,7 +447,7 @@ def test_pil_inputs(self): mask = np.random.randint(0, 255, (height, width), dtype=np.uint8) > 127.5 mask = Image.fromarray((mask * 255).astype(np.uint8)) - t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width) + t_mask, t_masked, t_image = prepare_mask_and_masked_image(im, mask, height, width, return_image=True) self.assertTrue(isinstance(t_mask, torch.Tensor)) self.assertTrue(isinstance(t_masked, torch.Tensor)) @@ -493,8 +493,8 @@ def test_np_inputs(self): ) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) - t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) + t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width, return_image=True) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) @@ -529,9 +529,9 @@ def test_torch_3D_2D_inputs(self): mask_np = mask_tensor.numpy() t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -567,9 +567,9 @@ def test_torch_3D_3D_inputs(self): mask_np = mask_tensor.numpy()[0] t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -605,9 +605,9 @@ def test_torch_4D_2D_inputs(self): mask_np = mask_tensor.numpy() t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -644,9 +644,9 @@ def test_torch_4D_3D_inputs(self): mask_np = mask_tensor.numpy()[0] t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -684,9 +684,9 @@ def test_torch_4D_4D_inputs(self): mask_np = mask_tensor.numpy()[0][0] t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -724,9 +724,9 @@ def test_torch_batch_4D_3D(self): mask_nps = [mask.numpy() for mask in mask_tensor] t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] + nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) t_image_np = torch.cat([n[2] for n in nps]) @@ -768,9 +768,9 @@ def test_torch_batch_4D_4D(self): mask_nps = [mask.numpy()[0] for mask in mask_tensor] t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( - im_tensor / 127.5 - 1, mask_tensor, height, width + im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - nps = [prepare_mask_and_masked_image(i, m, height, width) for i, m in zip(im_nps, mask_nps)] + nps = [prepare_mask_and_masked_image(i, m, height, width, return_image=True) for i, m in zip(im_nps, mask_nps)] t_mask_np = torch.cat([n[0] for n in nps]) t_masked_np = torch.cat([n[1] for n in nps]) t_image_np = torch.cat([n[2] for n in nps]) @@ -793,6 +793,7 @@ def test_shape_mismatch(self): torch.randn(64, 64), height, width, + return_image=True ) # test batch dim with self.assertRaises(AssertionError): @@ -806,6 +807,7 @@ def test_shape_mismatch(self): torch.randn(4, 64, 64), height, width, + return_image=True ) # test batch dim with self.assertRaises(AssertionError): @@ -819,6 +821,7 @@ def test_shape_mismatch(self): torch.randn(4, 1, 64, 64), height, width, + return_image=True ) def test_type_mismatch(self): @@ -839,6 +842,7 @@ def test_type_mismatch(self): ).numpy(), height, width, + return_image=True ) # test tensors-only with self.assertRaises(TypeError): @@ -855,6 +859,7 @@ def test_type_mismatch(self): ), height, width, + return_image=True ) def test_channels_first(self): @@ -871,6 +876,7 @@ def test_channels_first(self): ), height, width, + return_image=True ) def test_tensor_range(self): @@ -891,6 +897,7 @@ def test_tensor_range(self): ), height, width, + return_image=True ) # test im >= -1 with self.assertRaises(ValueError): @@ -907,6 +914,7 @@ def test_tensor_range(self): ), height, width, + return_image=True ) # test mask <= 1 with self.assertRaises(ValueError): @@ -923,6 +931,7 @@ def test_tensor_range(self): * 2, height, width, + return_image=True ) # test mask >= 0 with self.assertRaises(ValueError): @@ -939,4 +948,5 @@ def test_tensor_range(self): * -1, height, width, + return_image=True ) From dc65be6bb85525e48827f4fdebb2f981af15a765 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 16 May 2023 17:13:35 -0700 Subject: [PATCH 18/20] make style --- .../pipeline_stable_diffusion_inpaint.py | 52 ++++++++---- .../test_stable_diffusion_inpaint.py | 85 +++++++++++-------- 2 files changed, 84 insertions(+), 53 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 3d219d441b6f..87fbe77f1e74 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -149,7 +149,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool # n.b. ensure backwards compatibility as old function does not return image if return_image: return mask, masked_image, image - + return mask, masked_image @@ -607,18 +607,31 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, image=None, timestep=None, is_strength_max=True, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + image=None, + timestep=None, + is_strength_max=True, + latents=None, + ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) - + if (image is None or timestep is None) and not is_strength_max: raise ValueError( - f"Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." - f"However, either the image or the noise timestep has not been provided." + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." ) if latents is None: @@ -638,12 +651,11 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) image_latents = self.vae.config.scaling_factor * image_latents - + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma @@ -699,7 +711,7 @@ def prepare_mask_latents( # aligning device to prevent device errors when concating it with the latent model input masked_image_latents = masked_image_latents.to(device=device, dtype=dtype) return mask, masked_image_latents - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps def get_timesteps(self, num_inference_steps, strength, device): # get the original timestep using init_timestep @@ -718,7 +730,7 @@ def __call__( mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, height: Optional[int] = None, width: Optional[int] = None, - strength: float = 1., + strength: float = 1.0, num_inference_steps: int = 50, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, @@ -753,12 +765,12 @@ def __call__( width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The width in pixels of the generated image. strength (`float`, *optional*, defaults to 1.): - Conceptually, indicates how much to transform the masked portion of the reference `image`. - Must be between 0 and 1. `image` will be used as a starting point, adding more noise to it the - larger the `strength`. The number of denoising steps depends on the amount of noise initially - added. When `strength` is 1, added noise will be maximum and the denoising process will run for - the full number of iterations specified in `num_inference_steps`. A value of 1, therefore, - essentially ignores the masked portion of the reference `image`. + Conceptually, indicates how much to transform the masked portion of the reference `image`. Must be + between 0 and 1. `image` will be used as a starting point, adding more noise to it the larger the + `strength`. The number of denoising steps depends on the amount of noise initially added. When + `strength` is 1, added noise will be maximum and the denoising process will run for the full number of + iterations specified in `num_inference_steps`. A value of 1, therefore, essentially ignores the masked + portion of the reference `image`. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. @@ -885,14 +897,18 @@ def __call__( # 4. set timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps=num_inference_steps, strength=strength, device=device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps=num_inference_steps, strength=strength, device=device + ) # at which timestep to set the initial noise (n.b. 50% if strength is 0.5) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise - is_strength_max = strength == 1. + is_strength_max = strength == 1.0 # 5. Preprocess mask and image - mask, masked_image, init_image = prepare_mask_and_masked_image(image, mask_image, height, width, return_image=True) + mask, masked_image, init_image = prepare_mask_and_masked_image( + image, mask_image, height, width, return_image=True + ) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index cc034c213b5e..9e46270a1543 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -326,24 +326,25 @@ def test_stable_diffusion_inpaint_pil_input_resolution_test(self): assert image.shape == (1, inputs["height"], inputs["width"], 3) def test_stable_diffusion_inpaint_strength_test(self): - pipe = StableDiffusionInpaintPipeline.from_pretrained( - "runwayml/stable-diffusion-inpainting", safety_checker=None - ) - pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) - pipe.to(torch_device) - pipe.set_progress_bar_config(disable=None) - pipe.enable_attention_slicing() - - inputs = self.get_inputs(torch_device) - # change input strength - inputs["strength"] = 0.75 - image = pipe(**inputs).images - # verify that the returned image has the same height and width as the input height and width - assert image.shape == (1, 512, 512, 3) - - image_slice = image[0, 253:256, 253:256, -1].flatten() - expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) - assert np.abs(expected_slice - image_slice).max() < 3e-3 + pipe = StableDiffusionInpaintPipeline.from_pretrained( + "runwayml/stable-diffusion-inpainting", safety_checker=None + ) + pipe.scheduler = LMSDiscreteScheduler.from_config(pipe.scheduler.config) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + pipe.enable_attention_slicing() + + inputs = self.get_inputs(torch_device) + # change input strength + inputs["strength"] = 0.75 + image = pipe(**inputs).images + # verify that the returned image has the same height and width as the input height and width + assert image.shape == (1, 512, 512, 3) + + image_slice = image[0, 253:256, 253:256, -1].flatten() + expected_slice = np.array([0.0021, 0.2350, 0.3712, 0.0575, 0.2485, 0.3451, 0.1857, 0.3156, 0.3943]) + assert np.abs(expected_slice - image_slice).max() < 3e-3 + @nightly @require_torch_gpu @@ -493,8 +494,12 @@ def test_np_inputs(self): ) mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8)) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) - t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image(im_pil, mask_pil, height, width, return_image=True) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) + t_mask_pil, t_masked_pil, t_image_pil = prepare_mask_and_masked_image( + im_pil, mask_pil, height, width, return_image=True + ) self.assertTrue((t_mask_np == t_mask_pil).all()) self.assertTrue((t_masked_np == t_masked_pil).all()) @@ -531,7 +536,9 @@ def test_torch_3D_2D_inputs(self): t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -569,7 +576,9 @@ def test_torch_3D_3D_inputs(self): t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -607,7 +616,9 @@ def test_torch_4D_2D_inputs(self): t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -646,7 +657,9 @@ def test_torch_4D_3D_inputs(self): t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -686,7 +699,9 @@ def test_torch_4D_4D_inputs(self): t_mask_tensor, t_masked_tensor, t_image_tensor = prepare_mask_and_masked_image( im_tensor / 127.5 - 1, mask_tensor, height, width, return_image=True ) - t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image(im_np, mask_np, height, width, return_image=True) + t_mask_np, t_masked_np, t_image_np = prepare_mask_and_masked_image( + im_np, mask_np, height, width, return_image=True + ) self.assertTrue((t_mask_tensor == t_mask_np).all()) self.assertTrue((t_masked_tensor == t_masked_np).all()) @@ -793,7 +808,7 @@ def test_shape_mismatch(self): torch.randn(64, 64), height, width, - return_image=True + return_image=True, ) # test batch dim with self.assertRaises(AssertionError): @@ -807,7 +822,7 @@ def test_shape_mismatch(self): torch.randn(4, 64, 64), height, width, - return_image=True + return_image=True, ) # test batch dim with self.assertRaises(AssertionError): @@ -821,7 +836,7 @@ def test_shape_mismatch(self): torch.randn(4, 1, 64, 64), height, width, - return_image=True + return_image=True, ) def test_type_mismatch(self): @@ -842,7 +857,7 @@ def test_type_mismatch(self): ).numpy(), height, width, - return_image=True + return_image=True, ) # test tensors-only with self.assertRaises(TypeError): @@ -859,7 +874,7 @@ def test_type_mismatch(self): ), height, width, - return_image=True + return_image=True, ) def test_channels_first(self): @@ -876,7 +891,7 @@ def test_channels_first(self): ), height, width, - return_image=True + return_image=True, ) def test_tensor_range(self): @@ -897,7 +912,7 @@ def test_tensor_range(self): ), height, width, - return_image=True + return_image=True, ) # test im >= -1 with self.assertRaises(ValueError): @@ -914,7 +929,7 @@ def test_tensor_range(self): ), height, width, - return_image=True + return_image=True, ) # test mask <= 1 with self.assertRaises(ValueError): @@ -931,7 +946,7 @@ def test_tensor_range(self): * 2, height, width, - return_image=True + return_image=True, ) # test mask >= 0 with self.assertRaises(ValueError): @@ -948,5 +963,5 @@ def test_tensor_range(self): * -1, height, width, - return_image=True + return_image=True, ) From 26f0c2e6fb6dc5d9a85098a5bc3f5689579e0889 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 16 May 2023 17:34:31 -0700 Subject: [PATCH 19/20] keep function argument ordering same for backwards compatibility in callees with copied from statements --- .../stable_diffusion/pipeline_stable_diffusion_inpaint.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 87fbe77f1e74..78ef11587b4d 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -616,10 +616,10 @@ def prepare_latents( dtype, device, generator, + latents=None, image=None, timestep=None, is_strength_max=True, - latents=None, ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -920,10 +920,10 @@ def __call__( prompt_embeds.dtype, device, generator, + latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, - latents=latents, ) # 7. Prepare mask latent variables From 934974abc79336889cf33b208b90692460450d12 Mon Sep 17 00:00:00 2001 From: William Berman Date: Tue, 16 May 2023 17:34:41 -0700 Subject: [PATCH 20/20] make fix-copies --- .../controlnet/pipeline_controlnet_inpaint.py | 47 +++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index a146a1cc2908..27475dc5ef8b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -99,7 +99,7 @@ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image -def prepare_mask_and_masked_image(image, mask, height, width): +def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): """ Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the @@ -209,6 +209,10 @@ def prepare_mask_and_masked_image(image, mask, height, width): masked_image = image * (mask < 0.5) + # n.b. ensure backwards compatibility as old function does not return image + if return_image: + return mask, masked_image, image + return mask, masked_image @@ -795,7 +799,20 @@ def prepare_control_image( return image # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + image=None, + timestep=None, + is_strength_max=True, + ): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -803,13 +820,37 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + if (image is None or timestep is None) and not is_strength_max: + raise ValueError( + "Since strength < 1. initial latents are to be initialised as a combination of Image + Noise." + "However, either the image or the noise timestep has not been provided." + ) + if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + if is_strength_max: + # if strength is 100% then simply initialise the latents to noise + latents = noise + else: + # otherwise initialise latents as init image + noise + image = image.to(device=device, dtype=dtype) + if isinstance(generator, list): + image_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + for i in range(batch_size) + ] + else: + image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + + image_latents = self.vae.config.scaling_factor * image_latents + + latents = self.scheduler.add_noise(image_latents, noise, timestep) else: latents = latents.to(device) # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma + return latents def _default_height_width(self, height, width, image):