Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ensure dtype match between diffused latents and vae weights #8391

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

heyalexchoi
Copy link

What does this PR do?

Simple fix to diffused latent dtype not matching vae weights dtype. See error below. I had this issue when loading pipeline in bfloat16 and using accelerate.

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/PixArt-sigma/diffusion/utils/image_evaluation.py", line 150, in generate_images
    batch_images = pipeline(
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/workspace/diffusers/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py", line 866, in __call__
    image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
  File "/workspace/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 305, in decode
    decoded = self._decode(z, return_dict=False)[0]
  File "/workspace/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 277, in _decode
    z = self.post_quant_conv(z)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

Fixes # (issue)

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@sayakpaul
Copy link
Member

Thanks for your PR. Does it only when using the Sigma pipeline? Would something like this would be more prudent to implement?

needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@bghira
Copy link
Contributor

bghira commented Jun 4, 2024

this also occurs under SD 1.x/2.x and SDXL under accelerate, the default dtype for torch is fp32 but the vae dtype is bf16.

here is an error seen when using SDXL Refiner:

2024-06-05 00:41:54,010 [ERROR] (helpers.training.validation) Error generating validation image: Input type (fl
oat) and bias type (c10::BFloat16) should be the same, Traceback (most recent call last):
  File "/notebooks/SimpleTuner/helpers/training/validation.py", line 534, in validate_prompt
    validation_image_results = self.pipeline(**pipeline_kwargs).images
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py", line 1422, in __call__
    image = self.vae.decode(latents, return_dict=False)[0]
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 304, in decode
    decoded = self._decode(z).sample
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/diffusers/models/autoencoders/autoencoder_kl.py", line 274, in _decode
    z = self.post_quant_conv(z)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 460, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/notebooks/SimpleTuner/.venv/lib/python3.9/site-packages/torch/nn/modules/conv.py", line 456, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

@bghira
Copy link
Contributor

bghira commented Jun 4, 2024

#7886 is same/similar

@heyalexchoi
Copy link
Author

Thanks for your PR. Does it only when using the Sigma pipeline? Would something like this would be more prudent to implement?

needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

I don't know much about the background to a force_upcast config param. I do know I have had this issue in PixArt pipelines (maybe alpha too?) a few times. This fix seems simple and I don't see any downside.

@sayakpaul
Copy link
Member

Will defer to @yiyixuxu for an opinion on how to best proceed. IMO, we should handle in the same way as

needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast

@bghira
Copy link
Contributor

bghira commented Jun 5, 2024

do you mean to provide a conditional check instead of unconditionally casting it to the vae's dtype? or do you mean we should set force_upcast in a certain situation?

for the former, i'm curious what problems you foresee with doing it unconditionally. it's not that having a check would hurt, but i also don't see it hurting anything to ensure the latents are equal to the vae dtype before decode.

for the latter, this is a situation where upcasting the vae to be the same as the latents is unnecessary, eg. i am using the fp16 fixed SDXL VAE for decode, and upcasting will just waste resources. the problem is that the latents become fp32 after being modified by the pipeline just a few lines prior to the decode, but the vae itself is bf16.

tl;dr i think casting to the vae dtype is the correct solution rather than upcasting vae to the latents dtype.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants