Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
186 changes: 122 additions & 64 deletions examples/community/stable_diffusion_controlnet_img2img.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Inspired by: https://github.com/haofanwang/ControlNet-for-Diffusers/

import inspect
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import PIL.Image
Expand All @@ -10,6 +10,7 @@

from diffusers import AutoencoderKL, ControlNetModel, DiffusionPipeline, UNet2DConditionModel, logging
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.utils import (
PIL_INTERPOLATION,
Expand Down Expand Up @@ -86,7 +87,14 @@ def prepare_image(image):


def prepare_controlnet_conditioning_image(
controlnet_conditioning_image, width, height, batch_size, num_images_per_prompt, device, dtype
controlnet_conditioning_image,
width,
height,
batch_size,
num_images_per_prompt,
device,
dtype,
do_classifier_free_guidance,
):
if not isinstance(controlnet_conditioning_image, torch.Tensor):
if isinstance(controlnet_conditioning_image, PIL.Image.Image):
Expand Down Expand Up @@ -116,6 +124,9 @@ def prepare_controlnet_conditioning_image(

controlnet_conditioning_image = controlnet_conditioning_image.to(device=device, dtype=dtype)

if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)

return controlnet_conditioning_image


Expand All @@ -132,7 +143,7 @@ def __init__(
text_encoder: CLIPTextModel,
tokenizer: CLIPTokenizer,
unet: UNet2DConditionModel,
controlnet: ControlNetModel,
controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel],
scheduler: KarrasDiffusionSchedulers,
safety_checker: StableDiffusionSafetyChecker,
feature_extractor: CLIPImageProcessor,
Expand All @@ -156,6 +167,9 @@ def __init__(
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
)

if isinstance(controlnet, (list, tuple)):
controlnet = MultiControlNetModel(controlnet)

self.register_modules(
vae=vae,
text_encoder=text_encoder,
Expand Down Expand Up @@ -425,6 +439,42 @@ def prepare_extra_step_kwargs(self, generator, eta):
extra_step_kwargs["generator"] = generator
return extra_step_kwargs

def check_controlnet_conditioning_image(self, image, prompt, prompt_embeds):
image_is_pil = isinstance(image, PIL.Image.Image)
image_is_tensor = isinstance(image, torch.Tensor)
image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)

if not image_is_pil and not image_is_tensor and not image_is_pil_list and not image_is_tensor_list:
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)

if image_is_pil:
image_batch_size = 1
elif image_is_tensor:
image_batch_size = image.shape[0]
elif image_is_pil_list:
image_batch_size = len(image)
elif image_is_tensor_list:
image_batch_size = len(image)
else:
raise ValueError("controlnet condition image is not valid")

if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
else:
raise ValueError("prompt or prompt_embeds are not valid")

if image_batch_size != 1 and image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
)

def check_inputs(
self,
prompt,
Expand All @@ -439,6 +489,7 @@ def check_inputs(
strength=None,
controlnet_guidance_start=None,
controlnet_guidance_end=None,
controlnet_conditioning_scale=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
Expand Down Expand Up @@ -477,58 +528,51 @@ def check_inputs(
f" {negative_prompt_embeds.shape}."
)

controlnet_cond_image_is_pil = isinstance(controlnet_conditioning_image, PIL.Image.Image)
controlnet_cond_image_is_tensor = isinstance(controlnet_conditioning_image, torch.Tensor)
controlnet_cond_image_is_pil_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], PIL.Image.Image
)
controlnet_cond_image_is_tensor_list = isinstance(controlnet_conditioning_image, list) and isinstance(
controlnet_conditioning_image[0], torch.Tensor
)
# check controlnet condition image

if (
not controlnet_cond_image_is_pil
and not controlnet_cond_image_is_tensor
and not controlnet_cond_image_is_pil_list
and not controlnet_cond_image_is_tensor_list
):
raise TypeError(
"image must be passed and be one of PIL image, torch tensor, list of PIL images, or list of torch tensors"
)
if isinstance(self.controlnet, ControlNetModel):
self.check_controlnet_conditioning_image(controlnet_conditioning_image, prompt, prompt_embeds)
elif isinstance(self.controlnet, MultiControlNetModel):
if not isinstance(controlnet_conditioning_image, list):
raise TypeError("For multiple controlnets: `image` must be type `list`")

if controlnet_cond_image_is_pil:
controlnet_cond_image_batch_size = 1
elif controlnet_cond_image_is_tensor:
controlnet_cond_image_batch_size = controlnet_conditioning_image.shape[0]
elif controlnet_cond_image_is_pil_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
elif controlnet_cond_image_is_tensor_list:
controlnet_cond_image_batch_size = len(controlnet_conditioning_image)
if len(controlnet_conditioning_image) != len(self.controlnet.nets):
raise ValueError(
"For multiple controlnets: `image` must have the same length as the number of controlnets."
)

if prompt is not None and isinstance(prompt, str):
prompt_batch_size = 1
elif prompt is not None and isinstance(prompt, list):
prompt_batch_size = len(prompt)
elif prompt_embeds is not None:
prompt_batch_size = prompt_embeds.shape[0]
for image_ in controlnet_conditioning_image:
self.check_controlnet_conditioning_image(image_, prompt, prompt_embeds)
else:
assert False

if controlnet_cond_image_batch_size != 1 and controlnet_cond_image_batch_size != prompt_batch_size:
raise ValueError(
f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {controlnet_cond_image_batch_size}, prompt batch size: {prompt_batch_size}"
)
# Check `controlnet_conditioning_scale`

if isinstance(self.controlnet, ControlNetModel):
if not isinstance(controlnet_conditioning_scale, float):
raise TypeError("For single controlnet: `controlnet_conditioning_scale` must be type `float`.")
elif isinstance(self.controlnet, MultiControlNetModel):
if isinstance(controlnet_conditioning_scale, list) and len(controlnet_conditioning_scale) != len(
self.controlnet.nets
):
raise ValueError(
"For multiple controlnets: When `controlnet_conditioning_scale` is specified as `list`, it must have"
" the same length as the number of controlnets"
)
else:
assert False

if isinstance(image, torch.Tensor):
if image.ndim != 3 and image.ndim != 4:
raise ValueError("`image` must have 3 or 4 dimensions")

# if mask_image.ndim != 2 and mask_image.ndim != 3 and mask_image.ndim != 4:
# raise ValueError("`mask_image` must have 2, 3, or 4 dimensions")

if image.ndim == 3:
image_batch_size = 1
image_channels, image_height, image_width = image.shape
elif image.ndim == 4:
image_batch_size, image_channels, image_height, image_width = image.shape
else:
assert False

if image_channels != 3:
raise ValueError("`image` must have 3 channels")
Expand Down Expand Up @@ -660,7 +704,7 @@ def __call__(
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: float = 1.0,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
controlnet_guidance_start: float = 0.0,
controlnet_guidance_end: float = 1.0,
):
Expand Down Expand Up @@ -761,7 +805,6 @@ def __call__(
self.check_inputs(
prompt,
image,
# mask_image,
controlnet_conditioning_image,
height,
width,
Expand All @@ -772,6 +815,7 @@ def __call__(
strength,
controlnet_guidance_start,
controlnet_guidance_end,
controlnet_conditioning_scale,
)

# 2. Define call parameters
Expand All @@ -788,6 +832,9 @@ def __call__(
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0

if isinstance(self.controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(self.controlnet.nets)

# 3. Encode input prompt
prompt_embeds = self._encode_prompt(
prompt,
Expand All @@ -799,22 +846,41 @@ def __call__(
negative_prompt_embeds=negative_prompt_embeds,
)

# 4. Prepare mask, image, and controlnet_conditioning_image
# 4. Prepare image, and controlnet_conditioning_image
image = prepare_image(image)

# mask_image = prepare_mask_image(mask_image)
# condition image(s)
if isinstance(self.controlnet, ControlNetModel):
controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=controlnet_conditioning_image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)
elif isinstance(self.controlnet, MultiControlNetModel):
controlnet_conditioning_images = []

for image_ in controlnet_conditioning_image:
image_ = prepare_controlnet_conditioning_image(
controlnet_conditioning_image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=self.controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
)

controlnet_conditioning_image = prepare_controlnet_conditioning_image(
controlnet_conditioning_image,
width,
height,
batch_size * num_images_per_prompt,
num_images_per_prompt,
device,
self.controlnet.dtype,
)
controlnet_conditioning_images.append(image_)

# masked_image = image * (mask_image < 0.5)
controlnet_conditioning_image = controlnet_conditioning_images
else:
assert False

# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
Expand All @@ -832,9 +898,6 @@ def __call__(
generator,
)

if do_classifier_free_guidance:
controlnet_conditioning_image = torch.cat([controlnet_conditioning_image] * 2)

# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)

Expand Down Expand Up @@ -864,15 +927,10 @@ def __call__(
t,
encoder_hidden_states=prompt_embeds,
controlnet_cond=controlnet_conditioning_image,
conditioning_scale=controlnet_conditioning_scale,
return_dict=False,
)

down_block_res_samples = [
down_block_res_sample * controlnet_conditioning_scale
for down_block_res_sample in down_block_res_samples
]
mid_block_res_sample *= controlnet_conditioning_scale

# predict the noise residual
noise_pred = self.unet(
latent_model_input,
Expand Down