-
Notifications
You must be signed in to change notification settings - Fork 6.6k
Closed
Labels
bugSomething isn't workingSomething isn't working
Description
Describe the bug
When switching from StableDiffusionXLPipeline to it's PAG version StableDiffusionXLPAGPipeline, the workflow using an additional IP-Adapter (ip-adapter-faceid-plusv2_sdxl) starts throwing Unet-related issue with some shapes mismatching.
Reproduction
Work fine with raw SDXL:
import torch
from diffusers import StableDiffusionXLPipeline
pipeline = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plusv2_sdxl.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(1.0)
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = True
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut_scale = 2.0
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = torch.randn(2, 1, 257, 1280).cuda().half()
image = pipeline(
prompt="A portrait of a person with a neutral expression",
num_inference_steps=20,
ip_adapter_image_embeds=[torch.randn(2, 1, 512).cuda().half()],
).images[0]Now switching to PAG:
import torch
from diffusers import StableDiffusionXLPAGPipeline
pipeline = StableDiffusionXLPAGPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch.float16,
).to("cuda")
pipeline.load_ip_adapter("h94/IP-Adapter-FaceID", subfolder=None, weight_name="ip-adapter-faceid-plusv2_sdxl.bin", image_encoder_folder=None)
pipeline.set_ip_adapter_scale(1.0)
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut = True
pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut_scale = 2.0
pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = torch.randn(2, 1, 257, 1280).cuda().half()
image = pipeline(
prompt="A portrait of a person with a neutral expression",
num_inference_steps=20,
pag_scale=3.0,
ip_adapter_image_embeds=[torch.randn(2, 1, 512).cuda().half()],
).images[0]Logs
{
"name": "RuntimeError",
"message": "Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list.",
"stack": "---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[2], line 16
13 pipeline.unet.encoder_hid_proj.image_projection_layers[0].shortcut_scale = 2.0
14 pipeline.unet.encoder_hid_proj.image_projection_layers[0].clip_embeds = torch.randn(2, 1, 257, 1280).cuda().half()
---> 16 image = pipeline(
17 prompt=\"A portrait of a person with a neutral expression\",
18 num_inference_steps=20,
19 pag_scale=3.0,
20 ip_adapter_image_embeds=[torch.randn(2, 1, 512).cuda().half()],
21 ).images[0]
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
112 @functools.wraps(func)
113 def decorate_context(*args, **kwargs):
114 with ctx_factory():
--> 115 return func(*args, **kwargs)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/pipelines/pag/pipeline_pag_sd_xl.py:1228, in StableDiffusionXLPAGPipeline.__call__(self, prompt, prompt_2, height, width, num_inference_steps, timesteps, sigmas, denoising_end, guidance_scale, negative_prompt, negative_prompt_2, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ip_adapter_image, ip_adapter_image_embeds, output_type, return_dict, cross_attention_kwargs, guidance_rescale, original_size, crops_coords_top_left, target_size, negative_original_size, negative_crops_coords_top_left, negative_target_size, clip_skip, callback_on_step_end, callback_on_step_end_tensor_inputs, pag_scale, pag_adaptive_scale)
1226 if ip_adapter_image_embeds is not None:
1227 added_cond_kwargs[\"image_embeds\"] = ip_adapter_image_embeds
-> 1228 noise_pred = self.unet(
1229 latent_model_input,
1230 t,
1231 encoder_hidden_states=prompt_embeds,
1232 timestep_cond=timestep_cond,
1233 cross_attention_kwargs=self.cross_attention_kwargs,
1234 added_cond_kwargs=added_cond_kwargs,
1235 return_dict=False,
1236 )[0]
1238 # perform guidance
1239 if self.do_perturbed_attention_guidance:
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1164, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
1161 if self.time_embed_act is not None:
1162 emb = self.time_embed_act(emb)
-> 1164 encoder_hidden_states = self.process_encoder_hidden_states(
1165 encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
1166 )
1168 # 2. pre-process
1169 sample = self.conv_in(sample)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1035, in UNet2DConditionModel.process_encoder_hidden_states(self, encoder_hidden_states, added_cond_kwargs)
1032 encoder_hidden_states = self.text_encoder_hid_proj(encoder_hidden_states)
1034 image_embeds = added_cond_kwargs.get(\"image_embeds\")
-> 1035 image_embeds = self.encoder_hid_proj(image_embeds)
1036 encoder_hidden_states = (encoder_hidden_states, image_embeds)
1037 return encoder_hidden_states
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/embeddings.py:1551, in MultiIPAdapterImageProjection.forward(self, image_embeds)
1549 batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
1550 image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
-> 1551 image_embed = image_projection_layer(image_embed)
1552 image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
1554 projected_image_embeds.append(image_embed)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/embeddings.py:1515, in IPAdapterFaceIDPlusImageProjection.forward(self, id_embeds)
1513 for block in self.layers:
1514 residual = latents
-> 1515 latents = block(x, latents, residual)
1517 latents = self.proj_out(latents)
1518 out = self.norm_out(latents)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1532, in Module._wrapped_call_impl(self, *args, **kwargs)
1530 return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
1531 else:
-> 1532 return self._call_impl(*args, **kwargs)
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/torch/nn/modules/module.py:1541, in Module._call_impl(self, *args, **kwargs)
1536 # If we don't have any hooks, we want to skip the rest of the logic in
1537 # this function, and just call forward.
1538 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1539 or _global_backward_pre_hooks or _global_backward_hooks
1540 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1541 return forward_call(*args, **kwargs)
1543 try:
1544 result = None
File /opt/.cache/virtualenvs/diffusion-exps-R-NqVyK3-py3.10/lib/python3.10/site-packages/diffusers/models/embeddings.py:1379, in IPAdapterPlusImageProjectionBlock.forward(self, x, latents, residual)
1377 encoder_hidden_states = self.ln0(x)
1378 latents = self.ln1(latents)
-> 1379 encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1380 latents = self.attn(latents, encoder_hidden_states) + residual
1381 latents = self.ff(latents) + latents
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 2 but got size 3 for tensor number 1 in the list."
}System Info
- 🤗 Diffusers version: 0.31.0.dev0 (https://github.com/asomoza/diffusers/tree/8dc04eb5f6b61959a3da6fcf20ca79738ebc01b3)
- Platform: Linux-5.15.0-1055-aws-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.14
- PyTorch version (GPU?): 2.3.1+cu121 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.24.5
- Transformers version: 4.40.1
- Accelerate version: 0.31.0
- PEFT version: 0.10.0
- Bitsandbytes version: 0.43.1
- Safetensors version: 0.4.3
- xFormers version: not installed
- Accelerator: NVIDIA A10G, 23028 MiB
- Using GPU in script?: YES
- Using distributed or parallel set-up in script?: NO
Who can help?
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working