Skip to content

Conversation

Dalanke
Copy link
Contributor

@Dalanke Dalanke commented Jun 14, 2024

What does this PR do?

This PR is trying to fix the bug when you specified any number greater than 1 in num_images_per_prompt when you call StableDiffusion3Pipeline . An expection occurs when you create the pipeline without T5 text encoder (set text_encoder_3=None)

Reproduction (follow the documentation here):

import torch
from diffusers import StableDiffusion3Pipeline

pipe = StableDiffusion3Pipeline.from_single_file(
    './stable-diffusion-3-medium/sd3_medium_incl_clips.safetensors',
    torch_dtype=torch.float16,
    text_encoder_3=None
    )
pipe = pipe.to("cuda")

image = pipe(
    "a picture of a cat holding a sign that says hello world",
    negative_prompt="",
    # could not specify number of images
    num_images_per_prompt=4,
    num_inference_steps=28,
    guidance_scale=7.0,
).images

for i, img in enumerate(image):
    with open(f'./output/test_{i}.jpg','w+') as f:
        img.save(f)

Bug output

Loading pipeline components...:  62%|███████████████████████████████████████▍                       | 5/8 [00:00<00:00, 16.23it/s]You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
Loading pipeline components...: 100%|███████████████████████████████████████████████████████████████| 8/8 [00:01<00:00,  5.47it/s]
Traceback (most recent call last):
  File "/home/xxx/workspace/sd3/sd3_inference.py", line 16, in <module>
    image = pipe(
  File "/home/xxx/miniconda3/envs/diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/xxx/workspace/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 778, in __call__
    ) = self.encode_prompt(
  File "/home/xxx/workspace/diffusers/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py", line 413, in encode_prompt
    prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 4 but got size 1 for tensor number 1 in the list.

Mitigation:
Bug due to the shape mis-match. For example, in num_images_per_prompt=4 settting, line 413:

prompt_embeds = torch.cat([clip_prompt_embeds, t5_prompt_embed], dim=-2)

will have different shape in torch.Size([4, 77, 4096]) and torch.Size([1, 77, 4096])

fix in the function _get_t5_prompt_embeds return when text_encoder_3=None

if self.text_encoder_3 is None:
        return torch.zeros(
        # change shape here
        # (batch_size, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
            (batch_size * num_images_per_prompt, self.tokenizer_max_length, self.transformer.config.joint_attention_dim),
            device=device,
            dtype=dtype,
        )

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@Dalanke
Copy link
Contributor Author

Dalanke commented Jun 18, 2024

Seems like some checks were not successful but I could not figure out the reason. One line code changed should not lead to code quality issue. Can anyone kindly look into it?

@yiyixuxu yiyixuxu merged commit 2921a20 into huggingface:main Jun 18, 2024
@yiyixuxu
Copy link
Collaborator

thanks!
we have a make style and make fix-copies command you can run to pass the quality test https://huggingface.co/docs/diffusers/en/conceptual/contribution#how-to-open-a-pr

yiyixuxu added a commit that referenced this pull request Jun 20, 2024
…out T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
sayakpaul pushed a commit that referenced this pull request Dec 23, 2024
…out T5 (text_encoder_3=None) (#8558)

* fix shape mismatch when num_images_per_prompt > 1 and text_encoder_3=None

* style

* fix copies

---------

Co-authored-by: YiYi Xu <yixu310@gmail.com>
Co-authored-by: yiyixuxu <yixu310@gmail,com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants