From 7d7fd830cbf59bfcc16e7b84bd12825d65766d08 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 2 Mar 2024 16:12:26 +0100 Subject: [PATCH 01/29] Switch to peft and multi proj layers --- examples/community/ip_adapter_face_id.py | 401 +++-------------------- 1 file changed, 39 insertions(+), 362 deletions(-) diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index b4d2446b5ce9..3bb1f5d6b566 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -26,7 +26,14 @@ from diffusers.image_processor import VaeImageProcessor from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.lora import LoRALinearLayer, adjust_lora_scale_text_encoder +from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) +from diffusers.models.embeddings import MultiIPAdapterImageProjection +from diffusers.models.lora import adjust_lora_scale_text_encoder from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker @@ -45,300 +52,6 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name -class LoRAIPAdapterAttnProcessor(nn.Module): - r""" - Attention processor for IP-Adapater. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - rank (`int`, defaults to 4): - The dimension of the LoRA update matrices. - network_alpha (`int`, *optional*): - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - lora_scale (`float`, defaults to 1.0): - the weight scale of LoRA. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__( - self, - hidden_size, - cross_attention_dim=None, - rank=4, - network_alpha=None, - lora_scale=1.0, - scale=1.0, - num_tokens=4, - ): - super().__init__() - - self.rank = rank - self.lora_scale = lora_scale - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - # separate ip_hidden_states from encoder_hidden_states - if encoder_hidden_states is not None: - if isinstance(encoder_hidden_states, tuple): - encoder_hidden_states, ip_hidden_states = encoder_hidden_states - else: - deprecation_message = ( - "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." - ) - deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) - end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - [encoder_hidden_states[:, end_pos:, :]], - ) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) - - query = attn.head_to_batch_dim(query) - key = attn.head_to_batch_dim(key) - value = attn.head_to_batch_dim(value) - - attention_probs = attn.get_attention_scores(query, key, attention_mask) - hidden_states = torch.bmm(attention_probs, value) - hidden_states = attn.batch_to_head_dim(hidden_states) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = attn.head_to_batch_dim(ip_key) - ip_value = attn.head_to_batch_dim(ip_value) - - ip_attention_probs = attn.get_attention_scores(query, ip_key, None) - ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) - ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - -class LoRAIPAdapterAttnProcessor2_0(nn.Module): - r""" - Attention processor for IP-Adapater for PyTorch 2.0. - Args: - hidden_size (`int`): - The hidden size of the attention layer. - cross_attention_dim (`int`): - The number of channels in the `encoder_hidden_states`. - rank (`int`, defaults to 4): - The dimension of the LoRA update matrices. - network_alpha (`int`, *optional*): - Equivalent to `alpha` but it's usage is specific to Kohya (A1111) style LoRAs. - lora_scale (`float`, defaults to 1.0): - the weight scale of LoRA. - scale (`float`, defaults to 1.0): - the weight scale of image prompt. - num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): - The context length of the image features. - """ - - def __init__( - self, - hidden_size, - cross_attention_dim=None, - rank=4, - network_alpha=None, - lora_scale=1.0, - scale=1.0, - num_tokens=4, - ): - super().__init__() - - self.rank = rank - self.lora_scale = lora_scale - - self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha) - self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha) - - self.hidden_size = hidden_size - self.cross_attention_dim = cross_attention_dim - self.scale = scale - self.num_tokens = num_tokens - - self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) - - def __call__( - self, - attn, - hidden_states, - encoder_hidden_states=None, - attention_mask=None, - temb=None, - ): - residual = hidden_states - - # separate ip_hidden_states from encoder_hidden_states - if encoder_hidden_states is not None: - if isinstance(encoder_hidden_states, tuple): - encoder_hidden_states, ip_hidden_states = encoder_hidden_states - else: - deprecation_message = ( - "You have passed a tensor as `encoder_hidden_states`.This is deprecated and will be removed in a future release." - " Please make sure to update your script to pass `encoder_hidden_states` as a tuple to supress this warning." - ) - deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False) - end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0] - encoder_hidden_states, ip_hidden_states = ( - encoder_hidden_states[:, :end_pos, :], - [encoder_hidden_states[:, end_pos:, :]], - ) - - if attn.spatial_norm is not None: - hidden_states = attn.spatial_norm(hidden_states, temb) - - input_ndim = hidden_states.ndim - - if input_ndim == 4: - batch_size, channel, height, width = hidden_states.shape - hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) - - batch_size, sequence_length, _ = ( - hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape - ) - - if attention_mask is not None: - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) - # scaled_dot_product_attention expects attention_mask shape to be - # (batch, heads, source_length, target_length) - attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) - - if attn.group_norm is not None: - hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) - - query = attn.to_q(hidden_states) + self.lora_scale * self.to_q_lora(hidden_states) - - if encoder_hidden_states is None: - encoder_hidden_states = hidden_states - elif attn.norm_cross: - encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) - - key = attn.to_k(encoder_hidden_states) + self.lora_scale * self.to_k_lora(encoder_hidden_states) - value = attn.to_v(encoder_hidden_states) + self.lora_scale * self.to_v_lora(encoder_hidden_states) - - inner_dim = key.shape[-1] - head_dim = inner_dim // attn.heads - - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - hidden_states = F.scaled_dot_product_attention( - query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False - ) - - hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - hidden_states = hidden_states.to(query.dtype) - - # for ip-adapter - ip_key = self.to_k_ip(ip_hidden_states) - ip_value = self.to_v_ip(ip_hidden_states) - - ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - # the output of sdp = (batch, num_heads, seq_len, head_dim) - # TODO: add support for attn.scale when we move to Torch 2.1 - ip_hidden_states = F.scaled_dot_product_attention( - query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False - ) - - ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) - ip_hidden_states = ip_hidden_states.to(query.dtype) - - hidden_states = hidden_states + self.scale * ip_hidden_states - - # linear proj - hidden_states = attn.to_out[0](hidden_states) + self.lora_scale * self.to_out_lora(hidden_states) - # dropout - hidden_states = attn.to_out[1](hidden_states) - - if input_ndim == 4: - hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) - - if attn.residual_connection: - hidden_states = hidden_states + residual - - hidden_states = hidden_states / attn.rescale_output_factor - - return hidden_states - - class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() @@ -615,10 +328,6 @@ def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): return image_projection def _load_ip_adapter_weights(self, state_dict): - from diffusers.models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - ) num_image_text_embeds = 4 @@ -626,6 +335,7 @@ def _load_ip_adapter_weights(self, state_dict): # set ip-adapter cross-attention processors & load state_dict attn_procs = {} + lora_dict = {} key_id = 0 for name in self.unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim @@ -642,94 +352,61 @@ def _load_ip_adapter_weights(self, state_dict): AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() - rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] - attn_module = self.unet - for n in name.split(".")[:-1]: - attn_module = getattr(attn_module, n) - # Set the `lora_layer` attribute of the attention-related matrices. - attn_module.to_q.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_q.in_features, - out_features=attn_module.to_q.out_features, - rank=rank, - ) - ) - attn_module.to_k.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_k.in_features, - out_features=attn_module.to_k.out_features, - rank=rank, - ) - ) - attn_module.to_v.set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_v.in_features, - out_features=attn_module.to_v.out_features, - rank=rank, - ) - ) - attn_module.to_out[0].set_lora_layer( - LoRALinearLayer( - in_features=attn_module.to_out[0].in_features, - out_features=attn_module.to_out[0].out_features, - rank=rank, - ) - ) - value_dict = {} - for k, module in attn_module.named_children(): - index = "." - if not hasattr(module, "set_lora_layer"): - index = ".0." - module = module[0] - lora_layer = getattr(module, "lora_layer") - for lora_name, w in lora_layer.state_dict().items(): - value_dict.update( - { - f"{k}{index}lora_layer.{lora_name}": state_dict["ip_adapter"][ - f"{key_id}.{k}_lora.{lora_name}" - ] - } - ) - - attn_module.load_state_dict(value_dict, strict=False) - attn_module.to(dtype=self.dtype, device=self.device) + lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) key_id += 1 else: - rank = state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"].shape[0] attn_processor_class = ( - LoRAIPAdapterAttnProcessor2_0 + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") - else LoRAIPAdapterAttnProcessor + else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, - rank=rank, num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) - value_dict = {} - for k, w in attn_procs[name].state_dict().items(): - value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]}) + lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + value_dict = {} + value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) attn_procs[name].load_state_dict(value_dict) key_id += 1 self.unet.set_attn_processor(attn_procs) + self.load_lora_weights(lora_dict, adapter_name="faceid") + self.set_adapters(["faceid"], adapter_weights=[1.0]) + # convert IP-Adapter Image Projection layers to diffusers image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)] - self.unet.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype) + self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) self.unet.config.encoder_hid_dim_type = "ip_image_proj" def set_ip_adapter_scale(self, scale): unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet for attn_processor in unet.attn_processors.values(): - if isinstance(attn_processor, (LoRAIPAdapterAttnProcessor, LoRAIPAdapterAttnProcessor2_0)): - attn_processor.scale = scale + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = [scale] def _encode_prompt( self, @@ -1298,7 +975,7 @@ def __call__( negative_image_embeds = torch.zeros_like(image_embeds) if self.do_classifier_free_guidance: image_embeds = torch.cat([negative_image_embeds, image_embeds]) - + image_embeds = [image_embeds] # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -1319,7 +996,7 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else None + added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {} # 6.2 Optionally get Guidance Scale Embedding timestep_cond = None From 4e56997ebc7bb160e9c0b1298c93bcdc2e14108b Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sun, 3 Mar 2024 20:35:18 +0100 Subject: [PATCH 02/29] Move Face ID loading and inference to core --- src/diffusers/loaders/ip_adapter.py | 6 +- src/diffusers/loaders/unet.py | 60 +++++++++++++++---- src/diffusers/models/embeddings.py | 13 +++- .../pipeline_stable_diffusion.py | 45 ++++++++------ 4 files changed, 90 insertions(+), 34 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 2b70ed84d7ed..ba6bd4202cad 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -225,7 +225,11 @@ def load_ip_adapter( # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + extra_lora = unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + + if extra_lora != {}: + self.load_lora_weights(extra_lora, adapter_name="faceid") + self.set_adapters(["faceid"], adapter_weights=[1.0]) def set_ip_adapter_scale(self, scale): """ diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9d8e2666c518..c8b2efd8eac0 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -727,14 +727,21 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us diffusers_name = key.replace("proj", "image_embeds") updated_state_dict[diffusers_name] = value - elif "proj.3.weight" in state_dict: - # IP-Adapter Full - clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] - cross_attention_dim = state_dict["proj.3.weight"].shape[0] + elif "proj.0.weight" in state_dict: + # IP-Adapter Full and Face ID + clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim with init_context(): image_projection = IPAdapterFullImageProjection( - cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, ) for key, value in state_dict.items(): @@ -816,7 +823,11 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F # set ip-adapter cross-attention processors & load state_dict attn_procs = {} + lora_dict = {} key_id = 1 + for state_dict in state_dicts: + if "0.to_k_lora.down.weight" in state_dict["ip_adapter"]: + key_id = 0 init_context = init_empty_weights if low_cpu_mem_usage else nullcontext for name in self.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim @@ -834,6 +845,19 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor ) attn_procs[name] = attn_processor_class() + + for state_dict in state_dicts: + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + key_id += 1 + break else: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor @@ -843,9 +867,12 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F if "proj.weight" in state_dict["image_proj"]: # IP-Adapter num_image_text_embeds += [4] - elif "proj.3.weight" in state_dict["image_proj"]: - # IP-Adapter Full Face - num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "proj.0.weight" in state_dict["image_proj"]: + # IP-Adapter Full Face and Face ID + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + num_image_text_embeds += [4] + else: + num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token else: # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] @@ -862,6 +889,15 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F for i, state_dict in enumerate(state_dicts): value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) if not low_cpu_mem_usage: attn_procs[name].load_state_dict(value_dict) @@ -870,9 +906,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F dtype = next(iter(value_dict.values())).dtype load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) - key_id += 2 + key_id += 2 if "0.to_k_lora.down.weight" not in state_dict["ip_adapter"] else 1 - return attn_procs + return attn_procs, lora_dict def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): if not isinstance(state_dicts, list): @@ -881,7 +917,7 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None - attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + attn_procs, lora_dict = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers @@ -896,3 +932,5 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): self.config.encoder_hid_dim_type = "ip_image_proj" self.to(dtype=self.dtype, device=self.device) + + return lora_dict diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 49f385d5f493..53fe77a6674d 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -463,15 +463,22 @@ def forward(self, image_embeds: torch.FloatTensor): class IPAdapterFullImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() from .attention import FeedForward - self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): - return self.norm(self.ff(image_embeds)) + if self.num_tokens == 4: + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + else: + return self.norm(self.ff(image_embeds)) class CombinedTimestepLabelEmbeddings(nn.Module): diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 9e4e6c186ffa..c8f977b4014c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -466,27 +466,34 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # TODO: add checks + dtype = next(self.unet.parameters()).dtype + image_embeds = image.to(device=device, dtype=dtype) uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( From 98a1aa4e925e8853ae905c69105cd3b1279885c5 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 9 Mar 2024 18:07:30 +0100 Subject: [PATCH 03/29] Add support for Face ID XL --- .../pipeline_stable_diffusion_xl.py | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 776696e9d486..4c87f6001c6d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -494,27 +494,34 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + # TODO: add checks + dtype = next(self.unet.parameters()).dtype + image_embeds = image.to(device=device, dtype=dtype) uncond_image_embeds = torch.zeros_like(image_embeds) - return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds From 03d84fac801d14e5a539c88c2ea68e62af4dbf9d Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sun, 10 Mar 2024 15:41:56 +0100 Subject: [PATCH 04/29] Add checks --- src/diffusers/loaders/ip_adapter.py | 11 +++++++---- .../stable_diffusion/pipeline_stable_diffusion.py | 8 +++++++- .../pipeline_stable_diffusion_xl.py | 8 +++++++- 3 files changed, 21 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index ba6bd4202cad..9eb117561fa2 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -214,8 +214,8 @@ def load_ip_adapter( ) else: logger.warning( - "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." - "Use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead." + "image_encoder is not loaded since `image_encoder_folder=None` passed. `ip_adapter_image` is allowed only if you are loading an IP-Adapter Face ID model." + "If you don't load an IP Adapter Face ID model, always use `ip_adapter_image_embedding` to pass pre-geneated image embedding instead." ) # create feature extractor if it has not been registered to the pipeline yet @@ -228,8 +228,11 @@ def load_ip_adapter( extra_lora = unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) if extra_lora != {}: - self.load_lora_weights(extra_lora, adapter_name="faceid") - self.set_adapters(["faceid"], adapter_weights=[1.0]) + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + if "faceid" not in peft_config: + self.load_lora_weights(extra_lora, adapter_name="faceid") + self.set_adapters(["faceid"], adapter_weights=[1.0]) def set_ip_adapter_scale(self, scale): """ diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index c8f977b4014c..38bbde58ae44 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -490,8 +490,14 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds else: - # TODO: add checks dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + + if image.ndim < 2: + image = image.unsqueeze(0) + image_embeds = image.to(device=device, dtype=dtype) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 4c87f6001c6d..5069359004f6 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -518,8 +518,14 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds else: - # TODO: add checks dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + + if image.ndim < 2: + image = image.unsqueeze(0) + image_embeds = image.to(device=device, dtype=dtype) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds From fe35a4264db9bd01b77275db0217cb49bfd7f542 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sun, 10 Mar 2024 15:54:58 +0100 Subject: [PATCH 05/29] Add test --- .../test_ip_adapter_stable_diffusion.py | 25 +++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index 6289ee887d13..db3d1cf69e14 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -37,6 +37,7 @@ from diffusers.utils.testing_utils import ( enable_full_determinism, is_flaky, + load_pt, numpy_cosine_similarity_distance, require_torch_gpu, slow, @@ -299,6 +300,30 @@ def test_multi(self): max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 + def test_text_to_image_face_id(self): + pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", safety_checker=None, torch_dtype=self.dtype + ) + pipeline.to(torch_device) + pipeline.load_ip_adapter( + "h94/IP-Adapter-FaceID", + subfolder=None, + weight_name="ip-adapter-faceid_sd15.bin", + image_encoder_folder=None, + ) + pipeline.set_ip_adapter_scale(0.7) + + inputs = self.get_dummy_inputs() + inputs["ip_adapter_image"] = load_pt( + "https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt" + ) + images = pipeline(**inputs).images + image_slice = images[0, :3, :3, -1].flatten() + expected_slice = np.array([0.1665, 0.1626, 0.2187, 0.1882, 0.1702, 0.2144, 0.1624, 0.2012, 0.2173]) + + max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) + assert max_diff < 5e-4 + @slow @require_torch_gpu From bc016f43e3fb07ddc85c5333ec1e1b70c848982a Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sun, 10 Mar 2024 15:55:08 +0100 Subject: [PATCH 06/29] Fix style --- src/diffusers/loaders/unet.py | 132 +++++++++++++++++++++++++++++----- 1 file changed, 115 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index c8b2efd8eac0..1aea0775edcd 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -848,14 +848,62 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F for state_dict in state_dicts: if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: - lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + lora_dict.update( + { + f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.up.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.up.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.up.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.up.weight" + ] + } + ) key_id += 1 break else: @@ -890,14 +938,62 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: - lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + lora_dict.update( + { + f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.up.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.up.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.up.weight" + ] + } + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.up.weight" + ] + } + ) if not low_cpu_mem_usage: attn_procs[name].load_state_dict(value_dict) @@ -917,7 +1013,9 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None - attn_procs, lora_dict = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + attn_procs, lora_dict = self._convert_ip_adapter_attn_to_diffusers( + state_dicts, low_cpu_mem_usage=low_cpu_mem_usage + ) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers From 92bfe5073cedf2ea420614b0b0ff6658b06bb14c Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 12 Mar 2024 18:22:18 +0100 Subject: [PATCH 07/29] Remove old pipeline --- examples/community/README.md | 62 -- examples/community/ip_adapter_face_id.py | 1083 ---------------------- 2 files changed, 1145 deletions(-) delete mode 100644 examples/community/ip_adapter_face_id.py diff --git a/examples/community/README.md b/examples/community/README.md index cf1e4cdd579d..6bc23719d008 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -60,7 +60,6 @@ If a community doesn't work as expected, please open an issue and ping the autho | Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender-A-Video) | - | [Yifan Zhou](https://github.com/SingleZombie) | | StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | | AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | -| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) | | InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) | | UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) | | Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | @@ -3582,67 +3581,6 @@ frames = output.frames[0] export_to_gif(frames, "animation.gif") ``` -### IP Adapter Face ID - -IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. -You need to install `insightface` and all its requirements to use this model. -You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. -You have to disable PEFT BACKEND in order to load weights. -You can find more results [here](https://github.com/huggingface/diffusers/pull/6276). - -```py -import diffusers -diffusers.utils.USE_PEFT_BACKEND = False -import torch -from diffusers.utils import load_image -import cv2 -import numpy as np -from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler -from insightface.app import FaceAnalysis - - -noise_scheduler = DDIMScheduler( - num_train_timesteps=1000, - beta_start=0.00085, - beta_end=0.012, - beta_schedule="scaled_linear", - clip_sample=False, - set_alpha_to_one=False, - steps_offset=1, -) -vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16) -pipeline = DiffusionPipeline.from_pretrained( - "SG161222/Realistic_Vision_V4.0_noVAE", - torch_dtype=torch.float16, - scheduler=noise_scheduler, - vae=vae, - custom_pipeline="ip_adapter_face_id" -) -pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin") -pipeline.to("cuda") - -generator = torch.Generator(device="cpu").manual_seed(42) -num_images=2 - -image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") - -app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) -app.prepare(ctx_id=0, det_size=(640, 640)) -image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) -faces = app.get(image) -image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) -images = pipeline( - prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", - image_embeds=image, - negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", - num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704, - generator=generator -).images - -for i in range(num_images): - images[i].save(f"c{i}.png") -``` - ### InstantID Pipeline InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usgae question, please refer to the [official implementation](https://github.com/InstantID/InstantID). diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py deleted file mode 100644 index 3bb1f5d6b566..000000000000 --- a/examples/community/ip_adapter_face_id.py +++ /dev/null @@ -1,1083 +0,0 @@ -# Copyright 2024 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import inspect -from typing import Any, Callable, Dict, List, Optional, Union - -import torch -import torch.nn as nn -import torch.nn.functional as F -from packaging import version -from safetensors import safe_open -from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection - -from diffusers.configuration_utils import FrozenDict -from diffusers.image_processor import VaeImageProcessor -from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin -from diffusers.models import AutoencoderKL, UNet2DConditionModel -from diffusers.models.attention_processor import ( - AttnProcessor, - AttnProcessor2_0, - IPAdapterAttnProcessor, - IPAdapterAttnProcessor2_0, -) -from diffusers.models.embeddings import MultiIPAdapterImageProjection -from diffusers.models.lora import adjust_lora_scale_text_encoder -from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin -from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput -from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker -from diffusers.schedulers import KarrasDiffusionSchedulers -from diffusers.utils import ( - USE_PEFT_BACKEND, - _get_model_file, - deprecate, - logging, - scale_lora_layers, - unscale_lora_layers, -) -from diffusers.utils.torch_utils import randn_tensor - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -class IPAdapterFullImageProjection(nn.Module): - def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): - super().__init__() - from diffusers.models.attention import FeedForward - - self.num_tokens = num_tokens - self.cross_attention_dim = cross_attention_dim - self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") - self.norm = nn.LayerNorm(cross_attention_dim) - - def forward(self, image_embeds: torch.FloatTensor): - x = self.ff(image_embeds) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - return self.norm(x) - - -def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): - """ - Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and - Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 - """ - std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) - std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) - # rescale the results from guidance (fixes overexposure) - noise_pred_rescaled = noise_cfg * (std_text / std_cfg) - # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images - noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg - return noise_cfg - - -def retrieve_timesteps( - scheduler, - num_inference_steps: Optional[int] = None, - device: Optional[Union[str, torch.device]] = None, - timesteps: Optional[List[int]] = None, - **kwargs, -): - """ - Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles - custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. - - Args: - scheduler (`SchedulerMixin`): - The scheduler to get timesteps from. - num_inference_steps (`int`): - The number of diffusion steps used when generating samples with a pre-trained model. If used, - `timesteps` must be `None`. - device (`str` or `torch.device`, *optional*): - The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - timesteps (`List[int]`, *optional*): - Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default - timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` - must be `None`. - - Returns: - `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the - second element is the number of inference steps. - """ - if timesteps is not None: - accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) - if not accepts_timesteps: - raise ValueError( - f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" - f" timestep schedules. Please check whether you are using the correct scheduler." - ) - scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) - timesteps = scheduler.timesteps - num_inference_steps = len(timesteps) - else: - scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) - timesteps = scheduler.timesteps - return timesteps, num_inference_steps - - -class IPAdapterFaceIDStableDiffusionPipeline( - DiffusionPipeline, - StableDiffusionMixin, - TextualInversionLoaderMixin, - LoraLoaderMixin, - IPAdapterMixin, - FromSingleFileMixin, -): - r""" - Pipeline for text-to-image generation using Stable Diffusion. - - This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods - implemented for all pipelines (downloading, saving, running on a particular device, etc.). - - The pipeline also inherits the following loading methods: - - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings - - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights - - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights - - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files - - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters - - Args: - vae ([`AutoencoderKL`]): - Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. - text_encoder ([`~transformers.CLIPTextModel`]): - Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). - tokenizer ([`~transformers.CLIPTokenizer`]): - A `CLIPTokenizer` to tokenize text. - unet ([`UNet2DConditionModel`]): - A `UNet2DConditionModel` to denoise the encoded image latents. - scheduler ([`SchedulerMixin`]): - A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of - [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. - safety_checker ([`StableDiffusionSafetyChecker`]): - Classification module that estimates whether generated images could be considered offensive or harmful. - Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details - about a model's potential harms. - feature_extractor ([`~transformers.CLIPImageProcessor`]): - A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. - """ - - model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" - _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] - _exclude_from_cpu_offload = ["safety_checker"] - _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] - - def __init__( - self, - vae: AutoencoderKL, - text_encoder: CLIPTextModel, - tokenizer: CLIPTokenizer, - unet: UNet2DConditionModel, - scheduler: KarrasDiffusionSchedulers, - safety_checker: StableDiffusionSafetyChecker, - feature_extractor: CLIPImageProcessor, - image_encoder: CLIPVisionModelWithProjection = None, - requires_safety_checker: bool = True, - ): - super().__init__() - - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" - f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " - "to update the config accordingly as leaving `steps_offset` might led to incorrect results" - " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," - " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" - " file" - ) - deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["steps_offset"] = 1 - scheduler._internal_dict = FrozenDict(new_config) - - if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: - deprecation_message = ( - f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." - " `clip_sample` should be set to False in the configuration file. Please make sure to update the" - " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" - " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" - " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" - ) - deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(scheduler.config) - new_config["clip_sample"] = False - scheduler._internal_dict = FrozenDict(new_config) - - if safety_checker is None and requires_safety_checker: - logger.warning( - f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" - " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" - " results in services or applications open to the public. Both the diffusers team and Hugging Face" - " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" - " it only for use-cases that involve analyzing network behavior or auditing its results. For more" - " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." - ) - - if safety_checker is not None and feature_extractor is None: - raise ValueError( - "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" - " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." - ) - - is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( - version.parse(unet.config._diffusers_version).base_version - ) < version.parse("0.9.0.dev0") - is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 - if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: - deprecation_message = ( - "The configuration file of the unet has set the default `sample_size` to smaller than" - " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" - " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" - " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" - " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" - " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" - " in the config might lead to incorrect results in future versions. If you have downloaded this" - " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" - " the `unet/config.json` file" - ) - deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) - new_config = dict(unet.config) - new_config["sample_size"] = 64 - unet._internal_dict = FrozenDict(new_config) - - self.register_modules( - vae=vae, - text_encoder=text_encoder, - tokenizer=tokenizer, - unet=unet, - scheduler=scheduler, - safety_checker=safety_checker, - feature_extractor=feature_extractor, - image_encoder=image_encoder, - ) - self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.register_to_config(requires_safety_checker=requires_safety_checker) - - def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs): - cache_dir = kwargs.pop("cache_dir", None) - force_download = kwargs.pop("force_download", False) - resume_download = kwargs.pop("resume_download", False) - proxies = kwargs.pop("proxies", None) - local_files_only = kwargs.pop("local_files_only", None) - token = kwargs.pop("token", None) - revision = kwargs.pop("revision", None) - subfolder = kwargs.pop("subfolder", None) - - user_agent = { - "file_type": "attn_procs_weights", - "framework": "pytorch", - } - model_file = _get_model_file( - pretrained_model_name_or_path_or_dict, - weights_name=weight_name, - cache_dir=cache_dir, - force_download=force_download, - resume_download=resume_download, - proxies=proxies, - local_files_only=local_files_only, - token=token, - revision=revision, - subfolder=subfolder, - user_agent=user_agent, - ) - if weight_name.endswith(".safetensors"): - state_dict = {"image_proj": {}, "ip_adapter": {}} - with safe_open(model_file, framework="pt", device="cpu") as f: - for key in f.keys(): - if key.startswith("image_proj."): - state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) - elif key.startswith("ip_adapter."): - state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) - else: - state_dict = torch.load(model_file, map_location="cpu") - self._load_ip_adapter_weights(state_dict) - - def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): - updated_state_dict = {} - clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] - clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] - multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in - norm_layer = "norm.weight" - cross_attention_dim = state_dict[norm_layer].shape[0] - num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim - - image_projection = IPAdapterFullImageProjection( - cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim_in, - mult=multiplier, - num_tokens=num_tokens, - ) - - for key, value in state_dict.items(): - diffusers_name = key.replace("proj.0", "ff.net.0.proj") - diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") - updated_state_dict[diffusers_name] = value - - image_projection.load_state_dict(updated_state_dict) - return image_projection - - def _load_ip_adapter_weights(self, state_dict): - - num_image_text_embeds = 4 - - self.unet.encoder_hid_proj = None - - # set ip-adapter cross-attention processors & load state_dict - attn_procs = {} - lora_dict = {} - key_id = 0 - for name in self.unet.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim - if name.startswith("mid_block"): - hidden_size = self.unet.config.block_out_channels[-1] - elif name.startswith("up_blocks"): - block_id = int(name[len("up_blocks.")]) - hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] - elif name.startswith("down_blocks"): - block_id = int(name[len("down_blocks.")]) - hidden_size = self.unet.config.block_out_channels[block_id] - if cross_attention_dim is None or "motion_modules" in name: - attn_processor_class = ( - AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor - ) - attn_procs[name] = attn_processor_class() - - lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) - key_id += 1 - else: - attn_processor_class = ( - IPAdapterAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else IPAdapterAttnProcessor - ) - attn_procs[name] = attn_processor_class( - hidden_size=hidden_size, - cross_attention_dim=cross_attention_dim, - scale=1.0, - num_tokens=num_image_text_embeds, - ).to(dtype=self.dtype, device=self.device) - - lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) - - value_dict = {} - value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) - value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) - attn_procs[name].load_state_dict(value_dict) - key_id += 1 - - self.unet.set_attn_processor(attn_procs) - - self.load_lora_weights(lora_dict, adapter_name="faceid") - self.set_adapters(["faceid"], adapter_weights=[1.0]) - - # convert IP-Adapter Image Projection layers to diffusers - image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) - image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)] - - self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) - self.unet.config.encoder_hid_dim_type = "ip_image_proj" - - def set_ip_adapter_scale(self, scale): - unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - for attn_processor in unet.attn_processors.values(): - if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): - attn_processor.scale = [scale] - - def _encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - **kwargs, - ): - deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." - deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) - - prompt_embeds_tuple = self.encode_prompt( - prompt=prompt, - device=device, - num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, - negative_prompt=negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - **kwargs, - ) - - # concatenate for backwards comp - prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) - - return prompt_embeds - - def encode_prompt( - self, - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt=None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - lora_scale: Optional[float] = None, - clip_skip: Optional[int] = None, - ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - do_classifier_free_guidance (`bool`): - whether to use classifier free guidance or not - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts not to guide the image generation. If not defined, one has to pass - `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is - less than `1`). - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt - weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input - argument. - lora_scale (`float`, *optional*): - A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - """ - # set lora scale so that monkey patched LoRA - # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, LoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not USE_PEFT_BACKEND: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - if prompt_embeds is None: - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - prompt = self.maybe_convert_prompt(prompt, self.tokenizer) - - text_inputs = self.tokenizer( - prompt, - padding="max_length", - max_length=self.tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - text_input_ids = text_inputs.input_ids - untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids - - if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( - text_input_ids, untruncated_ids - ): - removed_text = self.tokenizer.batch_decode( - untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] - ) - logger.warning( - "The following part of your input was truncated because CLIP can only handle sequences up to" - f" {self.tokenizer.model_max_length} tokens: {removed_text}" - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = text_inputs.attention_mask.to(device) - else: - attention_mask = None - - if clip_skip is None: - prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) - prompt_embeds = prompt_embeds[0] - else: - prompt_embeds = self.text_encoder( - text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True - ) - # Access the `hidden_states` first, that contains a tuple of - # all the hidden states from the encoder layers. Then index into - # the tuple to access the hidden states from the desired layer. - prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] - # We also need to apply the final LayerNorm here to not mess with the - # representations. The `last_hidden_states` that we typically use for - # obtaining the final prompt representations passes through the LayerNorm - # layer. - prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) - - if self.text_encoder is not None: - prompt_embeds_dtype = self.text_encoder.dtype - elif self.unet is not None: - prompt_embeds_dtype = self.unet.dtype - else: - prompt_embeds_dtype = prompt_embeds.dtype - - prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # get unconditional embeddings for classifier free guidance - if do_classifier_free_guidance and negative_prompt_embeds is None: - uncond_tokens: List[str] - if negative_prompt is None: - uncond_tokens = [""] * batch_size - elif prompt is not None and type(prompt) is not type(negative_prompt): - raise TypeError( - f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" - f" {type(prompt)}." - ) - elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] - elif batch_size != len(negative_prompt): - raise ValueError( - f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" - f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" - " the batch size of `prompt`." - ) - else: - uncond_tokens = negative_prompt - - # textual inversion: process multi-vector tokens if necessary - if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) - - max_length = prompt_embeds.shape[1] - uncond_input = self.tokenizer( - uncond_tokens, - padding="max_length", - max_length=max_length, - truncation=True, - return_tensors="pt", - ) - - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: - attention_mask = uncond_input.attention_mask.to(device) - else: - attention_mask = None - - negative_prompt_embeds = self.text_encoder( - uncond_input.input_ids.to(device), - attention_mask=attention_mask, - ) - negative_prompt_embeds = negative_prompt_embeds[0] - - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder, lora_scale) - - return prompt_embeds, negative_prompt_embeds - - def run_safety_checker(self, image, device, dtype): - if self.safety_checker is None: - has_nsfw_concept = None - else: - if torch.is_tensor(image): - feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") - else: - feature_extractor_input = self.image_processor.numpy_to_pil(image) - safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) - image, has_nsfw_concept = self.safety_checker( - images=image, clip_input=safety_checker_input.pixel_values.to(dtype) - ) - return image, has_nsfw_concept - - def decode_latents(self, latents): - deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" - deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) - - latents = 1 / self.vae.config.scaling_factor * latents - image = self.vae.decode(latents, return_dict=False)[0] - image = (image / 2 + 0.5).clamp(0, 1) - # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 - image = image.cpu().permute(0, 2, 3, 1).float().numpy() - return image - - def prepare_extra_step_kwargs(self, generator, eta): - # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature - # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. - # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 - # and should be between [0, 1] - - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) - extra_step_kwargs = {} - if accepts_eta: - extra_step_kwargs["eta"] = eta - - # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) - if accepts_generator: - extra_step_kwargs["generator"] = generator - return extra_step_kwargs - - def check_inputs( - self, - prompt, - height, - width, - callback_steps, - negative_prompt=None, - prompt_embeds=None, - negative_prompt_embeds=None, - callback_on_step_end_tensor_inputs=None, - ): - if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") - - if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): - raise ValueError( - f"`callback_steps` has to be a positive integer but is {callback_steps} of type" - f" {type(callback_steps)}." - ) - if callback_on_step_end_tensor_inputs is not None and not all( - k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs - ): - raise ValueError( - f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" - ) - - if prompt is not None and prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" - " only forward one of the two." - ) - elif prompt is None and prompt_embeds is None: - raise ValueError( - "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." - ) - elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") - - if negative_prompt is not None and negative_prompt_embeds is not None: - raise ValueError( - f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" - f" {negative_prompt_embeds}. Please make sure to only forward one of the two." - ) - - if prompt_embeds is not None and negative_prompt_embeds is not None: - if prompt_embeds.shape != negative_prompt_embeds.shape: - raise ValueError( - "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" - f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" - f" {negative_prompt_embeds.shape}." - ) - - def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - else: - latents = latents.to(device) - - # scale the initial noise by the standard deviation required by the scheduler - latents = latents * self.scheduler.init_noise_sigma - return latents - - # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding - def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): - """ - See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 - - Args: - timesteps (`torch.Tensor`): - generate embedding vectors at these timesteps - embedding_dim (`int`, *optional*, defaults to 512): - dimension of the embeddings to generate - dtype: - data type of the generated embeddings - - Returns: - `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` - """ - assert len(w.shape) == 1 - w = w * 1000.0 - - half_dim = embedding_dim // 2 - emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) - emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) - emb = w.to(dtype)[:, None] * emb[None, :] - emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) - if embedding_dim % 2 == 1: # zero pad - emb = torch.nn.functional.pad(emb, (0, 1)) - assert emb.shape == (w.shape[0], embedding_dim) - return emb - - @property - def guidance_scale(self): - return self._guidance_scale - - @property - def guidance_rescale(self): - return self._guidance_rescale - - @property - def clip_skip(self): - return self._clip_skip - - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - @property - def do_classifier_free_guidance(self): - return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None - - @property - def cross_attention_kwargs(self): - return self._cross_attention_kwargs - - @property - def num_timesteps(self): - return self._num_timesteps - - @property - def interrupt(self): - return self._interrupt - - @torch.no_grad() - def __call__( - self, - prompt: Union[str, List[str]] = None, - height: Optional[int] = None, - width: Optional[int] = None, - num_inference_steps: int = 50, - timesteps: List[int] = None, - guidance_scale: float = 7.5, - negative_prompt: Optional[Union[str, List[str]]] = None, - num_images_per_prompt: Optional[int] = 1, - eta: float = 0.0, - generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, - latents: Optional[torch.FloatTensor] = None, - prompt_embeds: Optional[torch.FloatTensor] = None, - negative_prompt_embeds: Optional[torch.FloatTensor] = None, - image_embeds: Optional[torch.FloatTensor] = None, - output_type: Optional[str] = "pil", - return_dict: bool = True, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - guidance_rescale: float = 0.0, - clip_skip: Optional[int] = None, - callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, - callback_on_step_end_tensor_inputs: List[str] = ["latents"], - **kwargs, - ): - r""" - The call function to the pipeline for generation. - - Args: - prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The height in pixels of the generated image. - width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): - The width in pixels of the generated image. - num_inference_steps (`int`, *optional*, defaults to 50): - The number of denoising steps. More denoising steps usually lead to a higher quality image at the - expense of slower inference. - timesteps (`List[int]`, *optional*): - Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument - in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is - passed will be used. Must be in descending order. - guidance_scale (`float`, *optional*, defaults to 7.5): - A higher guidance scale value encourages the model to generate images closely linked to the text - `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. - negative_prompt (`str` or `List[str]`, *optional*): - The prompt or prompts to guide what to not include in image generation. If not defined, you need to - pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). - num_images_per_prompt (`int`, *optional*, defaults to 1): - The number of images to generate per prompt. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies - to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. - generator (`torch.Generator` or `List[torch.Generator]`, *optional*): - A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make - generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor is generated by sampling using the supplied random `generator`. - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not - provided, text embeddings are generated from the `prompt` input argument. - negative_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If - not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. - image_embeds (`torch.FloatTensor`, *optional*): - Pre-generated image embeddings. - output_type (`str`, *optional*, defaults to `"pil"`): - The output format of the generated image. Choose between `PIL.Image` or `np.array`. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a - plain tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in - [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - guidance_rescale (`float`, *optional*, defaults to 0.0): - Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are - Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when - using zero terminal SNR. - clip_skip (`int`, *optional*): - Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that - the output of the pre-final layer will be used for computing the prompt embeddings. - callback_on_step_end (`Callable`, *optional*): - A function that calls at the end of each denoising steps during the inference. The function is called - with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, - callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by - `callback_on_step_end_tensor_inputs`. - callback_on_step_end_tensor_inputs (`List`, *optional*): - The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list - will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the - `._callback_tensor_inputs` attribute of your pipeline class. - - Examples: - - Returns: - [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: - If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, - otherwise a `tuple` is returned where the first element is a list with the generated images and the - second element is a list of `bool`s indicating whether the corresponding generated image contains - "not-safe-for-work" (nsfw) content. - """ - - callback = kwargs.pop("callback", None) - callback_steps = kwargs.pop("callback_steps", None) - - if callback is not None: - deprecate( - "callback", - "1.0.0", - "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - if callback_steps is not None: - deprecate( - "callback_steps", - "1.0.0", - "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", - ) - - # 0. Default height and width to unet - height = height or self.unet.config.sample_size * self.vae_scale_factor - width = width or self.unet.config.sample_size * self.vae_scale_factor - # to deal with lora scaling and other possible forward hooks - - # 1. Check inputs. Raise error if not correct - self.check_inputs( - prompt, - height, - width, - callback_steps, - negative_prompt, - prompt_embeds, - negative_prompt_embeds, - callback_on_step_end_tensor_inputs, - ) - - self._guidance_scale = guidance_scale - self._guidance_rescale = guidance_rescale - self._clip_skip = clip_skip - self._cross_attention_kwargs = cross_attention_kwargs - self._interrupt = False - - # 2. Define call parameters - if prompt is not None and isinstance(prompt, str): - batch_size = 1 - elif prompt is not None and isinstance(prompt, list): - batch_size = len(prompt) - else: - batch_size = prompt_embeds.shape[0] - - device = self._execution_device - - # 3. Encode input prompt - lora_scale = ( - self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None - ) - - prompt_embeds, negative_prompt_embeds = self.encode_prompt( - prompt, - device, - num_images_per_prompt, - self.do_classifier_free_guidance, - negative_prompt, - prompt_embeds=prompt_embeds, - negative_prompt_embeds=negative_prompt_embeds, - lora_scale=lora_scale, - clip_skip=self.clip_skip, - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - if self.do_classifier_free_guidance: - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - - if image_embeds is not None: - image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to( - device=device, dtype=prompt_embeds.dtype - ) - negative_image_embeds = torch.zeros_like(image_embeds) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) - image_embeds = [image_embeds] - # 4. Prepare timesteps - timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) - - # 5. Prepare latent variables - num_channels_latents = self.unet.config.in_channels - latents = self.prepare_latents( - batch_size * num_images_per_prompt, - num_channels_latents, - height, - width, - prompt_embeds.dtype, - device, - generator, - latents, - ) - - # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline - extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - - # 6.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {} - - # 6.2 Optionally get Guidance Scale Embedding - timestep_cond = None - if self.unet.config.time_cond_proj_dim is not None: - guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) - timestep_cond = self.get_guidance_scale_embedding( - guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim - ).to(device=device, dtype=latents.dtype) - - # 7. Denoising loop - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order - self._num_timesteps = len(timesteps) - with self.progress_bar(total=num_inference_steps) as progress_bar: - for i, t in enumerate(timesteps): - if self.interrupt: - continue - - # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) - - # predict the noise residual - noise_pred = self.unet( - latent_model_input, - t, - encoder_hidden_states=prompt_embeds, - timestep_cond=timestep_cond, - cross_attention_kwargs=self.cross_attention_kwargs, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) - - if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: - # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf - noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) - - # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] - - if callback_on_step_end is not None: - callback_kwargs = {} - for k in callback_on_step_end_tensor_inputs: - callback_kwargs[k] = locals()[k] - callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) - - latents = callback_outputs.pop("latents", latents) - prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) - negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) - - # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): - progress_bar.update() - if callback is not None and i % callback_steps == 0: - step_idx = i // getattr(self.scheduler, "order", 1) - callback(step_idx, t, latents) - - if not output_type == "latent": - image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ - 0 - ] - image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) - else: - image = latents - has_nsfw_concept = None - - if has_nsfw_concept is None: - do_denormalize = [True] * image.shape[0] - else: - do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] - - image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) - - # Offload all models - self.maybe_free_model_hooks() - - if not return_dict: - return (image, has_nsfw_concept) - - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From b8eed8db14148e8eb8384ac9db79981c3d793016 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 12 Mar 2024 20:30:08 +0100 Subject: [PATCH 08/29] Fix copies --- .../animatediff/pipeline_animatediff.py | 51 ++++++++++++------- .../pipeline_animatediff_video2video.py | 51 ++++++++++++------- .../controlnet/pipeline_controlnet.py | 49 +++++++++++------- .../controlnet/pipeline_controlnet_img2img.py | 49 +++++++++++------- .../controlnet/pipeline_controlnet_inpaint.py | 47 ++++++++++------- .../pipeline_controlnet_inpaint_sd_xl.py | 47 ++++++++++------- .../controlnet/pipeline_controlnet_sd_xl.py | 49 +++++++++++------- .../pipeline_controlnet_sd_xl_img2img.py | 49 +++++++++++------- .../pipeline_latent_consistency_img2img.py | 49 +++++++++++------- .../pipeline_latent_consistency_text2img.py | 51 ++++++++++++------- src/diffusers/pipelines/pia/pipeline_pia.py | 51 ++++++++++++------- .../pipeline_stable_diffusion_img2img.py | 49 +++++++++++------- .../pipeline_stable_diffusion_inpaint.py | 51 ++++++++++++------- ...eline_stable_diffusion_instruct_pix2pix.py | 49 +++++++++++------- .../pipeline_stable_diffusion_ldm3d.py | 51 ++++++++++++------- .../pipeline_stable_diffusion_panorama.py | 51 ++++++++++++------- .../pipeline_stable_diffusion_safe.py | 51 ++++++++++++------- .../pipeline_stable_diffusion_sag.py | 51 ++++++++++++------- .../pipeline_stable_diffusion_xl_img2img.py | 49 +++++++++++------- .../pipeline_stable_diffusion_xl_inpaint.py | 51 ++++++++++++------- .../pipeline_stable_diffusion_xl_adapter.py | 51 ++++++++++++------- 21 files changed, 660 insertions(+), 387 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index cd7f0a283b63..a93003c1bd62 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -345,27 +345,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index cb6b71351faf..eaad150f2fc6 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -423,27 +423,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 8f31dfc2678a..039d50a104a5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -455,27 +455,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 9d2c76fd7483..ecc8cb0554a7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -448,27 +448,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index c4f1bff5efcd..67cf76d0eb7a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -573,27 +573,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 52ffe5a3f356..6e2f55a1d1c7 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -482,27 +482,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index eca81083be7b..114e3bd056f5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -460,27 +460,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 86a0e2c570d8..e71acc3ad412 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -512,27 +512,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index f64854ea982b..411caec36097 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -398,27 +398,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index e9bacaa89ba5..349f5f58d9a4 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -382,27 +382,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 507088991a5e..688b826351c2 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -470,27 +470,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index b43e0eb2abcd..2b305afefea9 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -509,27 +509,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 221d5c2cfd3f..93df3c1362ba 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -581,27 +581,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 89d4278937fe..06efcca6bb9e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -632,27 +632,40 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index c7c05feaf013..2dc97079416d 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -448,27 +448,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index feda710e0049..5f98e36cb9d9 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -359,27 +359,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index ae74e09678e3..309aa2ff168b 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -476,27 +476,40 @@ def perform_safety_guidance( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds @torch.no_grad() diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index 96aa006d2ab3..4b09279f6f45 100644 --- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -378,27 +378,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index fd4c412f48cb..1ac018b8f668 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -727,27 +727,40 @@ def prepare_latents( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c25628c22c7b..11692fae62b7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -431,27 +431,40 @@ def __init__( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index 4e0cc61f5c1d..e09639613096 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -508,27 +508,40 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states + if self.image_encoder is not None: + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) + dtype = next(self.unet.parameters()).dtype + + if not isinstance(image, torch.Tensor): + raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + if image.ndim < 2: + image = image.unsqueeze(0) + + image_embeds = image.to(device=device, dtype=dtype) + uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds From ec539d0c71fe459ab98437ff3dbeeda4315e15ee Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 12 Mar 2024 21:37:39 +0100 Subject: [PATCH 09/29] Fix loading for full face --- src/diffusers/loaders/unet.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 1aea0775edcd..ec55298d1679 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -732,7 +732,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in - norm_layer = "norm.weight" + norm_layer = "norm.weight" if "norm.weight" in state_dict else "proj.3.weight" cross_attention_dim = state_dict[norm_layer].shape[0] num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim From 9521f89050164448621c11ac46c9de4ae9ba9fc9 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 13 Mar 2024 18:26:22 +0100 Subject: [PATCH 10/29] Revert community pipeline delete --- examples/community/README.md | 62 ++ examples/community/ip_adapter_face_id.py | 1083 ++++++++++++++++++++++ 2 files changed, 1145 insertions(+) create mode 100644 examples/community/ip_adapter_face_id.py diff --git a/examples/community/README.md b/examples/community/README.md index 6bc23719d008..cf1e4cdd579d 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -60,6 +60,7 @@ If a community doesn't work as expected, please open an issue and ping the autho | Rerender A Video Pipeline | Implementation of [[SIGGRAPH Asia 2023] Rerender A Video: Zero-Shot Text-Guided Video-to-Video Translation](https://arxiv.org/abs/2306.07954) | [Rerender A Video Pipeline](#Rerender-A-Video) | - | [Yifan Zhou](https://github.com/SingleZombie) | | StyleAligned Pipeline | Implementation of [Style Aligned Image Generation via Shared Attention](https://arxiv.org/abs/2312.02133) | [StyleAligned Pipeline](#stylealigned-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/15X2E0jFPTajUIjS0FzX50OaHsCbP2lQ0/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | | AnimateDiff Image-To-Video Pipeline | Experimental Image-To-Video support for AnimateDiff (open to improvements) | [AnimateDiff Image To Video Pipeline](#animatediff-image-to-video-pipeline) | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://drive.google.com/file/d/1TvzCDPHhfFtdcJZe4RLloAwyoLKuttWK/view?usp=sharing) | [Aryan V S](https://github.com/a-r-r-o-w) | +| IP Adapter FaceID Stable Diffusion | Stable Diffusion Pipeline that supports IP Adapter Face ID | [IP Adapter Face ID](#ip-adapter-face-id) | - | [Fabio Rigano](https://github.com/fabiorigano) | | InstantID Pipeline | Stable Diffusion XL Pipeline that supports InstantID | [InstantID Pipeline](#instantid-pipeline) | [![Hugging Face Space](https://img.shields.io/badge/🤗%20Hugging%20Face-Space-yellow)](https://huggingface.co/spaces/InstantX/InstantID) | [Haofan Wang](https://github.com/haofanwang) | | UFOGen Scheduler | Scheduler for UFOGen Model (compatible with Stable Diffusion pipelines) | [UFOGen Scheduler](#ufogen-scheduler) | - | [dg845](https://github.com/dg845) | | Stable Diffusion XL IPEX Pipeline | Accelerate Stable Diffusion XL inference pipeline with BF16/FP32 precision on Intel Xeon CPUs with [IPEX](https://github.com/intel/intel-extension-for-pytorch) | [Stable Diffusion XL on IPEX](#stable-diffusion-xl-on-ipex) | - | [Dan Li](https://github.com/ustcuna/) | @@ -3581,6 +3582,67 @@ frames = output.frames[0] export_to_gif(frames, "animation.gif") ``` +### IP Adapter Face ID + +IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. +You need to install `insightface` and all its requirements to use this model. +You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. +You have to disable PEFT BACKEND in order to load weights. +You can find more results [here](https://github.com/huggingface/diffusers/pull/6276). + +```py +import diffusers +diffusers.utils.USE_PEFT_BACKEND = False +import torch +from diffusers.utils import load_image +import cv2 +import numpy as np +from diffusers import DiffusionPipeline, AutoencoderKL, DDIMScheduler +from insightface.app import FaceAnalysis + + +noise_scheduler = DDIMScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, + steps_offset=1, +) +vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float16) +pipeline = DiffusionPipeline.from_pretrained( + "SG161222/Realistic_Vision_V4.0_noVAE", + torch_dtype=torch.float16, + scheduler=noise_scheduler, + vae=vae, + custom_pipeline="ip_adapter_face_id" +) +pipeline.load_ip_adapter_face_id("h94/IP-Adapter-FaceID", "ip-adapter-faceid_sd15.bin") +pipeline.to("cuda") + +generator = torch.Generator(device="cpu").manual_seed(42) +num_images=2 + +image = load_image("https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/ai_face2.png") + +app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) +image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) +faces = app.get(image) +image = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) +images = pipeline( + prompt="A photo of a girl wearing a black dress, holding red roses in hand, upper body, behind is the Eiffel Tower", + image_embeds=image, + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=num_images, width=512, height=704, + generator=generator +).images + +for i in range(num_images): + images[i].save(f"c{i}.png") +``` + ### InstantID Pipeline InstantID is a new state-of-the-art tuning-free method to achieve ID-Preserving generation with only single image, supporting various downstream tasks. For any usgae question, please refer to the [official implementation](https://github.com/InstantID/InstantID). diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py new file mode 100644 index 000000000000..3bb1f5d6b566 --- /dev/null +++ b/examples/community/ip_adapter_face_id.py @@ -0,0 +1,1083 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from packaging import version +from safetensors import safe_open +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.configuration_utils import FrozenDict +from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.models.attention_processor import ( + AttnProcessor, + AttnProcessor2_0, + IPAdapterAttnProcessor, + IPAdapterAttnProcessor2_0, +) +from diffusers.models.embeddings import MultiIPAdapterImageProjection +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + _get_model_file, + deprecate, + logging, + scale_lora_layers, + unscale_lora_layers, +) +from diffusers.utils.torch_utils import randn_tensor + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): + super().__init__() + from diffusers.models.attention import FeedForward + + self.num_tokens = num_tokens + self.cross_attention_dim = cross_attention_dim + self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class IPAdapterFaceIDStableDiffusionPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + LoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + def load_ip_adapter_face_id(self, pretrained_model_name_or_path_or_dict, weight_name, **kwargs): + cache_dir = kwargs.pop("cache_dir", None) + force_download = kwargs.pop("force_download", False) + resume_download = kwargs.pop("resume_download", False) + proxies = kwargs.pop("proxies", None) + local_files_only = kwargs.pop("local_files_only", None) + token = kwargs.pop("token", None) + revision = kwargs.pop("revision", None) + subfolder = kwargs.pop("subfolder", None) + + user_agent = { + "file_type": "attn_procs_weights", + "framework": "pytorch", + } + model_file = _get_model_file( + pretrained_model_name_or_path_or_dict, + weights_name=weight_name, + cache_dir=cache_dir, + force_download=force_download, + resume_download=resume_download, + proxies=proxies, + local_files_only=local_files_only, + token=token, + revision=revision, + subfolder=subfolder, + user_agent=user_agent, + ) + if weight_name.endswith(".safetensors"): + state_dict = {"image_proj": {}, "ip_adapter": {}} + with safe_open(model_file, framework="pt", device="cpu") as f: + for key in f.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key) + elif key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key) + else: + state_dict = torch.load(model_file, map_location="cpu") + self._load_ip_adapter_weights(state_dict) + + def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): + updated_state_dict = {} + clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in + norm_layer = "norm.weight" + cross_attention_dim = state_dict[norm_layer].shape[0] + num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim + + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, + image_embed_dim=clip_embeddings_dim_in, + mult=multiplier, + num_tokens=num_tokens, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + updated_state_dict[diffusers_name] = value + + image_projection.load_state_dict(updated_state_dict) + return image_projection + + def _load_ip_adapter_weights(self, state_dict): + + num_image_text_embeds = 4 + + self.unet.encoder_hid_proj = None + + # set ip-adapter cross-attention processors & load state_dict + attn_procs = {} + lora_dict = {} + key_id = 0 + for name in self.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else self.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = self.unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(self.unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = self.unet.config.block_out_channels[block_id] + if cross_attention_dim is None or "motion_modules" in name: + attn_processor_class = ( + AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor + ) + attn_procs[name] = attn_processor_class() + + lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + key_id += 1 + else: + attn_processor_class = ( + IPAdapterAttnProcessor2_0 + if hasattr(F, "scaled_dot_product_attention") + else IPAdapterAttnProcessor + ) + attn_procs[name] = attn_processor_class( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + scale=1.0, + num_tokens=num_image_text_embeds, + ).to(dtype=self.dtype, device=self.device) + + lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) + lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) + lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + + value_dict = {} + value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) + value_dict.update({"to_v_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) + attn_procs[name].load_state_dict(value_dict) + key_id += 1 + + self.unet.set_attn_processor(attn_procs) + + self.load_lora_weights(lora_dict, adapter_name="faceid") + self.set_adapters(["faceid"], adapter_weights=[1.0]) + + # convert IP-Adapter Image Projection layers to diffusers + image_projection = self.convert_ip_adapter_image_proj_to_diffusers(state_dict["image_proj"]) + image_projection_layers = [image_projection.to(device=self.device, dtype=self.dtype)] + + self.unet.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.unet.config.encoder_hid_dim_type = "ip_image_proj" + + def set_ip_adapter_scale(self, scale): + unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet + for attn_processor in unet.attn_processors.values(): + if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)): + attn_processor.scale = [scale] + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + def decode_latents(self, latents): + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) + + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents, return_dict=False)[0] + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + callback_on_step_end_tensor_inputs=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def guidance_rescale(self): + return self._guidance_rescale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + @property + def interrupt(self): + return self._interrupt + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + image_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + image_embeds (`torch.FloatTensor`, *optional*): + Pre-generated image embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + guidance_rescale (`float`, *optional*, defaults to 0.0): + Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are + Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when + using zero terminal SNR. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeline class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + # to deal with lora scaling and other possible forward hooks + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + height, + width, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._guidance_rescale = guidance_rescale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + self._interrupt = False + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # 3. Encode input prompt + lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + clip_skip=self.clip_skip, + ) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if image_embeds is not None: + image_embeds = torch.stack([image_embeds] * num_images_per_prompt, dim=0).to( + device=device, dtype=prompt_embeds.dtype + ) + negative_image_embeds = torch.zeros_like(image_embeds) + if self.do_classifier_free_guidance: + image_embeds = torch.cat([negative_image_embeds, image_embeds]) + image_embeds = [image_embeds] + # 4. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 6.1 Add image embeds for IP-Adapter + added_cond_kwargs = {"image_embeds": image_embeds} if image_embeds is not None else {} + + # 6.2 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + self._num_timesteps = len(timesteps) + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if self.interrupt: + continue + + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + if self.do_classifier_free_guidance and self.guidance_rescale > 0.0: + # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf + noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + if not output_type == "latent": + image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False, generator=generator)[ + 0 + ] + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + else: + image = latents + has_nsfw_concept = None + + if has_nsfw_concept is None: + do_denormalize = [True] * image.shape[0] + else: + do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept] + + image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) From aad463c453d85fc0d7a2c210f0dc3e1bf1c94366 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 13 Mar 2024 18:27:32 +0100 Subject: [PATCH 11/29] Revert copies --- .../animatediff/pipeline_animatediff.py | 51 +++++++------------ .../pipeline_animatediff_video2video.py | 51 +++++++------------ .../controlnet/pipeline_controlnet.py | 49 +++++++----------- .../controlnet/pipeline_controlnet_img2img.py | 49 +++++++----------- .../controlnet/pipeline_controlnet_inpaint.py | 47 +++++++---------- .../pipeline_controlnet_inpaint_sd_xl.py | 47 +++++++---------- .../controlnet/pipeline_controlnet_sd_xl.py | 49 +++++++----------- .../pipeline_controlnet_sd_xl_img2img.py | 49 +++++++----------- .../pipeline_latent_consistency_img2img.py | 49 +++++++----------- .../pipeline_latent_consistency_text2img.py | 51 +++++++------------ src/diffusers/pipelines/pia/pipeline_pia.py | 51 +++++++------------ .../pipeline_stable_diffusion_img2img.py | 49 +++++++----------- .../pipeline_stable_diffusion_inpaint.py | 51 +++++++------------ ...eline_stable_diffusion_instruct_pix2pix.py | 49 +++++++----------- .../pipeline_stable_diffusion_ldm3d.py | 51 +++++++------------ .../pipeline_stable_diffusion_panorama.py | 51 +++++++------------ .../pipeline_stable_diffusion_safe.py | 51 +++++++------------ .../pipeline_stable_diffusion_sag.py | 51 +++++++------------ .../pipeline_stable_diffusion_xl_img2img.py | 49 +++++++----------- .../pipeline_stable_diffusion_xl_inpaint.py | 51 +++++++------------ .../pipeline_stable_diffusion_xl_adapter.py | 51 +++++++------------ 21 files changed, 387 insertions(+), 660 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index a93003c1bd62..cd7f0a283b63 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -345,40 +345,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index eaad150f2fc6..cb6b71351faf 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -423,40 +423,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 039d50a104a5..8f31dfc2678a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -455,40 +455,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index ecc8cb0554a7..9d2c76fd7483 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -448,40 +448,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 67cf76d0eb7a..c4f1bff5efcd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -573,40 +573,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype + dtype = next(self.image_encoder.parameters()).dtype - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 6e2f55a1d1c7..52ffe5a3f356 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -482,40 +482,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype + dtype = next(self.image_encoder.parameters()).dtype - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 114e3bd056f5..eca81083be7b 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -460,40 +460,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index e71acc3ad412..86a0e2c570d8 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -512,40 +512,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index 411caec36097..f64854ea982b 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -398,40 +398,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index 349f5f58d9a4..e9bacaa89ba5 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -382,40 +382,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 688b826351c2..507088991a5e 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -470,40 +470,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 2b305afefea9..b43e0eb2abcd 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -509,40 +509,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 93df3c1362ba..221d5c2cfd3f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -581,40 +581,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index 06efcca6bb9e..89d4278937fe 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -632,40 +632,27 @@ def _encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker diff --git a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py index 2dc97079416d..c7c05feaf013 100644 --- a/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +++ b/src/diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py @@ -448,40 +448,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 5f98e36cb9d9..feda710e0049 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -359,40 +359,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py index 309aa2ff168b..ae74e09678e3 100644 --- a/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +++ b/src/diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py @@ -476,40 +476,27 @@ def perform_safety_guidance( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds @torch.no_grad() diff --git a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py index 4b09279f6f45..96aa006d2ab3 100644 --- a/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +++ b/src/diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py @@ -378,40 +378,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 1ac018b8f668..fd4c412f48cb 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -727,40 +727,27 @@ def prepare_latents( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds - else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") + dtype = next(self.image_encoder.parameters()).dtype - if image.ndim < 2: - image = image.unsqueeze(0) + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values - image_embeds = image.to(device=device, dtype=dtype) + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 11692fae62b7..c25628c22c7b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -431,40 +431,27 @@ def __init__( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index e09639613096..4e0cc61f5c1d 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -508,40 +508,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds From e3c35188d2d995adb384c56acd1de0b653872db6 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 13 Mar 2024 18:47:35 +0100 Subject: [PATCH 12/29] Revert encode_image and loading warning --- src/diffusers/loaders/ip_adapter.py | 4 +- .../pipeline_stable_diffusion.py | 51 +++++++------------ .../pipeline_stable_diffusion_xl.py | 51 +++++++------------ 3 files changed, 40 insertions(+), 66 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index c0fe1d5c1a90..d511171e8155 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -214,8 +214,8 @@ def load_ip_adapter( ) else: logger.warning( - "image_encoder is not loaded since `image_encoder_folder=None` passed. `ip_adapter_image` is allowed only if you are loading an IP-Adapter Face ID model." - "If you don't load an IP Adapter Face ID model, always use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." + "image_encoder is not loaded since `image_encoder_folder=None` passed. You will not be able to use `ip_adapter_image` when calling the pipeline with IP-Adapter." + "Use `ip_adapter_image_embeds` to pass pre-generated image embedding instead." ) # create feature extractor if it has not been registered to the pipeline yet diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 38bbde58ae44..9e4e6c186ffa 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -466,40 +466,27 @@ def encode_prompt( return prompt_embeds, negative_prompt_embeds def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds def prepare_ip_adapter_image_embeds( diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 5069359004f6..776696e9d486 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -494,40 +494,27 @@ def encode_prompt( # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): - if self.image_encoder is not None: - dtype = next(self.image_encoder.parameters()).dtype - - if not isinstance(image, torch.Tensor): - image = self.feature_extractor(image, return_tensors="pt").pixel_values - - image = image.to(device=device, dtype=dtype) - if output_hidden_states: - image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] - image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_enc_hidden_states = self.image_encoder( - torch.zeros_like(image), output_hidden_states=True - ).hidden_states[-2] - uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( - num_images_per_prompt, dim=0 - ) - return image_enc_hidden_states, uncond_image_enc_hidden_states - else: - image_embeds = self.image_encoder(image).image_embeds - image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) - uncond_image_embeds = torch.zeros_like(image_embeds) - - return image_embeds, uncond_image_embeds + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states else: - dtype = next(self.unet.parameters()).dtype - - if not isinstance(image, torch.Tensor): - raise ValueError("When no image encoder is loaded, `image` must be a torch.Tensor") - - if image.ndim < 2: - image = image.unsqueeze(0) - - image_embeds = image.to(device=device, dtype=dtype) + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) + return image_embeds, uncond_image_embeds # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds From b8fe711c210ed8e5026fe0d9faa77b15f7cfa7dd Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 13 Mar 2024 19:11:54 +0100 Subject: [PATCH 13/29] Add a separate loop to load lora weights --- src/diffusers/loaders/ip_adapter.py | 11 +- src/diffusers/loaders/unet.py | 188 ++++++++++------------------ 2 files changed, 71 insertions(+), 128 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index d511171e8155..01ce666689bc 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -225,14 +225,15 @@ def load_ip_adapter( # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - extra_lora = unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + extra_loras = unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) - if extra_lora != {}: + if extra_loras != {}: # apply the IP Adapter Face ID LoRA weights peft_config = getattr(unet, "peft_config", {}) - if "faceid" not in peft_config: - self.load_lora_weights(extra_lora, adapter_name="faceid") - self.set_adapters(["faceid"], adapter_weights=[1.0]) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) def set_ip_adapter_scale(self, scale): """ diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index ec55298d1679..28bf6b778fa7 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -823,11 +823,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F # set ip-adapter cross-attention processors & load state_dict attn_procs = {} - lora_dict = {} key_id = 1 - for state_dict in state_dicts: - if "0.to_k_lora.down.weight" in state_dict["ip_adapter"]: - key_id = 0 init_context = init_empty_weights if low_cpu_mem_usage else nullcontext for name in self.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim @@ -846,66 +842,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F ) attn_procs[name] = attn_processor_class() - for state_dict in state_dicts: - if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: - lora_dict.update( - { - f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_k_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_q_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_v_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_out_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_k_lora.up.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_q_lora.up.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_v_lora.up.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_out_lora.up.weight" - ] - } - ) - key_id += 1 - break else: attn_processor_class = ( IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor @@ -937,63 +873,6 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F for i, state_dict in enumerate(state_dicts): value_dict.update({f"to_k_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) value_dict.update({f"to_v_ip.{i}.weight": state_dict["ip_adapter"][f"{key_id}.to_v_ip.weight"]}) - if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: - lora_dict.update( - { - f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_k_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_q_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_v_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ - f"{key_id}.to_out_lora.down.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_k_lora.up.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_q_lora.up.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_v_lora.up.weight" - ] - } - ) - lora_dict.update( - { - f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_out_lora.up.weight" - ] - } - ) if not low_cpu_mem_usage: attn_procs[name].load_state_dict(value_dict) @@ -1002,9 +881,72 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F dtype = next(iter(value_dict.values())).dtype load_model_dict_into_meta(attn_procs[name], value_dict, device=device, dtype=dtype) - key_id += 2 if "0.to_k_lora.down.weight" not in state_dict["ip_adapter"] else 1 + key_id += 2 + + lora_dicts = {} + for key_id, name in enumerate(self.attn_processors.keys()): + for i, state_dict in enumerate(state_dicts): + if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: + if i not in lora_dicts: + lora_dicts[i] = {} + lora_dicts[i].update( + { + f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_k_lora.up.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_q_lora.up.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_v_lora.up.weight" + ] + } + ) + lora_dicts[i].update( + { + f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.up.weight" + ] + } + ) - return attn_procs, lora_dict + return attn_procs, lora_dicts def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): if not isinstance(state_dicts, list): From 3c9382d87aba9609870917ff3cad595467faa5bd Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 13 Mar 2024 22:21:49 +0100 Subject: [PATCH 14/29] Fix style --- examples/community/ip_adapter_face_id.py | 77 ++++++++++++++++++------ src/diffusers/loaders/unet.py | 18 +----- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/examples/community/ip_adapter_face_id.py b/examples/community/ip_adapter_face_id.py index 3bb1f5d6b566..f644c287be04 100644 --- a/examples/community/ip_adapter_face_id.py +++ b/examples/community/ip_adapter_face_id.py @@ -328,7 +328,6 @@ def convert_ip_adapter_image_proj_to_diffusers(self, state_dict): return image_projection def _load_ip_adapter_weights(self, state_dict): - num_image_text_embeds = 4 self.unet.encoder_hid_proj = None @@ -353,20 +352,38 @@ def _load_ip_adapter_weights(self, state_dict): ) attn_procs[name] = attn_processor_class() - lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + lora_dict.update( + {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]} + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dict.update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]} + ) key_id += 1 else: attn_processor_class = ( - IPAdapterAttnProcessor2_0 - if hasattr(F, "scaled_dot_product_attention") - else IPAdapterAttnProcessor + IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor ) attn_procs[name] = attn_processor_class( hidden_size=hidden_size, @@ -375,14 +392,34 @@ def _load_ip_adapter_weights(self, state_dict): num_tokens=num_image_text_embeds, ).to(dtype=self.dtype, device=self.device) - lora_dict.update({f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.down.weight"]}) - lora_dict.update({f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]}) - lora_dict.update({f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]}) + lora_dict.update( + {f"unet.{name}.to_k_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.down.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.down.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.down.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.down.weight"]} + ) + lora_dict.update( + { + f"unet.{name}.to_out_lora.down.weight": state_dict["ip_adapter"][ + f"{key_id}.to_out_lora.down.weight" + ] + } + ) + lora_dict.update( + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} + ) + lora_dict.update( + {f"unet.{name}.to_out_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_out_lora.up.weight"]} + ) value_dict = {} value_dict.update({"to_k_ip.0.weight": state_dict["ip_adapter"][f"{key_id}.to_k_ip.weight"]}) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 81edc0bdd46e..529f9f925bff 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -923,25 +923,13 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F } ) lora_dicts[i].update( - { - f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_k_lora.up.weight" - ] - } + {f"unet.{name}.to_k_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_k_lora.up.weight"]} ) lora_dicts[i].update( - { - f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_q_lora.up.weight" - ] - } + {f"unet.{name}.to_q_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_q_lora.up.weight"]} ) lora_dicts[i].update( - { - f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][ - f"{key_id}.to_v_lora.up.weight" - ] - } + {f"unet.{name}.to_v_lora.up.weight": state_dict["ip_adapter"][f"{key_id}.to_v_lora.up.weight"]} ) lora_dicts[i].update( { From c6f106e746845951786285a4ef9e86601d2e59ee Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sat, 23 Mar 2024 15:24:42 +0100 Subject: [PATCH 15/29] Split Full and Face ID blocks --- src/diffusers/loaders/unet.py | 38 +++++++++++++++++++++--------- src/diffusers/models/embeddings.py | 20 +++++++++++----- 2 files changed, 41 insertions(+), 17 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 529f9f925bff..757474602da5 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -27,6 +27,7 @@ from ..models.embeddings import ( ImageProjection, + IPAdapterFaceIDImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, @@ -732,17 +733,33 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us diffusers_name = key.replace("proj", "image_embeds") updated_state_dict[diffusers_name] = value - elif "proj.0.weight" in state_dict: - # IP-Adapter Full and Face ID + elif "proj.3.weight" in state_dict: + # IP-Adapter Full + clip_embeddings_dim = state_dict["proj.0.weight"].shape[0] + cross_attention_dim = state_dict["proj.3.weight"].shape[0] + + with init_context(): + image_projection = IPAdapterFullImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("proj.0", "ff.net.0.proj") + diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") + diffusers_name = diffusers_name.replace("proj.3", "norm") + updated_state_dict[diffusers_name] = value + + elif "norm.weight" in state_dict: + # IP-Adapter Face ID clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in - norm_layer = "norm.weight" if "norm.weight" in state_dict else "proj.3.weight" + norm_layer = "norm.weight" cross_attention_dim = state_dict[norm_layer].shape[0] num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim with init_context(): - image_projection = IPAdapterFullImageProjection( + image_projection = IPAdapterFaceIDImageProjection( cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim_in, mult=multiplier, @@ -752,7 +769,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us for key, value in state_dict.items(): diffusers_name = key.replace("proj.0", "ff.net.0.proj") diffusers_name = diffusers_name.replace("proj.2", "ff.net.2") - diffusers_name = diffusers_name.replace("proj.3", "norm") updated_state_dict[diffusers_name] = value else: @@ -856,12 +872,12 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F if "proj.weight" in state_dict["image_proj"]: # IP-Adapter num_image_text_embeds += [4] - elif "proj.0.weight" in state_dict["image_proj"]: - # IP-Adapter Full Face and Face ID - if f"{key_id}.to_k_lora.down.weight" in state_dict["ip_adapter"]: - num_image_text_embeds += [4] - else: - num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "proj.3.weight" in state_dict["image_proj"]: + # IP-Adapter Full Face + num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "norm.weight" in state_dict["image_proj"]: + # IP-Adapter Face ID + num_image_text_embeds += [4] else: # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 3cd4d272ceb6..a6473baf7e0b 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -462,6 +462,17 @@ def forward(self, image_embeds: torch.FloatTensor): class IPAdapterFullImageProjection(nn.Module): + def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): + super().__init__() + from .attention import FeedForward + + self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") + self.norm = nn.LayerNorm(cross_attention_dim) + + def forward(self, image_embeds: torch.FloatTensor): + return self.norm(self.ff(image_embeds)) + +class IPAdapterFaceIDImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() from .attention import FeedForward @@ -472,12 +483,9 @@ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_t self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.FloatTensor): - if self.num_tokens == 4: - x = self.ff(image_embeds) - x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) - return self.norm(x) - else: - return self.norm(self.ff(image_embeds)) + x = self.ff(image_embeds) + x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) + return self.norm(x) class CombinedTimestepLabelEmbeddings(nn.Module): From 21ed0cc3421f1ef525fec53b350565fc55a0f401 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 27 Mar 2024 21:31:27 +0100 Subject: [PATCH 16/29] Load Face ID Plus --- src/diffusers/loaders/unet.py | 56 +++++++++++++++++++++++++++++++++++ 1 file changed, 56 insertions(+) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 757474602da5..54f92397de1c 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -28,6 +28,7 @@ from ..models.embeddings import ( ImageProjection, IPAdapterFaceIDImageProjection, + IPAdapterFaceIDPlusImageProjection, IPAdapterFullImageProjection, IPAdapterPlusImageProjection, MultiIPAdapterImageProjection, @@ -749,6 +750,58 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us diffusers_name = diffusers_name.replace("proj.3", "norm") updated_state_dict[diffusers_name] = value + elif "perceiver_resampler.proj_in.weight" in state_dict: + # IP-Adapter Face ID Plus + id_embeddings_dim = state_dict["proj.0.weight"].shape[1] + embed_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[0] + hidden_dims = state_dict["perceiver_resampler.proj_in.weight"].shape[1] + output_dims = state_dict["perceiver_resampler.proj_out.weight"].shape[0] + heads = state_dict["perceiver_resampler.layers.0.0.to_q.weight"].shape[0] // 64 + + with init_context(): + image_projection = IPAdapterFaceIDPlusImageProjection( + embed_dims=embed_dims, + output_dims=output_dims, + hidden_dims=hidden_dims, + heads=heads, + id_embeddings_dim=id_embeddings_dim, + ) + + for key, value in state_dict.items(): + diffusers_name = key.replace("perceiver_resampler.", "") + diffusers_name = diffusers_name.replace("0.to", "2.to") + diffusers_name = diffusers_name.replace("0.1.0.", "0.3.0.") + diffusers_name = diffusers_name.replace("0.1.1.", "0.3.1.") + diffusers_name = diffusers_name.replace("1.1.0.", "1.3.0.") + diffusers_name = diffusers_name.replace("1.1.1.", "1.3.1.") + diffusers_name = diffusers_name.replace("2.1.0.", "2.3.0.") + diffusers_name = diffusers_name.replace("2.1.1.", "2.3.1.") + diffusers_name = diffusers_name.replace("3.1.0.", "3.3.0.") + diffusers_name = diffusers_name.replace("3.1.1.", "3.3.1.") + diffusers_name = diffusers_name.replace(".3.1.weight", ".3.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace(".1.3.weight", ".3.1.net.2.weight") + + if "norm1" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value + elif "norm2" in diffusers_name: + updated_state_dict[diffusers_name.replace("0.norm2", "1")] = value + elif "to_kv" in diffusers_name: + v_chunk = value.chunk(2, dim=0) + updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0] + updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1] + elif "to_out" in diffusers_name: + updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value + elif "proj.0.weight" == diffusers_name: + updated_state_dict["proj.net.0.proj.weight"] = value + elif "proj.0.bias" == diffusers_name: + updated_state_dict["proj.net.0.proj.bias"] = value + elif "proj.2.weight" == diffusers_name: + updated_state_dict["proj.net.2.weight"] = value + elif "proj.2.bias" == diffusers_name: + updated_state_dict["proj.net.2.bias"] = value + else: + updated_state_dict[diffusers_name] = value + elif "norm.weight" in state_dict: # IP-Adapter Face ID clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] @@ -875,6 +928,9 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token + elif "perceiver_resampler.proj_in.weight" in state_dict: + # IP-Adapter Face ID Plus + num_image_text_embeds += [4] elif "norm.weight" in state_dict["image_proj"]: # IP-Adapter Face ID num_image_text_embeds += [4] From 50261bfc511464e7768ca8732513fd2d9362408b Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Wed, 27 Mar 2024 21:32:30 +0100 Subject: [PATCH 17/29] Add Face ID Plus proj layers --- src/diffusers/models/embeddings.py | 98 ++++++++++++++++++++++++++++++ 1 file changed, 98 insertions(+) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index a6473baf7e0b..d6e18585c619 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -894,6 +894,104 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return self.norm_out(latents) +class IPAdapterFaceIDPlusImageProjection(nn.Module): + """FacePerceiverResampler of IP-Adapter Plus. + + Args: + ---- + embed_dims (int): The feature dimension. Defaults to 768. + output_dims (int): The number of output channels, that is the same + number of the channels in the + `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): The number of hidden channels. Defaults to 1280. + depth (int): The number of blocks. Defaults to 8. + dim_head (int): The number of head channels. Defaults to 64. + heads (int): Parallel attention heads. Defaults to 16. + num_queries (int): The number of queries. Defaults to 8. + ffn_ratio (float): The expansion ratio of feedforward network hidden + layer channels. Defaults to 4. + """ + + def __init__( + self, + embed_dims: int = 768, + output_dims: int = 768, + hidden_dims: int = 1280, + id_embeddings_dim = 512, + depth: int = 4, + dim_head: int = 64, + heads: int = 16, + num_tokens=4, + num_queries: int = 8, + ffn_ratio: float = 4, + ffproj_ratio: int = 2, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.num_tokens = num_tokens + self.embed_dim = embed_dims + self.clip_embeds = None + + self.proj = FeedForward(id_embeddings_dim, embed_dims*num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.norm = nn.LayerNorm(embed_dims) + + self.proj_in = nn.Linear(hidden_dims, embed_dims) + + self.proj_out = nn.Linear(embed_dims, output_dims) + self.norm_out = nn.LayerNorm(output_dims) + + self.layers = nn.ModuleList([]) + for _ in range(depth): + self.layers.append( + nn.ModuleList( + [ + nn.LayerNorm(embed_dims), + nn.LayerNorm(embed_dims), + Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ), + nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ), + ] + ) + ) + + def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: + """Forward pass. + + Args: + ---- + id_embeds (torch.Tensor): Input Tensor (ID embeds). + + Returns: + ------- + torch.Tensor: Output Tensor. + """ + id_embeds = self.proj(id_embeds) + id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) + latents = self.norm(id_embeds) + + clip_embeds = self.proj_in(self.clip_embeds) + x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) + + for ln0, ln1, attn, ff in self.layers: + + encoder_hidden_states = ln0(x) + latents = ln1(latents) + latents = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = attn(latents, encoder_hidden_states) + latents + latents = ff(latents) + latents + + latents = self.proj_out(latents) + return self.norm_out(latents) + + class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() From 66f911710bb7c2b87185b45fcfae1cfb4c48eb09 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Sun, 7 Apr 2024 12:00:54 +0200 Subject: [PATCH 18/29] Bugfixes + add shortcut --- src/diffusers/loaders/unet.py | 2 +- src/diffusers/models/embeddings.py | 21 ++++++++++++++++----- 2 files changed, 17 insertions(+), 6 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 54f92397de1c..744243ea8362 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -928,7 +928,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F elif "proj.3.weight" in state_dict["image_proj"]: # IP-Adapter Full Face num_image_text_embeds += [257] # 256 CLIP tokens + 1 CLS token - elif "perceiver_resampler.proj_in.weight" in state_dict: + elif "perceiver_resampler.proj_in.weight" in state_dict["image_proj"]: # IP-Adapter Face ID Plus num_image_text_embeds += [4] elif "norm.weight" in state_dict["image_proj"]: diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index d6e18585c619..bff987f5cd64 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -907,9 +907,12 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module): depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. Defaults to 16. + num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. + ffproj_ratio (float): The expansion ratio of feedforward network hidden + layer channels (for ID embeddings). Defaults to 4. """ def __init__( @@ -921,7 +924,7 @@ def __init__( depth: int = 4, dim_head: int = 64, heads: int = 16, - num_tokens=4, + num_tokens: int = 4, num_queries: int = 8, ffn_ratio: float = 4, ffproj_ratio: int = 2, @@ -932,6 +935,8 @@ def __init__( self.num_tokens = num_tokens self.embed_dim = embed_dims self.clip_embeds = None + self.shortcut = False + self.shortcut_scale = 1.0 self.proj = FeedForward(id_embeddings_dim, embed_dims*num_tokens, activation_fn="gelu", mult=ffproj_ratio) self.norm = nn.LayerNorm(embed_dims) @@ -973,23 +978,29 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: ------- torch.Tensor: Output Tensor. """ + id_embeds = id_embeds.to(self.clip_embeds.dtype) id_embeds = self.proj(id_embeds) id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) - latents = self.norm(id_embeds) + id_embeds = self.norm(id_embeds) + latents = id_embeds clip_embeds = self.proj_in(self.clip_embeds) x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) for ln0, ln1, attn, ff in self.layers: + residual = latents encoder_hidden_states = ln0(x) latents = ln1(latents) - latents = torch.cat([encoder_hidden_states, latents], dim=-2) - latents = attn(latents, encoder_hidden_states) + latents + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = attn(latents, encoder_hidden_states) + residual latents = ff(latents) + latents latents = self.proj_out(latents) - return self.norm_out(latents) + out = self.norm_out(latents) + if self.shortcut: + out = id_embeds + self.shortcut_scale * out + return out class MultiIPAdapterImageProjection(nn.Module): From b9344895bd6addff40a30a75745b367f2137a786 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Thu, 11 Apr 2024 21:56:17 +0200 Subject: [PATCH 19/29] Update docs --- docs/source/en/using-diffusers/ip_adapter.md | 56 +++++++++++++++++-- .../en/using-diffusers/loading_adapters.md | 37 ++++++++++++ examples/community/README.md | 2 - 3 files changed, 89 insertions(+), 6 deletions(-) diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index 4ae403538d2b..e2e6a0812cb9 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -362,14 +362,12 @@ IP-Adapter's image prompting and compatibility with other adapters and models ma ### Face model -Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces: +Generating accurate faces is challenging because they are complex and nuanced. Diffusers supports two IP-Adapter checkpoints specifically trained to generate faces from the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repository: * [ip-adapter-full-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-full-face_sd15.safetensors) is conditioned with images of cropped faces and removed backgrounds * [ip-adapter-plus-face_sd15.safetensors](https://huggingface.co/h94/IP-Adapter/blob/main/models/ip-adapter-plus-face_sd15.safetensors) uses patch embeddings and is conditioned with images of cropped faces -> [!TIP] -> -> [IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) is a face-specific IP-Adapter trained with face ID embeddings instead of CLIP image embeddings, allowing you to generate more consistent faces in different contexts and styles. Try out this popular [community pipeline](https://github.com/huggingface/diffusers/tree/main/examples/community#ip-adapter-face-id) and see how it compares to the other face IP-Adapters. +Additionally, Diffusers supports all IP-Adapter checkpoints trained with face embeddings extracted by `insightface` face models. Supported models are from the [h94/IP-Adapter-FaceID](https://huggingface.co/h94/IP-Adapter-FaceID) repository. For face models, use the [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) checkpoint. It is also recommended to use [`DDIMScheduler`] or [`EulerDiscreteScheduler`] for face models. @@ -411,6 +409,56 @@ image +To use IP-Adapter FaceID models, first extract face embeddings with `insightface`. Then pass the list of tensors to the pipeline as `ip_adapter_image_embeds`. + +```py +import torch +from diffusers import StableDiffusionPipeline, DDIMScheduler +from diffusers.utils import load_image +from insightface.app import FaceAnalysis + +pipeline = StableDiffusionPipeline.from_pretrained( + "runwayml/stable-diffusion-v1-5", + torch_dtype=torch.float16, +).to("cuda") +pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None) +pipeline.set_ip_adapter_scale(0.6) + +_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png") + +ref_images_embeds = [] +app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) +app.prepare(ctx_id=0, det_size=(640, 640)) +image = cv2.cvtColor(np.asarray(_image), cv2.COLOR_BGR2RGB) +faces = app.get(image) +image = torch.from_numpy(faces[0].normed_embedding) +ref_images_embeds.append(image.unsqueeze(0)) +ref_images_embeds = torch.stack(ref_images_embeds, dim=0).unsqueeze(0) +neg_ref_images_embeds = torch.zeros_like(ref_images_embeds) +id_embeds = torch.cat([neg_ref_images_embeds, ref_images_embeds]).to(dtype=torch.float16, device="cuda")) + +generator = torch.Generator(device="cpu").manual_seed(42) + +images = pipeline( + prompt="A photo of a girl", + ip_adapter_image_embeds=[id_embeds], + negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality", + num_inference_steps=20, num_images_per_prompt=1, + generator=generator +).images +``` + +Both IP-Adapter FaceID Plus and Plus v2 models require CLIP image embeddings. You can prepare face embeddings as shown previously, then you can extract and pass CLIP embeddings to the hidden image projection layers. + +```py +clip_embeds = pipeline.prepare_ip_adapter_image_embeds([ip_adapter_images], None, torch.device("cuda"), num_images, True)[0] + +pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = clip_embeds.to(dtype=torch.float16) +pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = False # True if Plus v2 +``` + + ### Multi IP-Adapter More than one IP-Adapter can be used at the same time to generate specific images in more diverse styles. For example, you can use IP-Adapter-Face to generate consistent faces and characters, and IP-Adapter Plus to generate those faces in a specific style. diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index b079d2165ece..45abcdbd0a72 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -320,3 +320,40 @@ pipeline = AutoPipelineForText2Image.from_pretrained( pipeline.load_ip_adapter("h94/IP-Adapter", subfolder="sdxl_models", weight_name="ip-adapter-plus_sdxl_vit-h.safetensors") ``` + +### IP-Adapter Face ID models + +The IP-Adapter FaceID models are experimental IP Adapters that use image embeddings generated by `insightface` instead of CLIP image embeddings. Some of these models also use LoRA to improve ID consistency. +You need to install `insightface` and all its requirements to use these models. + + +As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and is not intended for commercial use. + + +```py +pipeline = AutoPipelineForText2Image.from_pretrained( + "stabilityai/stable-diffusion-xl-base-1.0", + torch_dtype=torch.float16 +).to("cuda") + +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sdxl.bin", image_encoder_folder=None) +``` + +If you want to use one of the two IP-Adapter FaceID Plus models, you must also load the CLIP image encoder, as this models use both `insightface` and CLIP image embeddings to achieve better photorealism. + +```py +from transformers import CLIPVisionModelWithProjection + +image_encoder = CLIPVisionModelWithProjection.from_pretrained( + "laion/CLIP-ViT-H-14-laion2B-s32B-b79K", + torch_dtype=torch.float16, +) + +pipeline = AutoPipelineForText2Image.from_pretrained( + "runwayml/stable-diffusion-v1-5", + image_encoder=image_encoder, + torch_dtype=torch.float16 +).to("cuda") + +pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plus_sd15.bin") +``` diff --git a/examples/community/README.md b/examples/community/README.md index cc471874ca02..5cebc4f9f049 100755 --- a/examples/community/README.md +++ b/examples/community/README.md @@ -3819,12 +3819,10 @@ export_to_gif(frames, "animation.gif") IP Adapter FaceID is an experimental IP Adapter model that uses image embeddings generated by `insightface`, so no image encoder needs to be loaded. You need to install `insightface` and all its requirements to use this model. You must pass the image embedding tensor as `image_embeds` to the StableDiffusionPipeline instead of `ip_adapter_image`. -You have to disable PEFT BACKEND in order to load weights. You can find more results [here](https://github.com/huggingface/diffusers/pull/6276). ```py import diffusers -diffusers.utils.USE_PEFT_BACKEND = False import torch from diffusers.utils import load_image import cv2 From 32d6943476d92c05a9eea7839c1386746265e204 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 12 Apr 2024 20:43:41 +0200 Subject: [PATCH 20/29] Update docs/source/en/using-diffusers/loading_adapters.md Co-authored-by: Sayak Paul --- docs/source/en/using-diffusers/loading_adapters.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/en/using-diffusers/loading_adapters.md b/docs/source/en/using-diffusers/loading_adapters.md index 45abcdbd0a72..5871823aefe0 100644 --- a/docs/source/en/using-diffusers/loading_adapters.md +++ b/docs/source/en/using-diffusers/loading_adapters.md @@ -327,7 +327,7 @@ The IP-Adapter FaceID models are experimental IP Adapters that use image embeddi You need to install `insightface` and all its requirements to use these models. -As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and is not intended for commercial use. +As InsightFace pretrained models are available for non-commercial research purposes, IP-Adapter-FaceID models are released exclusively for research purposes and are not intended for commercial use. ```py From dc0a5fb5b98f5b12cd3f8e80185e7cab3db4b2d9 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 12 Apr 2024 21:23:21 +0200 Subject: [PATCH 21/29] Fix test and docs --- docs/source/en/using-diffusers/ip_adapter.md | 4 ++-- .../ip_adapters/test_ip_adapter_stable_diffusion.py | 13 +++++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/docs/source/en/using-diffusers/ip_adapter.md b/docs/source/en/using-diffusers/ip_adapter.md index e2e6a0812cb9..dc64b2548529 100644 --- a/docs/source/en/using-diffusers/ip_adapter.md +++ b/docs/source/en/using-diffusers/ip_adapter.md @@ -425,12 +425,12 @@ pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config) pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid_sd15.bin", image_encoder_folder=None) pipeline.set_ip_adapter_scale(0.6) -_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png") +image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/ip_mask_girl1.png") ref_images_embeds = [] app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider', 'CPUExecutionProvider']) app.prepare(ctx_id=0, det_size=(640, 640)) -image = cv2.cvtColor(np.asarray(_image), cv2.COLOR_BGR2RGB) +image = cv2.cvtColor(np.asarray(image), cv2.COLOR_BGR2RGB) faces = app.get(image) image = torch.from_numpy(faces[0].normed_embedding) ref_images_embeds.append(image.unsqueeze(0)) diff --git a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py index d8d1ff8b5439..ef70baa05f19 100644 --- a/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py +++ b/tests/pipelines/ip_adapters/test_ip_adapter_stable_diffusion.py @@ -321,13 +321,18 @@ def test_text_to_image_face_id(self): pipeline.set_ip_adapter_scale(0.7) inputs = self.get_dummy_inputs() - inputs["ip_adapter_image"] = load_pt( - "https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt" - ) + id_embeds = load_pt("https://huggingface.co/datasets/fabiorigano/testing-images/resolve/main/ai_face2.ipadpt")[ + 0 + ] + id_embeds = id_embeds.reshape((2, 1, 1, 512)) + inputs["ip_adapter_image_embeds"] = [id_embeds] + inputs["ip_adapter_image"] = None images = pipeline(**inputs).images image_slice = images[0, :3, :3, -1].flatten() - expected_slice = np.array([0.1665, 0.1626, 0.2187, 0.1882, 0.1702, 0.2144, 0.1624, 0.2012, 0.2173]) + expected_slice = np.array( + [0.32714844, 0.3239746, 0.3466797, 0.31835938, 0.30004883, 0.3251953, 0.3215332, 0.3552246, 0.3251953] + ) max_diff = numpy_cosine_similarity_distance(image_slice, expected_slice) assert max_diff < 5e-4 From 4073be8eb7b3c9c6756dcb13c096865d5da61e90 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Fri, 12 Apr 2024 21:35:42 +0200 Subject: [PATCH 22/29] Fix style --- src/diffusers/models/embeddings.py | 36 ++++++++++++------------------ 1 file changed, 14 insertions(+), 22 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 0ee9c34aec1c..ebd270819f73 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -471,6 +471,7 @@ def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): def forward(self, image_embeds: torch.FloatTensor): return self.norm(self.ff(image_embeds)) + class IPAdapterFaceIDImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() @@ -809,13 +810,14 @@ class IPAdapterPlusImageProjection(nn.Module): """Resampler of IP-Adapter Plus. Args: - ---- embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, that is the same number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. - Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio + Defaults to 16. num_queries (int): + The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. """ @@ -866,11 +868,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: - ---- x (torch.Tensor): Input Tensor. - Returns: - ------- torch.Tensor: Output Tensor. """ latents = self.latents.repeat(x.size(0), 1, 1) @@ -894,17 +893,13 @@ class IPAdapterFaceIDPlusImageProjection(nn.Module): """FacePerceiverResampler of IP-Adapter Plus. Args: - ---- - embed_dims (int): The feature dimension. Defaults to 768. - output_dims (int): The number of output channels, that is the same - number of the channels in the - `unet.config.cross_attention_dim`. Defaults to 1024. - hidden_dims (int): The number of hidden channels. Defaults to 1280. - depth (int): The number of blocks. Defaults to 8. - dim_head (int): The number of head channels. Defaults to 64. - heads (int): Parallel attention heads. Defaults to 16. - num_tokens (int): Number of tokens - num_queries (int): The number of queries. Defaults to 8. + embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, + that is the same + number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. + hidden_dims (int): + The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults + to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. + Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. ffproj_ratio (float): The expansion ratio of feedforward network hidden @@ -916,7 +911,7 @@ def __init__( embed_dims: int = 768, output_dims: int = 768, hidden_dims: int = 1280, - id_embeddings_dim = 512, + id_embeddings_dim: int = 512, depth: int = 4, dim_head: int = 64, heads: int = 16, @@ -934,7 +929,7 @@ def __init__( self.shortcut = False self.shortcut_scale = 1.0 - self.proj = FeedForward(id_embeddings_dim, embed_dims*num_tokens, activation_fn="gelu", mult=ffproj_ratio) + self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) self.norm = nn.LayerNorm(embed_dims) self.proj_in = nn.Linear(hidden_dims, embed_dims) @@ -967,11 +962,8 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: """Forward pass. Args: - ---- id_embeds (torch.Tensor): Input Tensor (ID embeds). - Returns: - ------- torch.Tensor: Output Tensor. """ id_embeds = id_embeds.to(self.clip_embeds.dtype) From 2b9d5a5b9733238ca5ecd749bc0a8b6528bd79f6 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 15 Apr 2024 21:38:26 +0200 Subject: [PATCH 23/29] Move lora loading to separate function --- src/diffusers/loaders/ip_adapter.py | 3 +- src/diffusers/loaders/unet.py | 59 +++++++++++++++-------------- 2 files changed, 32 insertions(+), 30 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 2d795dae90e3..3178c2ffd075 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -226,8 +226,9 @@ def load_ip_adapter( # load ip-adapter into unet unet = getattr(self, self.unet_name) if not hasattr(self, "unet") else self.unet - extra_loras = unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + unet._load_ip_adapter_weights(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) + extra_loras = unet._load_ip_adapter_loras(state_dicts) if extra_loras != {}: # apply the IP Adapter Face ID LoRA weights peft_config = getattr(unet, "peft_config", {}) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index d1d42dc9c090..acc28490705b 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -968,6 +968,35 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F key_id += 2 + return attn_procs + + def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): + if not isinstance(state_dicts, list): + state_dicts = [state_dicts] + # Set encoder_hid_proj after loading ip_adapter weights, + # because `IPAdapterPlusImageProjection` also has `attn_processors`. + self.encoder_hid_proj = None + + attn_procs = self._convert_ip_adapter_attn_to_diffusers( + state_dicts, low_cpu_mem_usage=low_cpu_mem_usage + ) + self.set_attn_processor(attn_procs) + + # convert IP-Adapter Image Projection layers to diffusers + image_projection_layers = [] + for state_dict in state_dicts: + image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( + state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage + ) + image_projection_layers.append(image_projection_layer) + + self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) + self.config.encoder_hid_dim_type = "ip_image_proj" + + self.to(dtype=self.dtype, device=self.device) + + def _load_ip_adapter_loras(self, state_dicts): + lora_dicts = {} for key_id, name in enumerate(self.attn_processors.keys()): for i, state_dict in enumerate(state_dicts): @@ -1018,35 +1047,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F ] } ) - - return attn_procs, lora_dicts - - def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): - if not isinstance(state_dicts, list): - state_dicts = [state_dicts] - # Set encoder_hid_proj after loading ip_adapter weights, - # because `IPAdapterPlusImageProjection` also has `attn_processors`. - self.encoder_hid_proj = None - - attn_procs, lora_dict = self._convert_ip_adapter_attn_to_diffusers( - state_dicts, low_cpu_mem_usage=low_cpu_mem_usage - ) - self.set_attn_processor(attn_procs) - - # convert IP-Adapter Image Projection layers to diffusers - image_projection_layers = [] - for state_dict in state_dicts: - image_projection_layer = self._convert_ip_adapter_image_proj_to_diffusers( - state_dict["image_proj"], low_cpu_mem_usage=low_cpu_mem_usage - ) - image_projection_layers.append(image_projection_layer) - - self.encoder_hid_proj = MultiIPAdapterImageProjection(image_projection_layers) - self.config.encoder_hid_dim_type = "ip_image_proj" - - self.to(dtype=self.dtype, device=self.device) - - return lora_dict + return lora_dicts class FromOriginalUNetMixin: From c14f0777b11c591b8ee3ad0e21b33c9c8f458759 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 15 Apr 2024 22:49:44 +0200 Subject: [PATCH 24/29] Add IPAdapterPlusImageProjectionBlock --- src/diffusers/loaders/unet.py | 36 +++++++++++------ src/diffusers/models/embeddings.py | 64 +++++++++++++++++------------- 2 files changed, 61 insertions(+), 39 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index acc28490705b..9fd3fc6254da 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -774,20 +774,30 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us heads=heads, id_embeddings_dim=id_embeddings_dim, ) - + print(state_dict.keys()) for key, value in state_dict.items(): diffusers_name = key.replace("perceiver_resampler.", "") - diffusers_name = diffusers_name.replace("0.to", "2.to") - diffusers_name = diffusers_name.replace("0.1.0.", "0.3.0.") - diffusers_name = diffusers_name.replace("0.1.1.", "0.3.1.") - diffusers_name = diffusers_name.replace("1.1.0.", "1.3.0.") - diffusers_name = diffusers_name.replace("1.1.1.", "1.3.1.") - diffusers_name = diffusers_name.replace("2.1.0.", "2.3.0.") - diffusers_name = diffusers_name.replace("2.1.1.", "2.3.1.") - diffusers_name = diffusers_name.replace("3.1.0.", "3.3.0.") - diffusers_name = diffusers_name.replace("3.1.1.", "3.3.1.") - diffusers_name = diffusers_name.replace(".3.1.weight", ".3.1.net.0.proj.weight") - diffusers_name = diffusers_name.replace(".1.3.weight", ".3.1.net.2.weight") + diffusers_name = diffusers_name.replace("0.to", "attn.to") + diffusers_name = diffusers_name.replace("0.1.0.", "0.ff.0.") + diffusers_name = diffusers_name.replace("0.1.1.weight", "0.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("0.1.3.weight", "0.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("1.1.0.", "1.ff.0.") + diffusers_name = diffusers_name.replace("1.1.1.weight", "1.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("1.1.3.weight", "1.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("2.1.0.", "2.ff.0.") + diffusers_name = diffusers_name.replace("2.1.1.weight", "2.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("2.1.3.weight", "2.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("3.1.0.", "3.ff.0.") + diffusers_name = diffusers_name.replace("3.1.1.weight", "3.ff.1.net.0.proj.weight") + diffusers_name = diffusers_name.replace("3.1.3.weight", "3.ff.1.net.2.weight") + diffusers_name = diffusers_name.replace("layers.0.0", "layers.0.ln0") + diffusers_name = diffusers_name.replace("layers.0.1", "layers.0.ln1") + diffusers_name = diffusers_name.replace("layers.1.0", "layers.1.ln0") + diffusers_name = diffusers_name.replace("layers.1.1", "layers.1.ln1") + diffusers_name = diffusers_name.replace("layers.2.0", "layers.2.ln0") + diffusers_name = diffusers_name.replace("layers.2.1", "layers.2.ln1") + diffusers_name = diffusers_name.replace("layers.3.0", "layers.3.ln0") + diffusers_name = diffusers_name.replace("layers.3.1", "layers.3.ln1") if "norm1" in diffusers_name: updated_state_dict[diffusers_name.replace("0.norm1", "0")] = value @@ -809,6 +819,8 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us updated_state_dict["proj.net.2.bias"] = value else: updated_state_dict[diffusers_name] = value + print(updated_state_dict.keys()) + print(image_projection.state_dict().keys()) elif "norm.weight" in state_dict: # IP-Adapter Face ID diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index ebd270819f73..4b891d4ac90a 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -888,6 +888,38 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: latents = self.proj_out(latents) return self.norm_out(latents) +class IPAdapterPlusImageProjectionBlock(nn.Module): + + def __init__( + self, + embed_dims: int = 768, + dim_head: int = 64, + heads: int = 16, + ffn_ratio: float = 4, + ) -> None: + super().__init__() + from .attention import FeedForward + + self.ln0 = nn.LayerNorm(embed_dims) + self.ln1 = nn.LayerNorm(embed_dims) + self.attn = Attention( + query_dim=embed_dims, + dim_head=dim_head, + heads=heads, + out_bias=False, + ) + self.ff = nn.Sequential( + nn.LayerNorm(embed_dims), + FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), + ) + + def forward(self, x, latents, residual): + encoder_hidden_states = self.ln0(x) + latents = self.ln1(latents) + encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) + latents = self.attn(latents, encoder_hidden_states) + residual + latents = self.ff(latents) + latents + return latents class IPAdapterFaceIDPlusImageProjection(nn.Module): """FacePerceiverResampler of IP-Adapter Plus. @@ -937,26 +969,9 @@ def __init__( self.proj_out = nn.Linear(embed_dims, output_dims) self.norm_out = nn.LayerNorm(output_dims) - self.layers = nn.ModuleList([]) - for _ in range(depth): - self.layers.append( - nn.ModuleList( - [ - nn.LayerNorm(embed_dims), - nn.LayerNorm(embed_dims), - Attention( - query_dim=embed_dims, - dim_head=dim_head, - heads=heads, - out_bias=False, - ), - nn.Sequential( - nn.LayerNorm(embed_dims), - FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), - ), - ] - ) - ) + self.layers = nn.ModuleList([ + IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth) + ]) def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: """Forward pass. @@ -975,14 +990,9 @@ def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: clip_embeds = self.proj_in(self.clip_embeds) x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) - for ln0, ln1, attn, ff in self.layers: + for block in self.layers: residual = latents - - encoder_hidden_states = ln0(x) - latents = ln1(latents) - encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) - latents = attn(latents, encoder_hidden_states) + residual - latents = ff(latents) + latents + latents = block(x, latents, residual) latents = self.proj_out(latents) out = self.norm_out(latents) From 5cab6dacfe0c15265d44c702fd19f0ca10c0a387 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Mon, 15 Apr 2024 23:01:35 +0200 Subject: [PATCH 25/29] Fix style --- src/diffusers/loaders/unet.py | 5 +---- src/diffusers/models/embeddings.py | 11 ++++++----- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9fd3fc6254da..9c1984c469ea 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -989,9 +989,7 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): # because `IPAdapterPlusImageProjection` also has `attn_processors`. self.encoder_hid_proj = None - attn_procs = self._convert_ip_adapter_attn_to_diffusers( - state_dicts, low_cpu_mem_usage=low_cpu_mem_usage - ) + attn_procs = self._convert_ip_adapter_attn_to_diffusers(state_dicts, low_cpu_mem_usage=low_cpu_mem_usage) self.set_attn_processor(attn_procs) # convert IP-Adapter Image Projection layers to diffusers @@ -1008,7 +1006,6 @@ def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=False): self.to(dtype=self.dtype, device=self.device) def _load_ip_adapter_loras(self, state_dicts): - lora_dicts = {} for key_id, name in enumerate(self.attn_processors.keys()): for i, state_dict in enumerate(state_dicts): diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 4b891d4ac90a..ced520bb8204 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -888,15 +888,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: latents = self.proj_out(latents) return self.norm_out(latents) -class IPAdapterPlusImageProjectionBlock(nn.Module): +class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( self, embed_dims: int = 768, dim_head: int = 64, heads: int = 16, ffn_ratio: float = 4, - ) -> None: + ) -> None: super().__init__() from .attention import FeedForward @@ -921,6 +921,7 @@ def forward(self, x, latents, residual): latents = self.ff(latents) + latents return latents + class IPAdapterFaceIDPlusImageProjection(nn.Module): """FacePerceiverResampler of IP-Adapter Plus. @@ -969,9 +970,9 @@ def __init__( self.proj_out = nn.Linear(embed_dims, output_dims) self.norm_out = nn.LayerNorm(output_dims) - self.layers = nn.ModuleList([ - IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth) - ]) + self.layers = nn.ModuleList( + [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] + ) def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: """Forward pass. From 0e034469b451c038672d8447a2c8e16d29689c1c Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 16 Apr 2024 20:22:04 +0200 Subject: [PATCH 26/29] Fix quality + add PEFT check --- src/diffusers/loaders/ip_adapter.py | 16 ++++++++++------ src/diffusers/loaders/unet.py | 4 +--- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/src/diffusers/loaders/ip_adapter.py b/src/diffusers/loaders/ip_adapter.py index 3178c2ffd075..fdddc382212f 100644 --- a/src/diffusers/loaders/ip_adapter.py +++ b/src/diffusers/loaders/ip_adapter.py @@ -21,6 +21,7 @@ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict from ..utils import ( + USE_PEFT_BACKEND, _get_model_file, is_accelerate_available, is_torch_version, @@ -230,12 +231,15 @@ def load_ip_adapter( extra_loras = unet._load_ip_adapter_loras(state_dicts) if extra_loras != {}: - # apply the IP Adapter Face ID LoRA weights - peft_config = getattr(unet, "peft_config", {}) - for k, lora in extra_loras.items(): - if f"faceid_{k}" not in peft_config: - self.load_lora_weights(lora, adapter_name=f"faceid_{k}") - self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) + if not USE_PEFT_BACKEND: + logger.warning("PEFT backend is required to load these weights.") + else: + # apply the IP Adapter Face ID LoRA weights + peft_config = getattr(unet, "peft_config", {}) + for k, lora in extra_loras.items(): + if f"faceid_{k}" not in peft_config: + self.load_lora_weights(lora, adapter_name=f"faceid_{k}") + self.set_adapters([f"faceid_{k}"], adapter_weights=[1.0]) def set_ip_adapter_scale(self, scale): """ diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 9c1984c469ea..71070216f94d 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -774,7 +774,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us heads=heads, id_embeddings_dim=id_embeddings_dim, ) - print(state_dict.keys()) + for key, value in state_dict.items(): diffusers_name = key.replace("perceiver_resampler.", "") diffusers_name = diffusers_name.replace("0.to", "attn.to") @@ -819,8 +819,6 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us updated_state_dict["proj.net.2.bias"] = value else: updated_state_dict[diffusers_name] = value - print(updated_state_dict.keys()) - print(image_projection.state_dict().keys()) elif "norm.weight" in state_dict: # IP-Adapter Face ID From e6495a22679bec3b8fc0c0a4cbc31db9b169a543 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 16 Apr 2024 20:58:55 +0200 Subject: [PATCH 27/29] Fix names --- src/diffusers/loaders/unet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/loaders/unet.py b/src/diffusers/loaders/unet.py index 71070216f94d..294db44ee61d 100644 --- a/src/diffusers/loaders/unet.py +++ b/src/diffusers/loaders/unet.py @@ -822,9 +822,9 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us elif "norm.weight" in state_dict: # IP-Adapter Face ID - clip_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] - clip_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] - multiplier = clip_embeddings_dim_out // clip_embeddings_dim_in + id_embeddings_dim_in = state_dict["proj.0.weight"].shape[1] + id_embeddings_dim_out = state_dict["proj.0.weight"].shape[0] + multiplier = id_embeddings_dim_out // id_embeddings_dim_in norm_layer = "norm.weight" cross_attention_dim = state_dict[norm_layer].shape[0] num_tokens = state_dict["proj.2.weight"].shape[0] // cross_attention_dim @@ -832,7 +832,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us with init_context(): image_projection = IPAdapterFaceIDImageProjection( cross_attention_dim=cross_attention_dim, - image_embed_dim=clip_embeddings_dim_in, + image_embed_dim=id_embeddings_dim_in, mult=multiplier, num_tokens=num_tokens, ) From d7772e94a0f3157ad6e437d136997fc0cd9790b8 Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 16 Apr 2024 20:59:49 +0200 Subject: [PATCH 28/29] Add fast test --- .../unets/test_models_unet_2d_condition.py | 60 ++++++++++++++++++- tests/pipelines/test_pipelines_common.py | 49 ++++++++++++++- 2 files changed, 107 insertions(+), 2 deletions(-) diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index bf6b7fe99b7f..1b8a998cfd66 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -30,7 +30,7 @@ IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, ) -from diffusers.models.embeddings import ImageProjection, IPAdapterPlusImageProjection +from diffusers.models.embeddings import ImageProjection, IPAdapterFaceIDImageProjection, IPAdapterPlusImageProjection from diffusers.utils import logging from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( @@ -190,6 +190,64 @@ def create_ip_adapter_plus_state_dict(model): return ip_state_dict +def create_ip_adapter_faceid_state_dict(model): + # "ip_adapter" (cross-attention weights) + # no LoRA weights + ip_cross_attn_state_dict = {} + key_id = 1 + + for name in model.attn_processors.keys(): + cross_attention_dim = ( + None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim + ) + + if name.startswith("mid_block"): + hidden_size = model.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.config.block_out_channels[block_id] + + if cross_attention_dim is not None: + sd = IPAdapterAttnProcessor( + hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 + ).state_dict() + ip_cross_attn_state_dict.update( + { + f"{key_id}.to_k_ip.weight": sd["to_k_ip.0.weight"], + f"{key_id}.to_v_ip.weight": sd["to_v_ip.0.weight"], + } + ) + + key_id += 2 + + # "image_proj" (ImageProjection layer weights) + cross_attention_dim = model.config["cross_attention_dim"] + image_projection = IPAdapterFaceIDImageProjection( + cross_attention_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, mult=2, num_tokens=4 + ) + + ip_image_projection_state_dict = {} + sd = image_projection.state_dict() + ip_image_projection_state_dict.update( + { + "proj.0.weight": sd["ff.net.0.proj.weight"], + "proj.0.bias": sd["ff.net.0.proj.bias"], + "proj.2.weight": sd["ff.net.2.weight"], + "proj.2.bias": sd["ff.net.2.bias"], + "norm.weight": sd["norm.weight"], + "norm.bias": sd["norm.bias"], + } + ) + + del sd + ip_state_dict = {} + ip_state_dict.update({"image_proj": ip_image_projection_state_dict, "ip_adapter": ip_cross_attn_state_dict}) + return ip_state_dict + + def create_custom_diffusion_layers(model, mock_weights: bool = True): train_kv = True train_q_out = True diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index acff5f2cdf8f..be5dfbe19291 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -47,7 +47,10 @@ get_autoencoder_tiny_config, get_consistency_vae_config, ) -from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict +from ..models.unets.test_models_unet_2d_condition import ( + create_ip_adapter_faceid_state_dict, + create_ip_adapter_state_dict, +) from ..others.test_utils import TOKEN, USER, is_staging_test @@ -238,6 +241,9 @@ def test_pipeline_signature(self): def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): return torch.randn((2, 1, cross_attention_dim), device=torch_device) + def _get_dummy_faceid_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((2, 1, 1, cross_attention_dim), device=torch_device) + def _get_dummy_masks(self, input_size: int = 64): _masks = torch.zeros((1, 1, input_size, input_size), device=torch_device) _masks[0, :, :, : int(input_size / 2)] = 1 @@ -415,6 +421,47 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" ) + def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs)[0] + output_without_adapter = output_without_adapter[0, -3:, -3:, -1].flatten() + + adapter_state_dict = create_ip_adapter_faceid_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + output_without_adapter_scale = output_without_adapter_scale[0, -3:, -3:, -1].flatten() + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_faceid_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + output_with_adapter_scale = output_with_adapter_scale[0, -3:, -3:, -1].flatten() + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-3, "Output with ip-adapter must be different from normal inference" + ) + class PipelineLatentTesterMixin: """ From f77739f0565b64bd885e10adae9dffae5875303e Mon Sep 17 00:00:00 2001 From: Fabio Rigano Date: Tue, 16 Apr 2024 21:07:11 +0200 Subject: [PATCH 29/29] Fix style --- tests/pipelines/test_pipelines_common.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index be5dfbe19291..1c89a8190b3d 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -422,7 +422,6 @@ def test_ip_adapter_masks(self, expected_max_diff: float = 1e-4): ) def test_ip_adapter_faceid(self, expected_max_diff: float = 1e-4): - components = self.get_dummy_components() pipe = self.pipeline_class(**components).to(torch_device) pipe.set_progress_bar_config(disable=None)