diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 6f8276a632f7..a740b175069b 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -45,6 +45,7 @@ accelerate launch train_unconditional.py \ --mixed_precision=no \ --push_to_hub ``` +Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). The script writes true 16-bit PNGs, TensorBoard still shows 8-bit previews, and W&B stores the uploaded 16-bit files even though previews are typically 8-bit and may vary by SDK version—download from the Files/Artifacts tab if you need full precision. An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64 A full training run takes 2 hours on 4xV100 GPUs. diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 0cc96220b932..0085d37ad69f 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -4,11 +4,13 @@ import math import os import shutil +import tempfile from datetime import timedelta from pathlib import Path import accelerate import datasets +import numpy as np import torch import torch.nn.functional as F from accelerate import Accelerator, InitProcessGroupKwargs @@ -70,6 +72,113 @@ def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor: raise ValueError(f"Unsupported number of channels: {channels}") +def _prepare_sample_images(images: np.ndarray, bit_depth: int): + """Scale NHWC float images in [0, 1] to uint8/uint16 arrays and build a TensorBoard-safe preview. + + Expects float inputs with shape (N, H, W, C) and returns the scaled array plus an 8-bit preview. + """ + if bit_depth == 8: + max_pixel_value = 255 + out_dtype = np.uint8 + elif bit_depth == 16: + max_pixel_value = 65535 + out_dtype = np.uint16 + else: + raise ValueError(f"Unsupported image_bit_depth: {bit_depth}") + + images_processed = ( + np.clip(np.round(images * max_pixel_value), 0, max_pixel_value).astype(out_dtype) + ) + + if bit_depth == 16: + # TensorBoard accepts float [0,1] or uint8 visuals, so keep a separate preview for the UI. + tb_preview = np.clip(np.round(images * 255), 0, 255).astype(np.uint8) + else: + tb_preview = images_processed + + return images_processed, tb_preview + + +def _log_sample_images( + accelerator: Accelerator, + images_processed: np.ndarray, + tb_preview: np.ndarray, + epoch: int, + global_step: int, + logger_name: str, +): + if images_processed.dtype not in {np.uint8, np.uint16}: + raise ValueError( + f"Unsupported dtype for logged images: {images_processed.dtype}; expected uint8 or uint16." + ) + + if logger_name == "tensorboard": + if is_accelerate_version(">=", "0.17.0.dev0"): + tracker = accelerator.get_tracker("tensorboard", unwrap=True) + else: + tracker = accelerator.get_tracker("tensorboard") + tracker.add_images("test_samples", tb_preview.transpose(0, 3, 1, 2), global_step) + elif logger_name == "wandb": + import wandb + + tracker = accelerator.get_tracker("wandb") + if images_processed.dtype == np.uint16: + # Pillow required to encode true 16-bit PNGs for W&B; fall back to 8-bit previews if unavailable. + from PIL import Image + + wandb_images = [] + temp_paths = [] + for img16 in images_processed: + if img16.dtype != np.uint16: + raise ValueError( + f"Expected uint16 image for 16-bit logging, received {img16.dtype}." + ) + if img16.ndim not in {2, 3}: + raise ValueError(f"Unsupported array shape for 16-bit image: {img16.shape}") + h, w = img16.shape[:2] + if img16.ndim == 3: + channels = img16.shape[-1] + if channels == 1: + array_for_pil = np.ascontiguousarray(img16.squeeze(-1)) + pil_image = Image.fromarray(array_for_pil, mode="I;16") + elif channels in {3, 4}: + mode = "RGB" if channels == 3 else "RGBA" + contiguous = np.ascontiguousarray(img16) + pil_image = Image.frombytes( + mode, + (w, h), + # Pillow expects big-endian for RGB(A); use ";16B" and byteswap() to match. + contiguous.byteswap().tobytes(), + "raw", + f"{mode};16B", + ) + else: + raise ValueError(f"Unsupported channel count for 16-bit image: {channels}") + else: + array_for_pil = np.ascontiguousarray(img16) + pil_image = Image.fromarray(array_for_pil, mode="I;16") + + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + temp_paths.append(tmp.name) + tmp.close() + # wandb.Image does not accept BytesIO; a file path preserves 16-bit PNG bytes on upload. + pil_image.save(temp_paths[-1], format="PNG") + wandb_images.append(wandb.Image(temp_paths[-1])) + + try: + tracker.log({"test_samples": wandb_images, "epoch": epoch}, step=global_step) + finally: + for path in temp_paths: + try: + os.unlink(path) + except OSError: + pass + else: + tracker.log( + {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, + step=global_step, + ) + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -147,6 +256,13 @@ def parse_args(): parser.add_argument( "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." ) + parser.add_argument( + "--image_bit_depth", + type=int, + choices=[8, 16], + default=8, + help="Bit depth for generated sample images during evaluation/logging. Default: 8.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -699,20 +815,16 @@ def transform_images(examples): if args.use_ema: ema_model.restore(unet.parameters()) - # denormalize the images and save to tensorboard - images_processed = (images * 255).round().astype("uint8") + images_processed, tb_preview = _prepare_sample_images(images, args.image_bit_depth) - if args.logger == "tensorboard": - if is_accelerate_version(">=", "0.17.0.dev0"): - tracker = accelerator.get_tracker("tensorboard", unwrap=True) - else: - tracker = accelerator.get_tracker("tensorboard") - tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) - elif args.logger == "wandb": - # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files - accelerator.get_tracker("wandb").log( - {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, - step=global_step, + if args.logger in {"tensorboard", "wandb"}: + _log_sample_images( + accelerator, + images_processed, + tb_preview, + epoch, + global_step, + args.logger, ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: