Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TAESD-encoded latents are too dark #4676

Closed
keturn opened this issue Aug 19, 2023 · 11 comments · Fixed by #4682
Closed

TAESD-encoded latents are too dark #4676

keturn opened this issue Aug 19, 2023 · 11 comments · Fixed by #4682
Labels
bug Something isn't working

Comments

@keturn
Copy link
Contributor

keturn commented Aug 19, 2023

Describe the bug

AutoencodeTiny (TAESD) decoder seems to work fine. encoding on the other hand is producing poor results, and an encode-decode round-trip turns out poorly:

input:
benz

output:
sad benz

Reproduction

see https://gist.github.com/keturn/b0a10a3b388e1e49cdf38567b76eb30c

import diffusers, torch
from PIL.Image import Image, open as image_open

device = torch.device("cuda:0")

with torch.inference_mode():
    taesd = diffusers.AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16).to(device=device)
    vaesd = diffusers.AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", variant="fp16", torch_dtype=torch.float16).to(device=device)

from diffusers.utils.testing_utils import load_image

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"
)

from diffusers.image_processor import VaeImageProcessor
vae_processor = VaeImageProcessor()
image_tensor: torch.FloatTensor = vae_processor.preprocess(image).to(dtype=torch.float16, device=device)

print(f"image tensor range: {image_tensor.min()} < {image_tensor.mean()} < {image_tensor.max()})")


with torch.inference_mode():
    taesd_latents = taesd.encode(image_tensor).latents
    print(f"taesd-encoded latent range: {taesd_latents.min()} < {taesd_latents.mean()} (σ={taesd_latents.std()}) < {taesd_latents.max()})")

    vaesd_latents = vaesd.encode(image_tensor).latent_dist.sample()
    print(f"vaesd-encoded latent range: {vaesd_latents.min()} < {vaesd_latents.mean()} (σ={vaesd_latents.std()}) < {vaesd_latents.max()})")


with torch.inference_mode():
    redecoded_tensor = taesd.decode(taesd_latents).sample


redecoded_image = vae_processor.postprocess(redecoded_tensor)
display(image, redecoded_image[0])


from diffusers.commands import env
env.EnvironmentCommand().run()

System Info

  • diffusers version: 0.20.0
  • Platform: Linux-5.15.0-79-generic-x86_64-with-glibc2.35
  • Python version: 3.11.4
  • PyTorch version (GPU?): 2.0.1+cu118 (True)
  • Huggingface_hub version: 0.16.4
  • Transformers version: 4.31.0
  • Accelerate version: 0.21.0
  • xFormers version: 0.0.21
  • Using GPU in script?: yes
  • Using distributed or parallel set-up in script?: no

Who can help?

No response

@madebyollin
Copy link
Contributor

Oops, It's probably this (@sayakpaul 😅)

@keturn
Copy link
Contributor Author

keturn commented Aug 19, 2023

yeah I was side-eyeing that suspiciously

@sayakpaul
Copy link
Member

Oops, It's probably #4384 (comment)

I am not sure what's meant by that comment.

Maybe the encode-decode roundtrip is best done by referring to @madebyollin's original notebook here:
https://github.com/madebyollin/taesd/blob/main/examples/Previewing_During_Image_Generation.ipynb

@keturn
Copy link
Contributor Author

keturn commented Aug 19, 2023

Is there a round-trip in that notebook? I don't see any encoding.

@madebyollin
Copy link
Contributor

madebyollin commented Aug 19, 2023

@sayakpaul
Per the README, TAESD's raw model assumes a [0, 1] scaling convention for input / output images.

I think either @keturn's sample code needs to use a special preprocessor for AutoencoderTiny, or (to match the AutoencoderKL behavior), AutoencoderTiny.encode could add rescaling of the inputs, inverse of the rescaling in AutoencoderTiny.decode.

Right now the encoder is getting an image in [-1, 1], encoding the values that are in [0, 1] (above 50% brightness), and clamping the rest of the values (everything below 50% brightness) to black - which is why the decoder decodes a darkened image.

(My fault for not following up on the review comment - sorry!)

@keturn Since a few people have now been (understandably 😅) tripped up by the way TAESD bakes in the SD-VAE input / output scale-shift transforms, I've added an example "Encoding / Decoding" notebook to hopefully clear things up. I also added (hopefully) clearer language to the README.

enc_dec_example

@Isotr0py
Copy link
Contributor

Isotr0py commented Aug 19, 2023

It seems that it's avaliable for VaeImageProcessor to normalize/denormalize the image in preprocess/postprocess.

Add do_normalize=False to VaeImageProcessor and do_denormalize to vae_processor.postprocess could create normal image.

But it seems that the workflow would be a little bit complicated if it needs switching between vae/tae. Maybe we can move do_normalize from processor to the preprocess method for flexibility?

Test code

import diffusers, torch
from PIL.Image import Image, open as image_open

device = torch.device("cuda:0")

with torch.inference_mode():
    taesd = diffusers.AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch.float16).to(device=device)
    vaesd = diffusers.AutoencoderKL.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="vae", variant="fp16", torch_dtype=torch.float16).to(device=device)

from diffusers.utils.testing_utils import load_image

image = load_image(
    "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/versatile_diffusion/benz.jpg"
)

from diffusers.image_processor import VaeImageProcessor
vae_processor = VaeImageProcessor(do_normalize=False)
image_tensor: torch.FloatTensor = vae_processor.preprocess(image).to(dtype=torch.float16, device=device)

print(f"image tensor range: {image_tensor.min()} < {image_tensor.mean()} < {image_tensor.max()})")


with torch.inference_mode():
    taesd_latents = taesd.encode(image_tensor).latents
    print(f"taesd-encoded latent range: {taesd_latents.min()} < {taesd_latents.mean()} (σ={taesd_latents.std()}) < {taesd_latents.max()})")

    vaesd_latents = vaesd.encode(image_tensor).latent_dist.sample()
    print(f"vaesd-encoded latent range: {vaesd_latents.min()} < {vaesd_latents.mean()} (σ={vaesd_latents.std()}) < {vaesd_latents.max()})")


with torch.inference_mode():
    redecoded_tensor = taesd.decode(taesd_latents).sample


redecoded_image = vae_processor.postprocess(redecoded_tensor, do_denormalize=[True])
display(image, redecoded_image[0])

Outputs

benz

@keturn
Copy link
Contributor Author

keturn commented Aug 19, 2023

encode and decode should be inverses of one another, such that x == encode(decode(x)) and y == decode(encode(y)), give or take bit due to the lossy nature of the encoding. So the domain of the input of encode should be consistent with the range of the output of decode — it would be quite unexpected for one of them to be [0..1) and the other to be [-1, 1].

But it seems that the workflow would be a little bit complicated if it needs switching between vae/tae.

Yes. Given that TAESD was explicitly designed as a drop-in replacement for the Stable Diffusion VAE, and the diffusers library implements them both, it would be very much appreciated if the library offered an interface that's consistent to both.

@sayakpaul
Copy link
Member

Even though it was developed as a drop-in replacement, I think its main usefulness lies in speedy decoding. For now, almost all (if not all) encoders we have in the library expect the value range to be in [-1, 1]. So, we won't likely be changing it. But happy to review any PRs in this regard.

madebyollin added a commit to madebyollin/diffusers that referenced this issue Aug 19, 2023
  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676)

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as earlier as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

   * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
     convention represents images with zero-centered values in [-1, 1],
     so AutoencoderTiny needs to scale / unscale images at the start of
     encoding and at the end of decoding in order to work with diffusers.
madebyollin added a commit to madebyollin/diffusers that referenced this issue Aug 19, 2023
  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676)

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as earlier as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.
madebyollin added a commit to madebyollin/diffusers that referenced this issue Aug 19, 2023
  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676)

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as early as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.
madebyollin added a commit to madebyollin/diffusers that referenced this issue Aug 19, 2023
  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676)

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as early as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.
madebyollin added a commit to madebyollin/diffusers that referenced this issue Aug 19, 2023
  * Add [-1, 1] -> [0, 1] rescaling to EncoderTiny (this fixes huggingface#4676)

  * Move [0, 1] -> [-1, 1] rescaling from AutoencoderTiny.decode to DecoderTiny
    (i.e. immediately after the final conv, as early as possible)

  * Fix missing [0, 255] -> [0, 1] rescaling in AutoencoderTiny.forward

  * Update AutoencoderTinyIntegrationTests to protect against scaling issues.
    The new test constructs a simple image, round-trips it through AutoencoderTiny,
    and confirms the decoded result is approximately equal to the source image.
    This test checks behavior with and without tiling enabled.
    This test will fail if new AutoencoderTiny scaling issues are introduced.

  * Context: Raw TAESD weights expect images in [0, 1], but diffusers'
    convention represents images with zero-centered values in [-1, 1],
    so AutoencoderTiny needs to scale / unscale images at the start of
    encoding and at the end of decoding in order to work with diffusers.
@madebyollin
Copy link
Contributor

madebyollin commented Aug 19, 2023

@sayakpaul I agree, [-1, 1] is the correct value convention for diffusers images. I've added a PR (#4682) to make AutoencoderTiny's encoder work with [-1, 1] inputs, so that it matches the other encoders.

@keturn I've tested the PR on your sample code, and I think it should now work without modifications (though the printed latent ranges are still different, because your sample code isn't manually applying the scaling_factor thingy):
image

@keturn
Copy link
Contributor Author

keturn commented Aug 19, 2023

Thank you, and for the note about the scaling factor as well -- I was wondering about that discrepancy. I have some follow-up questions about that, but such goes beyond the scope of this AutoencoderTiny issue; I'll find somewhere else to post that.

@keturn
Copy link
Contributor Author

keturn commented Aug 19, 2023

scaling_factor question → AutoencoderKL.scaling_factor and VaeImageProcessor

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants