Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/unconditional_image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways:
- you can either provide your own folder as `--train_data_dir`
- or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument.

If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. Precision still depends on the decoder: Pillow keeps 16-bit grayscale and float inputs, but many 16-bit RGB files are decoded as 8-bit RGB, and the flag cannot recover precision lost at load time.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It occurs to me that there is a possible footgun here for people using 16-bit RGB or RGBA images, but I wanted to keep the diff minimal.


Below, we explain both in more detail.

#### Provide the dataset as a folder
Expand Down
57 changes: 51 additions & 6 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape):
return res.expand(broadcast_shape)


def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
"""
Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed.
"""
if tensor.ndim == 2:
tensor = tensor.unsqueeze(0)
channels = tensor.shape[0]
if channels == 3:
return tensor
if channels == 1:
return tensor.repeat(3, 1, 1)
if channels == 2:
return torch.cat([tensor, tensor[:1]], dim=0)
if channels > 3:
return tensor[:3]
raise ValueError(f"Unsupported number of channels: {channels}")


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
Expand Down Expand Up @@ -260,6 +278,11 @@ def parse_args():
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--preserve_input_precision",
action="store_true",
help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.",
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -453,19 +476,41 @@ def load_model_hook(models, input_dir):
# https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder

# Preprocessing the datasets and DataLoaders creation.
spatial_augmentations = [
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
]

augmentations = transforms.Compose(
[
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
spatial_augmentations
+ [
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5]),
]
)

precision_augmentations = transforms.Compose(
[
transforms.PILToTensor(),
transforms.Lambda(_ensure_three_channels),
transforms.ConvertImageDtype(torch.float32),
]
+ spatial_augmentations
+ [transforms.Normalize([0.5], [0.5])]
)

def transform_images(examples):
images = [augmentations(image.convert("RGB")) for image in examples["image"]]
return {"input": images}
processed = []
for image in examples["image"]:
if not args.preserve_input_precision:
processed.append(augmentations(image.convert("RGB")))
else:
precise_image = image
if precise_image.mode == "P":
precise_image = precise_image.convert("RGB")
processed.append(precision_augmentations(precise_image))
return {"input": processed}

logger.info(f"Dataset size: {len(dataset)}")

Expand Down
Loading