diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index 0239c8128171..6444ec7c8506 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -36,6 +36,7 @@ from ...utils import ( PIL_INTERPOLATION, BaseOutput, + deprecate, is_accelerate_available, is_accelerate_version, logging, @@ -721,23 +722,31 @@ def prepare_image_latents(self, image, batch_size, dtype, device, generator=None ) if isinstance(generator, list): - init_latents = [ - self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) - ] - init_latents = torch.cat(init_latents, dim=0) + latents = [self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)] + latents = torch.cat(latents, dim=0) else: - init_latents = self.vae.encode(image).latent_dist.sample(generator) - - init_latents = self.vae.config.scaling_factor * init_latents - - if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." - ) + latents = self.vae.encode(image).latent_dist.sample(generator) + + latents = self.vae.config.scaling_factor * latents + + if batch_size != latents.shape[0]: + if batch_size % latents.shape[0] == 0: + # expand image_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_latents_per_image = batch_size // latents.shape[0] + latents = torch.cat([latents] * additional_latents_per_image, dim=0) + else: + raise ValueError( + f"Cannot duplicate `image` of batch size {latents.shape[0]} to {batch_size} text prompts." + ) else: - init_latents = torch.cat([init_latents], dim=0) - - latents = init_latents + latents = torch.cat([latents], dim=0) return latents @@ -759,23 +768,18 @@ def get_epsilon(self, model_output: torch.Tensor, sample: torch.Tensor, timestep ) def auto_corr_loss(self, hidden_states, generator=None): - batch_size, channel, height, width = hidden_states.shape - if batch_size > 1: - raise ValueError("Only batch_size 1 is supported for now") - - hidden_states = hidden_states.squeeze(0) - # hidden_states must be shape [C,H,W] now reg_loss = 0.0 for i in range(hidden_states.shape[0]): - noise = hidden_states[i][None, None, :, :] - while True: - roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 - reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 - - if noise.shape[2] <= 8: - break - noise = F.avg_pool2d(noise, kernel_size=2) + for j in range(hidden_states.shape[1]): + noise = hidden_states[i : i + 1, j : j + 1, :, :] + while True: + roll_amount = torch.randint(noise.shape[2] // 2, (1,), generator=generator).item() + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=2)).mean() ** 2 + reg_loss += (noise * torch.roll(noise, shifts=roll_amount, dims=3)).mean() ** 2 + + if noise.shape[2] <= 8: + break + noise = F.avg_pool2d(noise, kernel_size=2) return reg_loss def kl_divergence(self, hidden_states): diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py index 0809a91041ce..661926daaa3e 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_pix2pix_zero.py @@ -14,6 +14,8 @@ # limitations under the License. import gc +import random +import tempfile import unittest import numpy as np @@ -30,7 +32,7 @@ StableDiffusionPix2PixZeroPipeline, UNet2DConditionModel, ) -from diffusers.utils import load_numpy, slow, torch_device +from diffusers.utils import floats_tensor, load_numpy, slow, torch_device from diffusers.utils.testing_utils import load_image, load_pt, require_torch_gpu, skip_mps from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS @@ -69,6 +71,7 @@ def get_dummy_components(self): cross_attention_dim=32, ) scheduler = DDIMScheduler() + inverse_scheduler = DDIMInverseScheduler() torch.manual_seed(0) vae = AutoencoderKL( block_out_channels=[32, 64], @@ -101,7 +104,7 @@ def get_dummy_components(self): "tokenizer": tokenizer, "safety_checker": None, "feature_extractor": None, - "inverse_scheduler": None, + "inverse_scheduler": inverse_scheduler, "caption_generator": None, "caption_processor": None, } @@ -122,6 +125,90 @@ def get_dummy_inputs(self, device, seed=0): } return inputs + def get_dummy_inversion_inputs(self, device, seed=0): + dummy_image = floats_tensor((2, 3, 32, 32), rng=random.Random(seed)).to(torch_device) + generator = torch.manual_seed(seed) + + inputs = { + "prompt": [ + "A painting of a squirrel eating a burger", + "A painting of a burger eating a squirrel", + ], + "image": dummy_image.cpu(), + "num_inference_steps": 2, + "guidance_scale": 6.0, + "generator": generator, + "output_type": "numpy", + } + return inputs + + def test_save_load_optional_components(self): + if not hasattr(self.pipeline_class, "_optional_components"): + return + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None and update pipeline config accordingly + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + pipe.register_modules(**{optional_component: None for optional_component in pipe._optional_components}) + + inputs = self.get_dummy_inputs(torch_device) + output = pipe(**inputs)[0] + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(torch_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(output - output_loaded).max() + self.assertLess(max_diff, 1e-4) + + def test_stable_diffusion_pix2pix_zero_inversion(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inversion_inputs(device) + inputs["image"] = inputs["image"][:1] + inputs["prompt"] = inputs["prompt"][:1] + image = sd_pipe.invert(**inputs).images + image_slice = image[0, -3:, -3:, -1] + assert image.shape == (1, 32, 32, 3) + expected_slice = np.array([0.4833, 0.4696, 0.5574, 0.5194, 0.5248, 0.5638, 0.5040, 0.5423, 0.5072]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + + def test_stable_diffusion_pix2pix_zero_inversion_batch(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + components = self.get_dummy_components() + sd_pipe = StableDiffusionPix2PixZeroPipeline(**components) + sd_pipe = sd_pipe.to(device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inversion_inputs(device) + image = sd_pipe.invert(**inputs).images + image_slice = image[1, -3:, -3:, -1] + assert image.shape == (2, 32, 32, 3) + expected_slice = np.array([0.6672, 0.5203, 0.4908, 0.4376, 0.4517, 0.5544, 0.4605, 0.4826, 0.5007]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3 + def test_stable_diffusion_pix2pix_zero_default_case(self): device = "cpu" # ensure determinism for the device-dependent torch.Generator components = self.get_dummy_components()