Skip to content

StableDiffusionXLPAGPipeline does not work with FaceID Plus V2 IP-Adapter #9202

@detkov

Description

@detkov

Describe the bug

When switching from StableDiffusionXLPipeline to it's PAG version StableDiffusionXLPAGPipeline, the workflow using an additional IP-Adapter (ip-adapter-faceid-plusv2_sdxl) starts throwing Unet-related issue with some shapes mismatching.

Reproduction

Work fine with raw SDXL:

import torch
from diffusers import StableDiffusionXLPipeline


pipeline = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plusv2_sdxl.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(1.0)

pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = True
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut_scale = 2.0
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = torch.randn(2, 1, 257, 1280).cuda().half()

image = pipeline(
    prompt="A portrait of a person with a neutral expression",
    num_inference_steps=20,
    ip_adapter_image_embeds=[torch.randn(2, 1, 512).cuda().half()],
).images[0]

Now switching to PAG:

import torch
from diffusers import StableDiffusionXLPAGPipeline


pipeline = StableDiffusionXLPAGPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.float16,
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plusv2_sdxl.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(1.0)

pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = True
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut_scale = 2.0
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = torch.randn(2, 1, 257, 1280).cuda().half()

image = pipeline(
    prompt="A portrait of a person with a neutral expression",
    num_inference_steps=20,
    pag_scale=3.0,
    ip_adapter_image_embeds=[torch.randn(2, 1, 512).cuda().half()],
).images[0]

Logs

{
	"name": "RuntimeError",
	"message": "Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list.",
	"stack": "---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Cell In[2], line 16
     13 pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut_scale = 2.0
     14 pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = torch.randn(2, 1, 257, 1280).cuda().half()
---> 16 image = pipeline(
     17     prompt=\"A portrait of a person with a neutral expression\",
     18     num_inference_steps=20,
     19     pag_scale=3.0,
     20     ip_adapter_image_embeds=[torch.randn(2, 1, 512).cuda().half()],
     21 ).images[0]

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl.py:1228, in StableDiffusionXLPAGPipeline.__call__(self, prompt, prompt_2, height, width, num_inference_steps, timesteps, sigmas, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, output_type, return_dict, cross_attention_kwargs, guidance_rescale, original_size, crops_coords_top_left, target_size, negative_original_size, negative_crops_coords_top_left, negative_target_size, clip_skip, callback_on_step_end, callback_on_step_end_tensor_inputs, pag_scale, pag_adaptive_scale)
   1226 if ip_adapter_image_embeds is not None:
   1227     added_cond_kwargs[\"image_embeds\"] = ip_adapter_image_embeds
-> 1228 noise_pred = self.unet(
   1229     latent_model_input,
   1230     t,
   1231     encoder_hidden_states=prompt_embeds,
   1232     timestep_cond=timestep_cond,
   1233     cross_attention_kwargs=self.cross_attention_kwargs,
   1234     added_cond_kwargs=added_cond_kwargs,
   1235     return_dict=False,
   1236 )[0]
   1238 # perform guidance
   1239 if self.do_perturbed_attention_guidance:

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1164, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1161 if self.time_embed_act is not None:
   1162     emb = self.time_embed_act(emb)
-> 1164 encoder_hidden_states = self.process_encoder_hidden_states(
   1165     encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
   1166 )
   1168 # 2. pre-process
   1169 sample = self.conv_in(sample)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1035, in UNet2DConditionModel.process_encoder_hidden_states(self, encoder_hidden_states, added_cond_kwargs)
   1032         encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
   1034     image_embeds = added_cond_kwargs.get(\"image_embeds\")
-> 1035     image_embeds = self.encoder_hid_proj(image_embeds)
   1036     encoder_hidden_states = (encoder_hidden_states, image_embeds)
   1037 return encoder_hidden_states

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/embeddings.py:1551, in MultiIPAdapterImageProjection.forward(self, image_embeds)
   1549 batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
   1550 image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
-> 1551 image_embed = image_projection_layer(image_embed)
   1552 image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
   1554 projected_image_embeds.append(image_embed)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/embeddings.py:1515, in IPAdapterFaceIDPlusImageProjection.forward(self, id_embeds)
   1513 for block in self.layers:
   1514     residual = latents
-> 1515     latents = block(x, latents, residual)
   1517 latents = self.proj_out(latents)
   1518 out = self.norm_out(latents)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
   1530     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1531 else:
-> 1532     return self._call_impl(*args, **kwargs)

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
   1536 # If we don't have any hooks, we want to skip the rest of the logic in
   1537 # this function, and just call forward.
   1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1539         or _global_backward_pre_hooks or _global_backward_hooks
   1540         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541     return forward_call(*args, **kwargs)
   1543 try:
   1544     result = None

File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/embeddings.py:1379, in IPAdapterPlusImageProjectionBlock.forward(self, x, latents, residual)
   1377 encoder_hidden_states = self.ln0(x)
   1378 latents = self.ln1(latents)
-> 1379 encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
   1380 latents = self.attn(latents, encoder_hidden_states) + residual
   1381 latents = self.ff(latents) + latents

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list."
}

System Info

  • 🤗 Diffusers version: 0.31.0.dev0 (https://github.com/asomoza/diffusers/tree/8dc04eb5f6b61959a3da6fcf20ca79738ebc01b3)
  • Platform: Linux-5.15.0-1055-aws-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.14
  • PyTorch version (GPU?): 2.3.1+cu121 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.24.5
  • Transformers version: 4.40.1
  • Accelerate version: 0.31.0
  • PEFT version: 0.10.0
  • Bitsandbytes version: 0.43.1
  • Safetensors version: 0.4.3
  • xFormers version: not installed
  • Accelerator: NVIDIA A10G, 23028 MiB
  • Using GPU in script?: YES
  • Using distributed or parallel set-up in script?: NO

Who can help?

@sayakpaul @DN6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions