From 522de2f6181fe1477cac27c1b8a7e0521be82d16 Mon Sep 17 00:00:00 2001 From: Joseph Turian Date: Wed, 5 Nov 2025 14:39:26 +0000 Subject: [PATCH 1/6] Add 16-bit logging option for unconditional example --- .../unconditional_image_generation/README.md | 1 + .../train_unconditional.py | 95 ++++++++++++++++--- 2 files changed, 82 insertions(+), 14 deletions(-) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 22f982509bb1..ce6a29f8b970 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -45,6 +45,7 @@ accelerate launch train_unconditional.py \ --mixed_precision=no \ --push_to_hub ``` +Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). TensorBoard image summaries will still display 8-bit previews, while W&B keeps the 16-bit artifacts. An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64 A full training run takes 2 hours on 4xV100 GPUs. diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3ffeef13647c..4a023d6e9635 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,5 +1,6 @@ import argparse import inspect +import io import logging import math import os @@ -9,6 +10,7 @@ import accelerate import datasets +import numpy as np import torch import torch.nn.functional as F from accelerate import Accelerator, InitProcessGroupKwargs @@ -52,6 +54,68 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): return res.expand(broadcast_shape) +def _prepare_sample_images(images: np.ndarray, bit_depth: int): + """Scale pipeline outputs to uint8/uint16 arrays and build a TensorBoard-safe preview.""" + max_pixel_value = 255 if bit_depth == 8 else 65535 + out_dtype = np.uint8 if bit_depth == 8 else np.uint16 + images_processed = ( + np.clip(np.round(images * max_pixel_value), 0, max_pixel_value).astype(out_dtype) + ) + + if bit_depth == 8: + tb_preview = images_processed + else: + # TensorBoard accepts float [0,1] or uint8 visuals, so keep a separate preview for the UI. + tb_preview = np.clip(np.round(images * 255), 0, 255).astype(np.uint8) + + return images_processed, tb_preview + + +def _log_sample_images( + accelerator: Accelerator, + images_processed: np.ndarray, + tb_preview: np.ndarray, + epoch: int, + global_step: int, + logger_name: str, +): + if logger_name == "tensorboard": + if is_accelerate_version(">=", "0.17.0.dev0"): + tracker = accelerator.get_tracker("tensorboard", unwrap=True) + else: + tracker = accelerator.get_tracker("tensorboard") + tracker.add_images("test_samples", tb_preview.transpose(0, 3, 1, 2), epoch) + elif logger_name == "wandb": + import wandb + + tracker = accelerator.get_tracker("wandb") + if images_processed.dtype == np.uint16: + from PIL import Image + + wandb_images = [] + for img16 in images_processed: + buffer = io.BytesIO() + if img16.ndim == 3 and img16.shape[-1] == 3: + contiguous = np.ascontiguousarray(img16) + pil_image = Image.frombytes( + "RGB", + (contiguous.shape[1], contiguous.shape[0]), + contiguous.byteswap().tobytes(), + "raw", + "RGB;16B", + ) + else: + pil_image = Image.fromarray(img16, mode="I;16") + pil_image.save(buffer, format="PNG") + buffer.seek(0) + wandb_images.append(wandb.Image(buffer)) + tracker.log({"test_samples": wandb_images, "epoch": epoch}, step=global_step) + else: + tracker.log( + {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, + step=global_step, + ) + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -129,6 +193,13 @@ def parse_args(): parser.add_argument( "--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation." ) + parser.add_argument( + "--image_bit_depth", + type=int, + choices=[8, 16], + default=8, + help="Bit depth for generated sample images during evaluation/logging. Default: 8.", + ) parser.add_argument( "--dataloader_num_workers", type=int, @@ -654,20 +725,16 @@ def transform_images(examples): if args.use_ema: ema_model.restore(unet.parameters()) - # denormalize the images and save to tensorboard - images_processed = (images * 255).round().astype("uint8") - - if args.logger == "tensorboard": - if is_accelerate_version(">=", "0.17.0.dev0"): - tracker = accelerator.get_tracker("tensorboard", unwrap=True) - else: - tracker = accelerator.get_tracker("tensorboard") - tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch) - elif args.logger == "wandb": - # Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files - accelerator.get_tracker("wandb").log( - {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, - step=global_step, + images_processed, tb_preview = _prepare_sample_images(images, args.image_bit_depth) + + if args.logger in {"tensorboard", "wandb"}: + _log_sample_images( + accelerator, + images_processed, + tb_preview, + epoch, + global_step, + args.logger, ) if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1: From d5f7f8a53b87fca038eac7d6a80a6beba6f323b8 Mon Sep 17 00:00:00 2001 From: Joseph Turian Date: Wed, 5 Nov 2025 14:49:53 +0000 Subject: [PATCH 2/6] Add optional precision-preserving preprocessing --- .../unconditional_image_generation/README.md | 2 + .../train_unconditional.py | 57 +++++++++++++++++-- 2 files changed, 53 insertions(+), 6 deletions(-) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 22f982509bb1..e18ce4c02b0b 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -104,6 +104,8 @@ To use your own dataset, there are 2 ways: - you can either provide your own folder as `--train_data_dir` - or you can upload your dataset to the hub (possibly as a private repo, if you prefer so), and simply pass the `--dataset_name` argument. +If your dataset contains 16 or 32-bit channels (for example, medical TIFFs), add the `--preserve_input_precision` flag so the preprocessing keeps the original precision while still training a 3-channel model. + Below, we explain both in more detail. #### Provide the dataset as a folder diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 3ffeef13647c..0cc96220b932 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -52,6 +52,24 @@ def _extract_into_tensor(arr, timesteps, broadcast_shape): return res.expand(broadcast_shape) +def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor: + """ + Ensure the tensor has exactly three channels (C, H, W) by repeating or truncating channels when needed. + """ + if tensor.ndim == 2: + tensor = tensor.unsqueeze(0) + channels = tensor.shape[0] + if channels == 3: + return tensor + if channels == 1: + return tensor.repeat(3, 1, 1) + if channels == 2: + return torch.cat([tensor, tensor[:1]], dim=0) + if channels > 3: + return tensor[:3] + raise ValueError(f"Unsupported number of channels: {channels}") + + def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( @@ -260,6 +278,11 @@ def parse_args(): parser.add_argument( "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers." ) + parser.add_argument( + "--preserve_input_precision", + action="store_true", + help="Preserve 16/32-bit image precision by avoiding 8-bit RGB conversion while still producing 3-channel tensors.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -453,19 +476,41 @@ def load_model_hook(models, input_dir): # https://huggingface.co/docs/datasets/v2.4.0/en/image_load#imagefolder # Preprocessing the datasets and DataLoaders creation. + spatial_augmentations = [ + transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), + transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + ] + augmentations = transforms.Compose( - [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), - transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), - transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), + spatial_augmentations + + [ transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ] ) + precision_augmentations = transforms.Compose( + [ + transforms.PILToTensor(), + transforms.Lambda(_ensure_three_channels), + transforms.ConvertImageDtype(torch.float32), + ] + + spatial_augmentations + + [transforms.Normalize([0.5], [0.5])] + ) + def transform_images(examples): - images = [augmentations(image.convert("RGB")) for image in examples["image"]] - return {"input": images} + processed = [] + for image in examples["image"]: + if not args.preserve_input_precision: + processed.append(augmentations(image.convert("RGB"))) + else: + precise_image = image + if precise_image.mode == "P": + precise_image = precise_image.convert("RGB") + processed.append(precision_augmentations(precise_image)) + return {"input": processed} logger.info(f"Dataset size: {len(dataset)}") From d2ed7937bae75c0a895015dc6386a57b6208f452 Mon Sep 17 00:00:00 2001 From: Joseph Turian Date: Thu, 6 Nov 2025 16:57:41 +0000 Subject: [PATCH 3/6] Handle explicit channel cases for 16-bit W&B logging --- .../train_unconditional.py | 60 ++++++++++++++----- 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index e98ce7c6cfb9..cee30b311eca 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -74,17 +74,24 @@ def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor: def _prepare_sample_images(images: np.ndarray, bit_depth: int): """Scale pipeline outputs to uint8/uint16 arrays and build a TensorBoard-safe preview.""" - max_pixel_value = 255 if bit_depth == 8 else 65535 - out_dtype = np.uint8 if bit_depth == 8 else np.uint16 + if bit_depth == 8: + max_pixel_value = 255 + out_dtype = np.uint8 + elif bit_depth == 16: + max_pixel_value = 65535 + out_dtype = np.uint16 + else: + raise ValueError(f"Unsupported image_bit_depth: {bit_depth}") + images_processed = ( np.clip(np.round(images * max_pixel_value), 0, max_pixel_value).astype(out_dtype) ) - if bit_depth == 8: - tb_preview = images_processed - else: + if bit_depth == 16: # TensorBoard accepts float [0,1] or uint8 visuals, so keep a separate preview for the UI. tb_preview = np.clip(np.round(images * 255), 0, 255).astype(np.uint8) + else: + tb_preview = images_processed return images_processed, tb_preview @@ -97,6 +104,11 @@ def _log_sample_images( global_step: int, logger_name: str, ): + if images_processed.dtype not in {np.uint8, np.uint16}: + raise ValueError( + f"Unsupported dtype for logged images: {images_processed.dtype}; expected uint8 or uint16." + ) + if logger_name == "tensorboard": if is_accelerate_version(">=", "0.17.0.dev0"): tracker = accelerator.get_tracker("tensorboard", unwrap=True) @@ -112,18 +124,36 @@ def _log_sample_images( wandb_images = [] for img16 in images_processed: - buffer = io.BytesIO() - if img16.ndim == 3 and img16.shape[-1] == 3: - contiguous = np.ascontiguousarray(img16) - pil_image = Image.frombytes( - "RGB", - (contiguous.shape[1], contiguous.shape[0]), - contiguous.byteswap().tobytes(), - "raw", - "RGB;16B", + if img16.dtype != np.uint16: + raise ValueError( + f"Expected uint16 image for 16-bit logging, received {img16.dtype}." ) - else: + h, w = img16.shape[:2] + buffer = io.BytesIO() + if img16.ndim == 2: pil_image = Image.fromarray(img16, mode="I;16") + elif img16.ndim == 3: + channels = img16.shape[-1] + if channels == 1: + pil_image = Image.fromarray(img16.squeeze(-1), mode="I;16") + elif channels in {3, 4}: + mode = "RGB" if channels == 3 else "RGBA" + contiguous = np.ascontiguousarray(img16) + pil_image = Image.frombytes( + mode, + (w, h), + contiguous.byteswap().tobytes(), + "raw", + f"{mode};16B", + ) + else: + raise ValueError( + f"Unsupported channel count for 16-bit image: {channels}" + ) + else: + raise ValueError( + f"Unsupported array shape for 16-bit image: {img16.shape}" + ) pil_image.save(buffer, format="PNG") buffer.seek(0) wandb_images.append(wandb.Image(buffer)) From 731ff6f36f9ac4879812a2cc87f75c94262ca9f7 Mon Sep 17 00:00:00 2001 From: Joseph Turian Date: Thu, 6 Nov 2025 17:12:06 +0000 Subject: [PATCH 4/6] Improve 16-bit logging helpers --- .../unconditional_image_generation/README.md | 3 +- .../train_unconditional.py | 80 +++++++++++-------- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index eac7f6d4cf75..2ae92a62cb0d 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -45,7 +45,8 @@ accelerate launch train_unconditional.py \ --mixed_precision=no \ --push_to_hub ``` -Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). TensorBoard image summaries will still display 8-bit previews, while W&B keeps the 16-bit artifacts. +Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). TensorBoard image summaries will still display 8-bit previews. W&B stores the uploaded 16-bit PNG files, though web previews may appear 8-bit. +Note on 16-bit logging: the script writes true 16-bit PNGs when `--image_bit_depth 16` is set. TensorBoard shows 8-bit previews. W&B previews are usually 8-bit and behavior may vary by SDK version; we upload the PNG files themselves to preserve 16-bit data, but preview depth isn’t guaranteed. Download the files from the run’s Files/Artifacts tab if you need full precision. An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64 A full training run takes 2 hours on 4xV100 GPUs. diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index cee30b311eca..de3d312bcb02 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -1,10 +1,10 @@ import argparse import inspect -import io import logging import math import os import shutil +import tempfile from datetime import timedelta from pathlib import Path @@ -73,7 +73,7 @@ def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor: def _prepare_sample_images(images: np.ndarray, bit_depth: int): - """Scale pipeline outputs to uint8/uint16 arrays and build a TensorBoard-safe preview.""" + """Scale NHWC float images in [0, 1] to uint8/uint16 arrays and build a TensorBoard-safe preview.""" if bit_depth == 8: max_pixel_value = 255 out_dtype = np.uint8 @@ -114,7 +114,7 @@ def _log_sample_images( tracker = accelerator.get_tracker("tensorboard", unwrap=True) else: tracker = accelerator.get_tracker("tensorboard") - tracker.add_images("test_samples", tb_preview.transpose(0, 3, 1, 2), epoch) + tracker.add_images("test_samples", tb_preview.transpose(0, 3, 1, 2), global_step) elif logger_name == "wandb": import wandb @@ -123,40 +123,52 @@ def _log_sample_images( from PIL import Image wandb_images = [] - for img16 in images_processed: - if img16.dtype != np.uint16: - raise ValueError( - f"Expected uint16 image for 16-bit logging, received {img16.dtype}." - ) - h, w = img16.shape[:2] - buffer = io.BytesIO() - if img16.ndim == 2: - pil_image = Image.fromarray(img16, mode="I;16") - elif img16.ndim == 3: - channels = img16.shape[-1] - if channels == 1: - pil_image = Image.fromarray(img16.squeeze(-1), mode="I;16") - elif channels in {3, 4}: - mode = "RGB" if channels == 3 else "RGBA" - contiguous = np.ascontiguousarray(img16) - pil_image = Image.frombytes( - mode, - (w, h), - contiguous.byteswap().tobytes(), - "raw", - f"{mode};16B", + temp_paths = [] + try: + for img16 in images_processed: + if img16.dtype != np.uint16: + raise ValueError( + f"Expected uint16 image for 16-bit logging, received {img16.dtype}." ) - else: + if img16.ndim not in {2, 3}: raise ValueError( - f"Unsupported channel count for 16-bit image: {channels}" + f"Unsupported array shape for 16-bit image: {img16.shape}" ) - else: - raise ValueError( - f"Unsupported array shape for 16-bit image: {img16.shape}" - ) - pil_image.save(buffer, format="PNG") - buffer.seek(0) - wandb_images.append(wandb.Image(buffer)) + h, w = img16.shape[:2] + if img16.ndim == 3: + channels = img16.shape[-1] + if channels == 1: + array_for_pil = np.ascontiguousarray(img16.squeeze(-1)) + pil_image = Image.fromarray(array_for_pil, mode="I;16") + elif channels in {3, 4}: + mode = "RGB" if channels == 3 else "RGBA" + contiguous = np.ascontiguousarray(img16) + pil_image = Image.frombytes( + mode, + (w, h), + contiguous.byteswap().tobytes(), + "raw", + f"{mode};16B", + ) + else: + raise ValueError( + f"Unsupported channel count for 16-bit image: {channels}" + ) + else: + array_for_pil = np.ascontiguousarray(img16) + pil_image = Image.fromarray(array_for_pil, mode="I;16") + + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + temp_paths.append(tmp.name) + tmp.close() + pil_image.save(temp_paths[-1], format="PNG") + wandb_images.append(wandb.Image(temp_paths[-1])) + finally: + for path in temp_paths: + try: + os.unlink(path) + except OSError: + pass tracker.log({"test_samples": wandb_images, "epoch": epoch}, step=global_step) else: tracker.log( From e15369bed1e719f1c83029e9a5a87fbfe3a185b2 Mon Sep 17 00:00:00 2001 From: Joseph Turian Date: Thu, 6 Nov 2025 17:17:10 +0000 Subject: [PATCH 5/6] Polish 16-bit logging details --- .../unconditional_image_generation/README.md | 3 +- .../train_unconditional.py | 77 ++++++++++--------- 2 files changed, 40 insertions(+), 40 deletions(-) diff --git a/examples/unconditional_image_generation/README.md b/examples/unconditional_image_generation/README.md index 2ae92a62cb0d..a740b175069b 100644 --- a/examples/unconditional_image_generation/README.md +++ b/examples/unconditional_image_generation/README.md @@ -45,8 +45,7 @@ accelerate launch train_unconditional.py \ --mixed_precision=no \ --push_to_hub ``` -Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). TensorBoard image summaries will still display 8-bit previews. W&B stores the uploaded 16-bit PNG files, though web previews may appear 8-bit. -Note on 16-bit logging: the script writes true 16-bit PNGs when `--image_bit_depth 16` is set. TensorBoard shows 8-bit previews. W&B previews are usually 8-bit and behavior may vary by SDK version; we upload the PNG files themselves to preserve 16-bit data, but preview depth isn’t guaranteed. Download the files from the run’s Files/Artifacts tab if you need full precision. +Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). The script writes true 16-bit PNGs, TensorBoard still shows 8-bit previews, and W&B stores the uploaded 16-bit files even though previews are typically 8-bit and may vary by SDK version—download from the Files/Artifacts tab if you need full precision. An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64 A full training run takes 2 hours on 4xV100 GPUs. diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index de3d312bcb02..2a3fc3e4cd4d 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -73,7 +73,10 @@ def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor: def _prepare_sample_images(images: np.ndarray, bit_depth: int): - """Scale NHWC float images in [0, 1] to uint8/uint16 arrays and build a TensorBoard-safe preview.""" + """Scale NHWC float images in [0, 1] to uint8/uint16 arrays and build a TensorBoard-safe preview. + + Expects float inputs with shape (N, H, W, C) and returns the scaled array plus an 8-bit preview. + """ if bit_depth == 8: max_pixel_value = 255 out_dtype = np.uint8 @@ -124,52 +127,50 @@ def _log_sample_images( wandb_images = [] temp_paths = [] - try: - for img16 in images_processed: - if img16.dtype != np.uint16: - raise ValueError( - f"Expected uint16 image for 16-bit logging, received {img16.dtype}." - ) - if img16.ndim not in {2, 3}: - raise ValueError( - f"Unsupported array shape for 16-bit image: {img16.shape}" + for img16 in images_processed: + if img16.dtype != np.uint16: + raise ValueError( + f"Expected uint16 image for 16-bit logging, received {img16.dtype}." + ) + if img16.ndim not in {2, 3}: + raise ValueError(f"Unsupported array shape for 16-bit image: {img16.shape}") + h, w = img16.shape[:2] + if img16.ndim == 3: + channels = img16.shape[-1] + if channels == 1: + array_for_pil = np.ascontiguousarray(img16.squeeze(-1)) + pil_image = Image.fromarray(array_for_pil, mode="I;16") + elif channels in {3, 4}: + mode = "RGB" if channels == 3 else "RGBA" + contiguous = np.ascontiguousarray(img16) + pil_image = Image.frombytes( + mode, + (w, h), + contiguous.byteswap().tobytes(), + "raw", + f"{mode};16B", ) - h, w = img16.shape[:2] - if img16.ndim == 3: - channels = img16.shape[-1] - if channels == 1: - array_for_pil = np.ascontiguousarray(img16.squeeze(-1)) - pil_image = Image.fromarray(array_for_pil, mode="I;16") - elif channels in {3, 4}: - mode = "RGB" if channels == 3 else "RGBA" - contiguous = np.ascontiguousarray(img16) - pil_image = Image.frombytes( - mode, - (w, h), - contiguous.byteswap().tobytes(), - "raw", - f"{mode};16B", - ) - else: - raise ValueError( - f"Unsupported channel count for 16-bit image: {channels}" - ) else: - array_for_pil = np.ascontiguousarray(img16) - pil_image = Image.fromarray(array_for_pil, mode="I;16") + raise ValueError(f"Unsupported channel count for 16-bit image: {channels}") + else: + array_for_pil = np.ascontiguousarray(img16) + pil_image = Image.fromarray(array_for_pil, mode="I;16") - tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) - temp_paths.append(tmp.name) - tmp.close() - pil_image.save(temp_paths[-1], format="PNG") - wandb_images.append(wandb.Image(temp_paths[-1])) + tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) + temp_paths.append(tmp.name) + tmp.close() + # wandb.Image does not accept BytesIO; a file path preserves 16-bit PNG bytes on upload. + pil_image.save(temp_paths[-1], format="PNG") + wandb_images.append(wandb.Image(temp_paths[-1])) + + try: + tracker.log({"test_samples": wandb_images, "epoch": epoch}, step=global_step) finally: for path in temp_paths: try: os.unlink(path) except OSError: pass - tracker.log({"test_samples": wandb_images, "epoch": epoch}, step=global_step) else: tracker.log( {"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch}, From e0b224484deb6a65f36c4a25966fcaa4e0782eab Mon Sep 17 00:00:00 2001 From: Joseph Turian Date: Thu, 6 Nov 2025 17:30:24 +0000 Subject: [PATCH 6/6] Document 16-bit PNG encoding expectations --- examples/unconditional_image_generation/train_unconditional.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/examples/unconditional_image_generation/train_unconditional.py b/examples/unconditional_image_generation/train_unconditional.py index 2a3fc3e4cd4d..0085d37ad69f 100644 --- a/examples/unconditional_image_generation/train_unconditional.py +++ b/examples/unconditional_image_generation/train_unconditional.py @@ -123,6 +123,7 @@ def _log_sample_images( tracker = accelerator.get_tracker("wandb") if images_processed.dtype == np.uint16: + # Pillow required to encode true 16-bit PNGs for W&B; fall back to 8-bit previews if unavailable. from PIL import Image wandb_images = [] @@ -146,6 +147,7 @@ def _log_sample_images( pil_image = Image.frombytes( mode, (w, h), + # Pillow expects big-endian for RGB(A); use ";16B" and byteswap() to match. contiguous.byteswap().tobytes(), "raw", f"{mode};16B",