From 4bc752b1d66257f3d1c5ce1ee1445d20d67d2ebe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Thu, 18 Jul 2024 20:19:54 -0400 Subject: [PATCH 1/6] initial draft --- src/diffusers/loaders/ip_adapter.py | 3 +- src/diffusers/loaders/unet.py | 5 + .../models/unets/unet_2d_condition.py | 4 + .../pipelines/kolors/pipeline_kolors.py | 135 +++++++++++++++++- 4 files changed, 141 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ce9ed23caa94..dddbef43b36d 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -222,7 +222,8 @@ def load_ip_adapter( # create feature extractor if it has not been registered to the pipeline yet if hasattr(self, "feature_extractor") and getattr(self, "feature_extractor", None) is None: - feature_extractor = CLIPImageProcessor() + clip_image_size = self.image_encoder.config.image_size + feature_extractor = CLIPImageProcessor(size=clip_image_size, crop_size=clip_image_size) self.register_modules(feature_extractor=feature_extractor) # load ip-adapter into unet diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 0e002b2ba8a3..7f3bc7735299 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -1017,6 +1017,11 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): if not isinstance(state_dicts, list): state_dicts = [state_dicts] + + # Kolors + if self.encoder_hid_proj is not None and not hasattr(self, "text_encoder_hid_proj"): + self.text_encoder_hid_proj = self.encoder_hid_proj + # Set encoder_hid_proj after loading ip_adapter weights, # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 2b9122799bf3..f44921dd593a 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -1024,6 +1024,10 @@ def process_encoder_hidden_states( raise ValueError( f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" ) + + if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None: + encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states) + image_embeds = added_cond_kwargs.get("image_embeds") image_embeds = self.encoder_hid_proj(image_embeds) encoder_hidden_states = (encoder_hidden_states, image_embeds) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py index 2214c9ea2c58..63a99797b809 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -15,11 +15,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback -from ...image_processor import VaeImageProcessor -from ...loaders import StableDiffusionXLLoraLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -120,7 +121,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin): +class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using Kolors. @@ -149,6 +150,10 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL """ model_cpu_offload_seq = "text_encoder->unet->vae" + _optional_components = [ + "image_encoder", + "feature_extractor", + ] _callback_tensor_inputs = [ "latents", "prompt_embeds", @@ -166,11 +171,21 @@ def __init__( tokenizer: ChatGLMTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = False, ): super().__init__() - self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -343,6 +358,77 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -371,6 +457,8 @@ def check_inputs( pooled_prompt_embeds=None, negative_prompt_embeds=None, negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): @@ -420,6 +508,21 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + if max_sequence_length is not None and max_sequence_length > 256: raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") @@ -563,6 +666,8 @@ def __call__( pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -649,6 +754,12 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -726,6 +837,8 @@ def __call__( pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -815,6 +928,15 @@ def __call__( add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + # 8. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -856,6 +978,9 @@ def __call__( # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( latent_model_input, t, From ee31c9ef936c4cde60decc35bebff1d41fe3f8ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Sat, 20 Jul 2024 01:35:33 -0400 Subject: [PATCH 2/6] apply suggestions --- src/diffusers/loaders/ip_adapter.py | 8 +++++++- src/diffusers/loaders/unet.py | 8 ++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index dddbef43b36d..44c8c0a5181c 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -320,7 +320,13 @@ def unload_ip_adapter(self): # remove hidden encoder self.unet.encoder_hid_proj = None - self.config.encoder_hid_dim_type = None + self.unet.config.encoder_hid_dim_type = None + + # Kolors: restore `encoder_hid_proj` with `text_encoder_hid_proj` + if hasattr(self.unet, "text_encoder_hid_proj") and self.unet.text_encoder_hid_proj is not None: + self.unet.encoder_hid_proj = self.unet.text_encoder_hid_proj + self.unet.text_encoder_hid_proj = None + self.unet.config.encoder_hid_dim_type = "text_proj" # restore original Unet attention processors layers attn_procs = {} diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 7f3bc7735299..4ed5157c6e4b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -1018,8 +1018,12 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): if not isinstance(state_dicts, list): state_dicts = [state_dicts] - # Kolors - if self.encoder_hid_proj is not None and not hasattr(self, "text_encoder_hid_proj"): + # Kolors Unet already has a `encoder_hid_proj` + if ( + self.encoder_hid_proj is not None + and self.config.encoder_hid_dim_type == "text_proj" + and not hasattr(self, "text_encoder_hid_proj") + ): self.text_encoder_hid_proj = self.encoder_hid_proj # Set encoder_hid_proj after loading ip_adapter weights, From 75f66a5bff6fd2abc7e0c77c14a437f04b10c1e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Sat, 20 Jul 2024 02:02:01 -0400 Subject: [PATCH 3/6] fix failing test --- tests/pipelines/kolors/test_kolors.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/kolors/test_kolors.py b/tests/pipelines/kolors/test_kolors.py index ba2156e4e8ac..3f7fcaf59575 100644 --- a/tests/pipelines/kolors/test_kolors.py +++ b/tests/pipelines/kolors/test_kolors.py @@ -96,6 +96,8 @@ def get_dummy_components(self, time_cond_proj_dim=None): "vae": vae, "text_encoder": text_encoder, "tokenizer": tokenizer, + "image_encoder": None, + "feature_extractor": None, } return components @@ -132,8 +134,10 @@ def test_inference(self): max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) - # should skip it but pipe._optional_components = [] so it doesn't + # throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter + # not sure if it is worth to fix it before integrating it to transformers def test_save_load_optional_components(self): + # TODO (Alvaro) need to fix later pass # throws AttributeError: property 'eos_token' of 'ChatGLMTokenizer' object has no setter From a9d2b10bcf771d2523dcc21d7e94f2d3c868ebae Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Sat, 20 Jul 2024 03:04:37 -0400 Subject: [PATCH 4/6] added ipa to img2img --- .../pipelines/kolors/pipeline_kolors.py | 11 +- .../kolors/pipeline_kolors_img2img.py | 144 +++++++++++++++++- 2 files changed, 147 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors.py b/src/diffusers/pipelines/kolors/pipeline_kolors.py index 63a99797b809..b682429e9744 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors.py @@ -131,6 +131,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL The pipeline also inherits the following loading methods: - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -149,7 +150,7 @@ class KolorsPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLL `Kwai-Kolors/Kolors-diffusers`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" _optional_components = [ "image_encoder", "feature_extractor", @@ -450,6 +451,7 @@ def prepare_extra_step_kwargs(self, generator, eta): def check_inputs( self, prompt, + num_inference_steps, height, width, negative_prompt=None, @@ -462,6 +464,12 @@ def check_inputs( callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -830,6 +838,7 @@ def __call__( # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, + num_inference_steps, height, width, negative_prompt, diff --git a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py index f2c73665e723..81abdff0e9cc 100644 --- a/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py +++ b/src/diffusers/pipelines/kolors/pipeline_kolors_img2img.py @@ -16,11 +16,12 @@ import PIL.Image import torch +from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from ...callbacks import MultiPipelineCallbacks, PipelineCallback from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import StableDiffusionXLLoraLoaderMixin -from ...models import AutoencoderKL, UNet2DConditionModel +from ...loaders import IPAdapterMixin, StableDiffusionXLLoraLoaderMixin +from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor from ...schedulers import KarrasDiffusionSchedulers from ...utils import is_torch_xla_available, logging, replace_example_docstring @@ -139,7 +140,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin): +class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffusionXLLoraLoaderMixin, IPAdapterMixin): r""" Pipeline for text-to-image generation using Kolors. @@ -149,6 +150,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu The pipeline also inherits the following loading methods: - [`~loaders.StableDiffusionXLLoraLoaderMixin.load_lora_weights`] for loading LoRA weights - [`~loaders.StableDiffusionXLLoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters Args: vae ([`AutoencoderKL`]): @@ -167,10 +169,10 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu `Kwai-Kolors/Kolors-diffusers`. """ - model_cpu_offload_seq = "text_encoder->unet->vae" + model_cpu_offload_seq = "text_encoder->image_encoder-unet->vae" _optional_components = [ - "tokenizer", - "text_encoder", + "image_encoder", + "feature_extractor", ] _callback_tensor_inputs = [ "latents", @@ -189,11 +191,21 @@ def __init__( tokenizer: ChatGLMTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, + image_encoder: CLIPVisionModelWithProjection = None, + feature_extractor: CLIPImageProcessor = None, force_zeros_for_empty_prompt: bool = False, ): super().__init__() - self.register_modules(vae=vae, text_encoder=text_encoder, tokenizer=tokenizer, unet=unet, scheduler=scheduler) + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + image_encoder=image_encoder, + feature_extractor=feature_extractor, + ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = ( 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8 @@ -367,6 +379,77 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + image_embeds = [] + if do_classifier_free_guidance: + negative_image_embeds = [] + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + + image_embeds.append(single_image_embeds[None, :]) + if do_classifier_free_guidance: + negative_image_embeds.append(single_negative_image_embeds[None, :]) + else: + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + negative_image_embeds.append(single_negative_image_embeds) + image_embeds.append(single_image_embeds) + + ip_adapter_image_embeds = [] + for i, single_image_embeds in enumerate(image_embeds): + single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0) + if do_classifier_free_guidance: + single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0) + + single_image_embeds = single_image_embeds.to(device=device) + ip_adapter_image_embeds.append(single_image_embeds) + + return ip_adapter_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs def prepare_extra_step_kwargs(self, generator, eta): # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature @@ -389,6 +472,7 @@ def check_inputs( self, prompt, strength, + num_inference_steps, height, width, negative_prompt=None, @@ -396,12 +480,20 @@ def check_inputs( pooled_prompt_embeds=None, negative_prompt_embeds=None, negative_pooled_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, callback_on_step_end_tensor_inputs=None, max_sequence_length=None, ): if strength < 0 or strength > 1: raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + if not isinstance(num_inference_steps, int) or num_inference_steps <= 0: + raise ValueError( + f"`num_inference_steps` has to be a positive integer but is {num_inference_steps} of type" + f" {type(num_inference_steps)}." + ) + if height % 8 != 0 or width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") @@ -448,6 +540,21 @@ def check_inputs( "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`." ) + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + if max_sequence_length is not None and max_sequence_length > 256: raise ValueError(f"`max_sequence_length` cannot be greater than 256 but is {max_sequence_length}") @@ -699,6 +806,8 @@ def __call__( pooled_prompt_embeds: Optional[torch.Tensor] = None, negative_prompt_embeds: Optional[torch.Tensor] = None, negative_pooled_prompt_embeds: Optional[torch.Tensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -801,6 +910,12 @@ def __call__( Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of + IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should + contain the negative image embedding if `do_classifier_free_guidance` is set to `True`. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. @@ -872,6 +987,7 @@ def __call__( self.check_inputs( prompt, strength, + num_inference_steps, height, width, negative_prompt, @@ -879,6 +995,8 @@ def __call__( pooled_prompt_embeds, negative_prompt_embeds, negative_pooled_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, callback_on_step_end_tensor_inputs, max_sequence_length=max_sequence_length, ) @@ -990,6 +1108,15 @@ def denoising_value_valid(dnv): add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + # 9. Denoising loop num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) @@ -1037,6 +1164,9 @@ def denoising_value_valid(dnv): # predict the noise residual added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + added_cond_kwargs["image_embeds"] = image_embeds + noise_pred = self.unet( latent_model_input, t, From 57453e41acc7648b988a49d10d05becfda54c304 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Fri, 26 Jul 2024 06:58:56 -0400 Subject: [PATCH 5/6] add docs --- docs/source/en/api/pipelines/kolors.md | 58 ++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/docs/source/en/api/pipelines/kolors.md b/docs/source/en/api/pipelines/kolors.md index 1c083e4285bc..cf3818cc2a56 100644 --- a/docs/source/en/api/pipelines/kolors.md +++ b/docs/source/en/api/pipelines/kolors.md @@ -41,6 +41,64 @@ image = pipe( image.save("kolors_sample.png") ``` +## IP Adapter + +Kolors needs a different IP Adapter to be able to work, also it uses Openai-CLIP-336 as an image encoder. + + + +Using an IP Adapter with Kolors needs more than 24GB of VRAM, it is recommended to use `enable_model_cpu_offload()` in consumer GPUs to be able to use it. + + + + + +While the PR is merged, we need to give the revision and load the image encoder separatedly to use the safetensors format. You can still use the main branch of the original repository if you're comfortable loading `pickle` checkpoints. + + + +```python +import torch +from transformers import CLIPVisionModelWithProjection + +from diffusers import DPMSolverMultistepScheduler, KolorsPipeline +from diffusers.utils import load_image + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "Kwai-Kolors/Kolors-IP-Adapter-Plus", + subfolder="image_encoder", + low_cpu_mem_usage=True, + torch_dtype=torch.float16, + revision="refs/pr/4", +) + +pipe = KolorsPipeline.from_pretrained( + "Kwai-Kolors/Kolors-diffusers", image_encoder=image_encoder, torch_dtype=torch.float16, variant="fp16" +).to("cuda") +pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config, use_karras_sigmas=True) + +pipe.load_ip_adapter( + "Kwai-Kolors/Kolors-IP-Adapter-Plus", + subfolder="", + weight_name="ip_adapter_plus_general.safetensors", + revision="refs/pr/4", + image_encoder_folder=None, +) +pipe.enable_model_cpu_offload() + +ipa_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/kolors/cat_square.png") + +image = pipe( + prompt="best quality, high quality", + negative_prompt="", + guidance_scale=6.5, + num_inference_steps=25, + ip_adapter_image=ipa_image, +).images[0] + +image.save("kolors_ipa_sample.png") +``` + ## KolorsPipeline [[autodoc]] KolorsPipeline From 5f43f1e9961a6060b759f1b64ec49be5809b8608 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=81lvaro=20Somoza?= Date: Fri, 26 Jul 2024 13:40:10 -0400 Subject: [PATCH 6/6] apply suggestions --- docs/source/en/api/pipelines/kolors.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/en/api/pipelines/kolors.md b/docs/source/en/api/pipelines/kolors.md index cf3818cc2a56..a35a29d8a061 100644 --- a/docs/source/en/api/pipelines/kolors.md +++ b/docs/source/en/api/pipelines/kolors.md @@ -41,19 +41,19 @@ image = pipe( image.save("kolors_sample.png") ``` -## IP Adapter +### IP Adapter -Kolors needs a different IP Adapter to be able to work, also it uses Openai-CLIP-336 as an image encoder. +Kolors needs a different IP Adapter to work, and it uses [Openai-CLIP-336](https://huggingface.co/openai/clip-vit-large-patch14-336) as an image encoder. -Using an IP Adapter with Kolors needs more than 24GB of VRAM, it is recommended to use `enable_model_cpu_offload()` in consumer GPUs to be able to use it. +Using an IP Adapter with Kolors requires more than 24GB of VRAM. To use it, we recommend using [`~DiffusionPipeline.enable_model_cpu_offload`] on consumer GPUs. -While the PR is merged, we need to give the revision and load the image encoder separatedly to use the safetensors format. You can still use the main branch of the original repository if you're comfortable loading `pickle` checkpoints. +While Kolors is integrated in Diffusers, you need to load the image encoder from a revision to use the safetensor files. You can still use the main branch of the original repository if you're comfortable loading pickle checkpoints.