From 07efa23338f07a27241fb0d80d81c88bfebce97d Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Sat, 28 Sep 2024 13:46:03 -0400 Subject: [PATCH 1/3] image checking for controlnet flux --- .../flux/pipeline_flux_controlnet.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 6c072c482020..e1a4b1c6d20c 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -16,6 +16,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np +import PIL import torch from transformers import ( CLIPTextModel, @@ -389,10 +390,49 @@ def encode_prompt( return prompt_embeds, pooled_prompt_embeds, text_ids + # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image + def check_image(self, image, prompt, prompt_embeds): + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + def check_inputs( self, prompt, prompt_2, + image, height, width, prompt_embeds=None, @@ -429,6 +469,30 @@ def check_inputs( elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)): raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}") + if ( + isinstance(self.controlnet, FluxControlNetModel) + ): + self.check_image(image, prompt, prompt_embeds) + elif ( + isinstance(self.controlnet, FluxMultiControlNetModel) + ): + if not isinstance(image, list): + raise TypeError("For multiple controlnets: `image` must be type `list`") + + # When `image` is a nested list: + # (e.g. [[canny_image_1, pose_image_1], [canny_image_2, pose_image_2]]) + elif any(isinstance(i, list) for i in image): + raise ValueError("A single batch of multiple conditionings are supported at the moment.") + elif len(image) != len(self.controlnet.nets): + raise ValueError( + f"For multiple controlnets: `image` must have the same length as the number of controlnets, but got {len(image)} images and {len(self.controlnet.nets)} ControlNets." + ) + + for image_ in image: + self.check_image(image_, prompt, prompt_embeds) + else: + assert False + if prompt_embeds is not None and pooled_prompt_embeds is None: raise ValueError( "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`." @@ -678,6 +742,7 @@ def __call__( self.check_inputs( prompt, prompt_2, + control_image, height, width, prompt_embeds=prompt_embeds, From 1c8f02d6c1295f403f6aeb05c4dbfe0e090e2ed0 Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Wed, 2 Oct 2024 17:23:07 -0400 Subject: [PATCH 2/3] remove image type checks since it's redundant, but keep the prompt batch size check --- .../pipelines/flux/pipeline_flux_controlnet.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index e1a4b1c6d20c..784b568dcb9b 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -393,24 +393,6 @@ def encode_prompt( # Copied from diffusers.pipelines.controlnet.pipeline_controlnet.StableDiffusionControlNetPipeline.check_image def check_image(self, image, prompt, prompt_embeds): image_is_pil = isinstance(image, PIL.Image.Image) - image_is_tensor = isinstance(image, torch.Tensor) - image_is_np = isinstance(image, np.ndarray) - image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) - image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) - image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) - - if ( - not image_is_pil - and not image_is_tensor - and not image_is_np - and not image_is_pil_list - and not image_is_tensor_list - and not image_is_np_list - ): - raise TypeError( - f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" - ) - if image_is_pil: image_batch_size = 1 else: From 1bc52d2b700ebf5b3fe169379318281ef14714bf Mon Sep 17 00:00:00 2001 From: Christopher Beckham Date: Fri, 1 Nov 2024 16:24:19 -0400 Subject: [PATCH 3/3] add a just-in-case valuerror check inside prepare_image, also remove if statement bypassing preprocess for torch tensor type --- .../flux/pipeline_flux_controlnet.py | 20 ++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 784b568dcb9b..cee6c6af1273 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -569,18 +569,20 @@ def prepare_image( do_classifier_free_guidance=False, guess_mode=False, ): - if isinstance(image, torch.Tensor): - pass - else: - image = self.image_processor.preprocess(image, height=height, width=width) - + + image = self.image_processor.preprocess(image, height=height, width=width) image_batch_size = image.shape[0] if image_batch_size == 1: - repeat_by = batch_size - else: + repeat_by = batch_size*num_images_per_prompt + elif image_batch_size == batch_size: # image batch size is the same as prompt batch size repeat_by = num_images_per_prompt + else: + raise ValueError( + "`image_batch_size` must be either 1 or equal to the prompt " + \ + f"batch size, which is {batch_size}." + ) image = image.repeat_interleave(repeat_by, dim=0) @@ -773,7 +775,7 @@ def __call__( image=control_image, width=width, height=height, - batch_size=batch_size * num_images_per_prompt, + batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype, @@ -809,7 +811,7 @@ def __call__( image=control_image_, width=width, height=height, - batch_size=batch_size * num_images_per_prompt, + batch_size=batch_size, num_images_per_prompt=num_images_per_prompt, device=device, dtype=self.vae.dtype,