diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index d9acf9daf2a6..343d2132cd83 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -75,6 +75,16 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" @@ -561,11 +571,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 83eafc10407b..b692d936c5b7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -91,6 +91,16 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def prepare_image(image): if isinstance(image, torch.Tensor): # Batch single image @@ -733,11 +743,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index d03a8eea7f6d..3e0b07bf1694 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -103,6 +103,16 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image def prepare_mask_and_masked_image(image, mask, height, width, return_image=False): """ @@ -949,12 +959,12 @@ def prepare_mask_latents( def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 78c7c1cd8dbf..06b4f6f55ec1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -131,6 +131,16 @@ """ +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class StableDiffusionXLControlNetImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin ): @@ -806,11 +816,12 @@ def prepare_latents( elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 9697ede0e1aa..d2764295b719 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -43,6 +43,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + EXAMPLE_DOC_STRING = """ Examples: ```py @@ -426,11 +436,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py index a782caa55efc..ffcbacc3b74e 100644 --- a/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +++ b/src/diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py @@ -34,6 +34,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def prepare_mask_and_masked_image(image, mask): """ Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be @@ -334,12 +344,12 @@ def prepare_mask_latents( def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index 2acdc1c5296f..f36d7471084f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -36,6 +36,16 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" @@ -466,11 +476,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index a6c25987a7fd..82dad3b9a697 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -73,6 +73,15 @@ """ +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + def preprocess(image): deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead" deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False) @@ -555,11 +564,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) init_latents = self.vae.config.scaling_factor * init_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 62f1186eb05b..a8baf606b0df 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -159,6 +159,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool return mask, masked_image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class StableDiffusionInpaintPipeline( DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin ): @@ -654,12 +664,12 @@ def prepare_latents( def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = self.vae.config.scaling_factor * image_latents diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 57d00af82106..bb2a26f162ee 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -92,6 +92,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): return noise_cfg +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class StableDiffusionXLImg2ImgPipeline( DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin ): @@ -604,11 +614,12 @@ def prepare_latents( elif isinstance(generator, list): init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) + for i in range(batch_size) ] init_latents = torch.cat(init_latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) + init_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 11ae0a0d85f0..cf6a4a37ac47 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -238,6 +238,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool return mask, masked_image +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents +def retrieve_latents(encoder_output, generator): + if hasattr(encoder_output, "latent_dist"): + return encoder_output.latent_dist.sample(generator) + elif hasattr(encoder_output, "latents"): + return encoder_output.latents + else: + raise AttributeError("Could not access latents of provided encoder_output") + + class StableDiffusionXLInpaintPipeline( DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin ): @@ -750,12 +760,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): if isinstance(generator, list): image_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i]) + retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: - image_latents = self.vae.encode(image).latent_dist.sample(generator=generator) + image_latents = retrieve_latents(self.vae.encode(image), generator=generator) if self.vae.config.force_upcast: self.vae.to(dtype) diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index be8f067b1b78..d136ec3e57bc 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -24,6 +24,7 @@ from diffusers import ( AutoencoderKL, + AutoencoderTiny, DDIMScheduler, DPMSolverMultistepScheduler, HeunDiscreteScheduler, @@ -148,6 +149,9 @@ def get_dummy_components(self): } return components + def get_dummy_tiny_autoencoder(self): + return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4) + def get_dummy_inputs(self, device, seed=0): image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) image = image / 2 + 0.5 @@ -236,6 +240,23 @@ def test_stable_diffusion_img2img_k_lms(self): assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_img2img_tiny_autoencoder(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionImg2ImgPipeline(**components) + sd_pipe.vae = self.get_dummy_tiny_autoencoder() + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.00669, 0.00669, 0.0, 0.00693, 0.00858, 0.0, 0.00567, 0.00515, 0.00125]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + @skip_mps def test_save_load_local(self): return super().test_save_load_local() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 97c19108947f..651bf508b8c6 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -22,6 +22,7 @@ from diffusers import ( AutoencoderKL, + AutoencoderTiny, EulerDiscreteScheduler, StableDiffusionXLImg2ImgPipeline, UNet2DConditionModel, @@ -121,6 +122,9 @@ def get_dummy_components(self, skip_first_text_encoder=False): } return components + def get_dummy_tiny_autoencoder(self): + return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4) + def test_components_function(self): init_components = self.get_dummy_components() init_components.pop("requires_aesthetics_score") @@ -216,6 +220,23 @@ def test_stable_diffusion_xl_img2img_negative_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_xl_img2img_tiny_autoencoder(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionXLImg2ImgPipeline(**components) + sd_pipe.vae = self.get_dummy_tiny_autoencoder() + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = sd_pipe(**inputs).images + image_slice = image[0, -3:, -3:, -1].flatten() + + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.0, 0.0, 0.0106, 0.0, 0.0, 0.0087, 0.0052, 0.0062, 0.0177]) + + assert np.allclose(image_slice, expected_slice, atol=1e-4, rtol=1e-4) + @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = []