Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
Expand Down Expand Up @@ -561,11 +571,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

init_latents = self.vae.config.scaling_factor * init_latents

Expand Down
15 changes: 13 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,16 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


def prepare_image(image):
if isinstance(image, torch.Tensor):
# Batch single image
Expand Down Expand Up @@ -733,11 +743,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

init_latents = self.vae.config.scaling_factor * init_latents

Expand Down
14 changes: 12 additions & 2 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,16 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.prepare_mask_and_masked_image
def prepare_mask_and_masked_image(image, mask, height, width, return_image=False):
"""
Expand Down Expand Up @@ -949,12 +959,12 @@ def prepare_mask_latents(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,16 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


class StableDiffusionXLControlNetImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin
):
Expand Down Expand Up @@ -806,11 +816,12 @@ def prepare_latents(

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

if self.vae.config.force_upcast:
self.vae.to(dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


EXAMPLE_DOC_STRING = """
Examples:
```py
Expand Down Expand Up @@ -426,11 +436,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

init_latents = self.vae.config.scaling_factor * init_latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,16 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


def prepare_mask_and_masked_image(image, mask):
"""
Prepares a pair (image, mask) to be consumed by the Paint by Example pipeline. This means that those inputs will be
Expand Down Expand Up @@ -334,12 +344,12 @@ def prepare_mask_latents(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
Expand Down Expand Up @@ -466,11 +476,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

init_latents = self.vae.config.scaling_factor * init_latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@
"""


def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


def preprocess(image):
deprecation_message = "The preprocess method is deprecated and will be removed in diffusers 1.0.0. Please use VaeImageProcessor.preprocess(...) instead"
deprecate("preprocess", "1.0.0", deprecation_message, standard_warn=False)
Expand Down Expand Up @@ -555,11 +564,12 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

init_latents = self.vae.config.scaling_factor * init_latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


class StableDiffusionInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
):
Expand Down Expand Up @@ -654,12 +664,12 @@ def prepare_latents(
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

image_latents = self.vae.config.scaling_factor * image_latents

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,16 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


class StableDiffusionXLImg2ImgPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, FromSingleFileMixin, StableDiffusionXLLoraLoaderMixin
):
Expand Down Expand Up @@ -604,11 +614,12 @@ def prepare_latents(

elif isinstance(generator, list):
init_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(batch_size)
]
init_latents = torch.cat(init_latents, dim=0)
else:
init_latents = self.vae.encode(image).latent_dist.sample(generator)
init_latents = retrieve_latents(self.vae.encode(image), generator=generator)

if self.vae.config.force_upcast:
self.vae.to(dtype)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,16 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
return mask, masked_image


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(encoder_output, generator):
if hasattr(encoder_output, "latent_dist"):
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latents"):
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")


class StableDiffusionXLInpaintPipeline(
DiffusionPipeline, TextualInversionLoaderMixin, StableDiffusionXLLoraLoaderMixin, FromSingleFileMixin
):
Expand Down Expand Up @@ -750,12 +760,12 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):

if isinstance(generator, list):
image_latents = [
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator=generator[i])
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
for i in range(image.shape[0])
]
image_latents = torch.cat(image_latents, dim=0)
else:
image_latents = self.vae.encode(image).latent_dist.sample(generator=generator)
image_latents = retrieve_latents(self.vae.encode(image), generator=generator)

if self.vae.config.force_upcast:
self.vae.to(dtype)
Expand Down
21 changes: 21 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from diffusers import (
AutoencoderKL,
AutoencoderTiny,
DDIMScheduler,
DPMSolverMultistepScheduler,
HeunDiscreteScheduler,
Expand Down Expand Up @@ -148,6 +149,9 @@ def get_dummy_components(self):
}
return components

def get_dummy_tiny_autoencoder(self):
return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)

def get_dummy_inputs(self, device, seed=0):
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
image = image / 2 + 0.5
Expand Down Expand Up @@ -236,6 +240,23 @@ def test_stable_diffusion_img2img_k_lms(self):

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

def test_stable_diffusion_img2img_tiny_autoencoder(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator
components = self.get_dummy_components()
sd_pipe = StableDiffusionImg2ImgPipeline(**components)
sd_pipe.vae = self.get_dummy_tiny_autoencoder()
sd_pipe = sd_pipe.to(device)
sd_pipe.set_progress_bar_config(disable=None)

inputs = self.get_dummy_inputs(device)
image = sd_pipe(**inputs).images
image_slice = image[0, -3:, -3:, -1]

assert image.shape == (1, 32, 32, 3)
expected_slice = np.array([0.00669, 0.00669, 0.0, 0.00693, 0.00858, 0.0, 0.00567, 0.00515, 0.00125])

assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3

@skip_mps
def test_save_load_local(self):
return super().test_save_load_local()
Expand Down
Loading