diff --git a/docs/source/en/api/pipelines/kolors.md b/docs/source/en/api/pipelines/kolors.md
index 1c083e4285bc..a35a29d8a061 100644
--- a/docs/source/en/api/pipelines/kolors.md
+++ b/docs/source/en/api/pipelines/kolors.md
@@ -41,6 +41,64 @@ image = pipe(
image.save("kolors_sample.png")
```
+### IP Adapter
+
+Kolors needs a different IP Adapter to work, and it uses [Openai-CLIP-336](https://huggingface.co/openai/clip-vit-large-patch14-336) as an image encoder.
+
+
+
+Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs.
+
+
+
+
+
+While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.
+
+
+
+```python
+import torch
+from transformers import CLIPVisionModelWithProjection
+
+from diffusers import DPMSolverMultistepScheduler, KolorsPipeline
+from diffusers.utils import load_image
+
+image_encoder = CLIPVisionModelWithProjection.from_pretrained(
+ "Kwai-Kolors/Kolors-IP-Adapter-Plus",
+ subfolder="image_encoder",
+ low_cpu_mem_usage=True,
+ torch_dtype=torch.float16,
+ revision="refs/pr/4",
+)
+
+pipe = KolorsPipeline.from_pretrained(
+ "Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16"
+).to("cuda")
+pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True)
+
+pipe.load_ip_adapter(
+ "Kwai-Kolors/Kolors-IP-Adapter-Plus",
+ subfolder="",
+ weight_name="ip_adapter_plus_general.safetensors",
+ revision="refs/pr/4",
+ image_encoder_folder=None,
+)
+pipe.enable_model_cpu_offload()
+
+ipa_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/cat_square.png")
+
+image = pipe(
+ prompt="best quality, high quality",
+ negative_prompt="",
+ guidance_scale=6.5,
+ num_inference_steps=25,
+ ip_adapter_image=ipa_image,
+).images[0]
+
+image.save("kolors_ipa_sample.png")
+```
+
## KolorsPipeline
[[autodoc]] KolorsPipeline
diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py
index ce9ed23caa94..44c8c0a5181c 100644
--- a/src/diffusers/loaders/ip_adapter.py
+++ b/src/diffusers/loaders/ip_adapter.py
@@ -222,7 +222,8 @@ def load_ip_adapter(
# create feature extractor if it has not been registered to the pipeline yet
if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None:
- feature_extractor = CLIPImageProcessor()
+ clip_image_size = self.image_encoder.config.image_size
+ feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size)
self.register_modules(feature_extractor=feature_extractor)
# load ip-adapter into unet
@@ -319,7 +320,13 @@ def unload_ip_adapter(self):
# remove hidden encoder
self.unet.encoder_hid_proj = None
- self.config.encoder_hid_dim_type = None
+ self.unet.config.encoder_hid_dim_type = None
+
+ # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj`
+ if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None:
+ self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj
+ self.unet.text_encoder_hid_proj = None
+ self.unet.config.encoder_hid_dim_type = "text_proj"
# restore original Unet attention processors layers
attn_procs = {}
diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py
index d6df03ad34f6..32ace77b6224 100644
--- a/src/diffusers/loaders/unet.py
+++ b/src/diffusers/loaders/unet.py
@@ -823,6 +823,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False):
if not isinstance(state_dicts, list):
state_dicts = [state_dicts]
+
+ # Kolors Unet already has a `encoder_hid_proj`
+ if (
+ self.encoder_hid_proj is not None
+ and self.config.encoder_hid_dim_type == "text_proj"
+ and not hasattr(self, "text_encoder_hid_proj")
+ ):
+ self.text_encoder_hid_proj = self.encoder_hid_proj
+
# Set encoder_hid_proj after loading ip_adapter weights,
# because `IPAdapterPlusImageProjection` also has `attn_processors`.
self.encoder_hid_proj = None
diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py
index 611ac6087e4a..9a168bd22c93 100644
--- a/src/diffusers/models/unets/unet_2d_condition.py
+++ b/src/diffusers/models/unets/unet_2d_condition.py
@@ -1027,6 +1027,10 @@ def process_encoder_hidden_states(
raise ValueError(
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
)
+
+ if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
+ encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
+
image_embeds = added_cond_kwargs.get("image_embeds")
image_embeds = self.encoder_hid_proj(image_embeds)
encoder_hidden_states = (encoder_hidden_states, image_embeds)
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py
index 2214c9ea2c58..b682429e9744 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py
@@ -15,11 +15,12 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
-from ...image_processor import VaeImageProcessor
-from ...loaders import StableDiffusionXLLoraLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...image_processor import PipelineImageInput, VaeImageProcessor
+from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -120,7 +121,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin):
+class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
r"""
Pipeline for text-to-image generation using Kolors.
@@ -130,6 +131,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -148,7 +150,11 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL
`Kwai-Kolors/Kolors-diffusers`.
"""
- model_cpu_offload_seq = "text_encoder->unet->vae"
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = [
+ "image_encoder",
+ "feature_extractor",
+ ]
_callback_tensor_inputs = [
"latents",
"prompt_embeds",
@@ -166,11 +172,21 @@ def __init__(
tokenizer: ChatGLMTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()
- self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -343,6 +359,77 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -364,6 +451,7 @@ def prepare_extra_step_kwargs(self, generator, eta):
def check_inputs(
self,
prompt,
+ num_inference_steps,
height,
width,
negative_prompt=None,
@@ -371,9 +459,17 @@ def check_inputs(
pooled_prompt_embeds=None,
negative_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
+ if not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -420,6 +516,21 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
if max_sequence_length is not None and max_sequence_length > 256:
raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}")
@@ -563,6 +674,8 @@ def __call__(
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -649,6 +762,12 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -719,6 +838,7 @@ def __call__(
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
+ num_inference_steps,
height,
width,
negative_prompt,
@@ -726,6 +846,8 @@ def __call__(
pooled_prompt_embeds,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -815,6 +937,15 @@ def __call__(
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
# 8. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -856,6 +987,9 @@ def __call__(
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
noise_pred = self.unet(
latent_model_input,
t,
diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
index f2c73665e723..81abdff0e9cc 100644
--- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
+++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py
@@ -16,11 +16,12 @@
import PIL.Image
import torch
+from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import PipelineImageInput, VaeImageProcessor
-from ...loaders import StableDiffusionXLLoraLoaderMixin
-from ...models import AutoencoderKL, UNet2DConditionModel
+from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin
+from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
from ...schedulers import KarrasDiffusionSchedulers
from ...utils import is_torch_xla_available, logging, replace_example_docstring
@@ -139,7 +140,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps
-class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin):
+class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin):
r"""
Pipeline for text-to-image generation using Kolors.
@@ -149,6 +150,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
The pipeline also inherits the following loading methods:
- [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
- [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
Args:
vae ([`AutoencoderKL`]):
@@ -167,10 +169,10 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
`Kwai-Kolors/Kolors-diffusers`.
"""
- model_cpu_offload_seq = "text_encoder->unet->vae"
+ model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae"
_optional_components = [
- "tokenizer",
- "text_encoder",
+ "image_encoder",
+ "feature_extractor",
]
_callback_tensor_inputs = [
"latents",
@@ -189,11 +191,21 @@ def __init__(
tokenizer: ChatGLMTokenizer,
unet: UNet2DConditionModel,
scheduler: KarrasDiffusionSchedulers,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ feature_extractor: CLIPImageProcessor = None,
force_zeros_for_empty_prompt: bool = False,
):
super().__init__()
- self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ image_encoder=image_encoder,
+ feature_extractor=feature_extractor,
+ )
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = (
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -367,6 +379,77 @@ def encode_prompt(
return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ image_embeds = []
+ if do_classifier_free_guidance:
+ negative_image_embeds = []
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+
+ image_embeds.append(single_image_embeds[None, :])
+ if do_classifier_free_guidance:
+ negative_image_embeds.append(single_negative_image_embeds[None, :])
+ else:
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ negative_image_embeds.append(single_negative_image_embeds)
+ image_embeds.append(single_image_embeds)
+
+ ip_adapter_image_embeds = []
+ for i, single_image_embeds in enumerate(image_embeds):
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
+ if do_classifier_free_guidance:
+ single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
+
+ single_image_embeds = single_image_embeds.to(device=device)
+ ip_adapter_image_embeds.append(single_image_embeds)
+
+ return ip_adapter_image_embeds
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
@@ -389,6 +472,7 @@ def check_inputs(
self,
prompt,
strength,
+ num_inference_steps,
height,
width,
negative_prompt=None,
@@ -396,12 +480,20 @@ def check_inputs(
pooled_prompt_embeds=None,
negative_prompt_embeds=None,
negative_pooled_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=None,
):
if strength < 0 or strength > 1:
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
+ if not isinstance(num_inference_steps, int) or num_inference_steps <= 0:
+ raise ValueError(
+ f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type"
+ f" {type(num_inference_steps)}."
+ )
+
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
@@ -448,6 +540,21 @@ def check_inputs(
"If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
)
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
if max_sequence_length is not None and max_sequence_length > 256:
raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}")
@@ -699,6 +806,8 @@ def __call__(
pooled_prompt_embeds: Optional[torch.Tensor] = None,
negative_prompt_embeds: Optional[torch.Tensor] = None,
negative_pooled_prompt_embeds: Optional[torch.Tensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -801,6 +910,12 @@ def __call__(
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should
+ contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -872,6 +987,7 @@ def __call__(
self.check_inputs(
prompt,
strength,
+ num_inference_steps,
height,
width,
negative_prompt,
@@ -879,6 +995,8 @@ def __call__(
pooled_prompt_embeds,
negative_prompt_embeds,
negative_pooled_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
callback_on_step_end_tensor_inputs,
max_sequence_length=max_sequence_length,
)
@@ -990,6 +1108,15 @@ def denoising_value_valid(dnv):
add_text_embeds = add_text_embeds.to(device)
add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1)
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
# 9. Denoising loop
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -1037,6 +1164,9 @@ def denoising_value_valid(dnv):
# predict the noise residual
added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ added_cond_kwargs["image_embeds"] = image_embeds
+
noise_pred = self.unet(
latent_model_input,
t,
diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py
index ba2156e4e8ac..3f7fcaf59575 100644
--- a/tests/pipelines/kolors/test_kolors.py
+++ b/tests/pipelines/kolors/test_kolors.py
@@ -96,6 +96,8 @@ def get_dummy_components(self, time_cond_proj_dim=None):
"vae": vae,
"text_encoder": text_encoder,
"tokenizer": tokenizer,
+ "image_encoder": None,
+ "feature_extractor": None,
}
return components
@@ -132,8 +134,10 @@ def test_inference(self):
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
self.assertLessEqual(max_diff, 1e-3)
- # should skip it but pipe._optional_components = [] so it doesn't
+ # throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter
+ # not sure if it is worth to fix it before integrating it to transformers
def test_save_load_optional_components(self):
+ # TODO (Alvaro) need to fix later
pass
# throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter