Skip to content

Mismatching size in matmul when using StableDiffusionInstructPix2Pix pipeline with IP-Adapter #7799

@misshimichka

Description

@misshimichka

Describe the bug

I've tried to combine InstructPix2Pix model and IP-Adapter (pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin"), but I always get RuntimeError: mat1 and mat2 shapes cannot be multiplied (771x1280 and 1024x3072). When I use other models i.e. Stable Diffusion or SD-XL everything works.
I think it's because of different dimensions of InstructPix2Pix and IP-Adapter.

Reproduction

noise_scheduler = DDIMScheduler(
    num_train_timesteps=1000,
    beta_start=0.00085,
    beta_end=0.012,
    beta_schedule="scaled_linear",
    clip_sample=False,
    set_alpha_to_one=False,
    steps_offset=1,
)

vae = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse").to(dtype=torch.float32)

pipe = StableDiffusionInstructPix2PixPipeline.from_pretrained(
    "timbrooks/instruct-pix2pix",
    torch_dtype=torch.float32,
    scheduler=noise_scheduler,
    vae=vae,
    feature_extractor=None,
    safety_checker=None
)
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter_sd15.bin")

image = load_image("https://huggingface.co/datasets/huggingface/documentation/images/resolve/main/diffusers/load_neg_embed.png")

generator = torch.Generator(device="cpu").manual_seed(33)
images = pipe(
    prompt='best quality, high quality',
    image=image,
    ip_adapter_image=image,
    negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
    num_inference_steps=50,
    generator=generator
).images[0]

Logs

RuntimeError                              Traceback (most recent call last)
Cell In[32], line 2
      1 generator = torch.Generator(device="cpu").manual_seed(33)
----> 2 images = pipe(
      3     prompt='best quality, high quality',
      4     image=image,
      5     ip_adapter_image=image,
      6     negative_prompt="monochrome, lowres, bad anatomy, worst quality, low quality",
      7     num_inference_steps=50,
      8     generator=generator
      9 ).images[0]
     10 images

File /opt/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py:404, in StableDiffusionInstructPix2PixPipeline.__call__(self, prompt, image, num_inference_steps, guidance_scale, image_guidance_scale, negative_prompt, num_images_per_prompt, eta, generator, latents, prompt_embeds, negative_prompt_embeds, ip_adapter_image, output_type, return_dict, callback_on_step_end, callback_on_step_end_tensor_inputs, **kwargs)
    401 scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
    403 # predict the noise residual
--> 404 noise_pred = self.unet(
    405     scaled_latent_model_input,
    406     t,
    407     encoder_hidden_states=prompt_embeds,
    408     added_cond_kwargs=added_cond_kwargs,
    409     return_dict=False,
    410 )[0]
    412 # perform guidance
    413 if self.do_classifier_free_guidance:

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1164, in UNet2DConditionModel.forward(self, sample, timestep, encoder_hidden_states, class_labels, timestep_cond, attention_mask, cross_attention_kwargs, added_cond_kwargs, down_block_additional_residuals, mid_block_additional_residual, down_intrablock_additional_residuals, encoder_attention_mask, return_dict)
   1161 if self.time_embed_act is not None:
   1162     emb = self.time_embed_act(emb)
-> 1164 encoder_hidden_states = self.process_encoder_hidden_states(
   1165     encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
   1166 )
   1168 # 2. pre-process
   1169 sample = self.conv_in(sample)

File /opt/conda/lib/python3.10/site-packages/diffusers/models/unets/unet_2d_condition.py:1035, in UNet2DConditionModel.process_encoder_hidden_states(self, encoder_hidden_states, added_cond_kwargs)
   1031         raise ValueError(
   1032             f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in  `added_conditions`"
   1033         )
   1034     image_embeds = added_cond_kwargs.get("image_embeds")
-> 1035     image_embeds = self.encoder_hid_proj(image_embeds)
   1036     encoder_hidden_states = (encoder_hidden_states, image_embeds)
   1037 return encoder_hidden_states

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/diffusers/models/embeddings.py:909, in MultiIPAdapterImageProjection.forward(self, image_embeds)
    907 batch_size, num_images = image_embed.shape[0], image_embed.shape[1]
    908 image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:])
--> 909 image_embed = image_projection_layer(image_embed)
    910 image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:])
    912 projected_image_embeds.append(image_embed)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/diffusers/models/embeddings.py:458, in ImageProjection.forward(self, image_embeds)
    455 batch_size = image_embeds.shape[0]
    457 # image
--> 458 image_embeds = self.image_embeds(image_embeds)
    459 image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1)
    460 image_embeds = self.norm(image_embeds)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File /opt/conda/lib/python3.10/site-packages/torch/nn/modules/linear.py:114, in Linear.forward(self, input)
    113 def forward(self, input: Tensor) -> Tensor:
--> 114     return F.linear(input, self.weight, self.bias)
RuntimeError: mat1 and mat2 shapes cannot be multiplied (771x1280 and 1024x3072)

System Info

- `diffusers` version: 0.27.2
- Platform: Linux-5.15.133+-x86_64-with-glibc2.31
- Python version: 3.10.13
- PyTorch version (GPU?): 2.1.2 (True)
- Huggingface_hub version: 0.22.2
- Transformers version: 4.39.3
- Accelerate version: 0.29.3
- xFormers version: not installed
- Using GPU in script?: No
- Using distributed or parallel set-up in script?: No

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions