Skip to content

Conversation

@turian
Copy link
Contributor

@turian turian commented Nov 5, 2025

What does this PR do?

Adds an opt-in --preserve_input_precision flag to examples/unconditional_image_generation/train_unconditional.py so users can keep 16/32-bit channel data (e.g. medical TIFFs) in full precision during preprocessing while still emitting 3-channel float32 tensors normalized to [-1, 1]. By default nothing changes: we still hit image.convert("RGB") → ToTensor() → Normalize, preserving byte-for-byte parity with the current pipeline.

With the flag enabled, the script now:

  • Uses transforms.PILToTensor() + ConvertImageDtype(torch.float32) to avoid 8-bit quantization.
  • Applies the same spatial augmentations on tensors and enforces three channels via the new _ensure_three_channels helper (repeat, pad, or slice as needed).
    • Special-cases palette images to keep today’s behavior when precision isn’t actually higher.

README now documents the flag for users with high-bit-depth datasets.

Fixes # (issue)

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline?
  • Did you read our philosophy doc (important for complex PRs)?
  • Was this discussed/approved via a GitHub issue or the forum? Please add a link to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
  • Did you write any new necessary tests?

Who can review?

- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.

If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me that there is a possible footgun here for people using 16-bit RGB or RGBA images, but I wanted to keep the diff minimal.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very cool work! Thank you! If you have some extended reading to add to the README, feel free to!

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul merged commit 58f3771 into huggingface:main Nov 6, 2025
15 of 26 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants