Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/unconditional_image_generation/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ accelerate launch train_unconditional.py \
--mixed_precision=no \
--push_to_hub
```
Append `--image_bit_depth 16` if you want evaluation samples logged as 16-bit PNGs (default is 8-bit). The script writes true 16-bit PNGs, TensorBoard still shows 8-bit previews, and W&B stores the uploaded 16-bit files even though previews are typically 8-bit and may vary by SDK version—download from the Files/Artifacts tab if you need full precision.
An example trained model: https://huggingface.co/anton-l/ddpm-ema-flowers-64

A full training run takes 2 hours on 4xV100 GPUs.
Expand Down
138 changes: 125 additions & 13 deletions examples/unconditional_image_generation/train_unconditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import math
import os
import shutil
import tempfile
from datetime import timedelta
from pathlib import Path

import accelerate
import datasets
import numpy as np
import torch
import torch.nn.functional as F
from accelerate import Accelerator, InitProcessGroupKwargs
Expand Down Expand Up @@ -70,6 +72,113 @@ def _ensure_three_channels(tensor: torch.Tensor) -> torch.Tensor:
raise ValueError(f"Unsupported number of channels: {channels}")


def _prepare_sample_images(images: np.ndarray, bit_depth: int):
"""Scale NHWC float images in [0, 1] to uint8/uint16 arrays and build a TensorBoard-safe preview.

Expects float inputs with shape (N, H, W, C) and returns the scaled array plus an 8-bit preview.
"""
if bit_depth == 8:
max_pixel_value = 255
out_dtype = np.uint8
elif bit_depth == 16:
max_pixel_value = 65535
out_dtype = np.uint16
else:
raise ValueError(f"Unsupported image_bit_depth: {bit_depth}")

images_processed = (
np.clip(np.round(images * max_pixel_value), 0, max_pixel_value).astype(out_dtype)
)

if bit_depth == 16:
# TensorBoard accepts float [0,1] or uint8 visuals, so keep a separate preview for the UI.
tb_preview = np.clip(np.round(images * 255), 0, 255).astype(np.uint8)
else:
tb_preview = images_processed

return images_processed, tb_preview


def _log_sample_images(
accelerator: Accelerator,
images_processed: np.ndarray,
tb_preview: np.ndarray,
epoch: int,
global_step: int,
logger_name: str,
):
if images_processed.dtype not in {np.uint8, np.uint16}:
raise ValueError(
f"Unsupported dtype for logged images: {images_processed.dtype}; expected uint8 or uint16."
)

if logger_name == "tensorboard":
if is_accelerate_version(">=", "0.17.0.dev0"):
tracker = accelerator.get_tracker("tensorboard", unwrap=True)
else:
tracker = accelerator.get_tracker("tensorboard")
tracker.add_images("test_samples", tb_preview.transpose(0, 3, 1, 2), global_step)
elif logger_name == "wandb":
import wandb

tracker = accelerator.get_tracker("wandb")
if images_processed.dtype == np.uint16:
# Pillow required to encode true 16-bit PNGs for W&B; fall back to 8-bit previews if unavailable.
from PIL import Image

wandb_images = []
temp_paths = []
for img16 in images_processed:
if img16.dtype != np.uint16:
raise ValueError(
f"Expected uint16 image for 16-bit logging, received {img16.dtype}."
)
if img16.ndim not in {2, 3}:
raise ValueError(f"Unsupported array shape for 16-bit image: {img16.shape}")
h, w = img16.shape[:2]
if img16.ndim == 3:
channels = img16.shape[-1]
if channels == 1:
array_for_pil = np.ascontiguousarray(img16.squeeze(-1))
pil_image = Image.fromarray(array_for_pil, mode="I;16")
elif channels in {3, 4}:
mode = "RGB" if channels == 3 else "RGBA"
contiguous = np.ascontiguousarray(img16)
pil_image = Image.frombytes(
mode,
(w, h),
# Pillow expects big-endian for RGB(A); use ";16B" and byteswap() to match.
contiguous.byteswap().tobytes(),
"raw",
f"{mode};16B",
)
else:
raise ValueError(f"Unsupported channel count for 16-bit image: {channels}")
else:
array_for_pil = np.ascontiguousarray(img16)
pil_image = Image.fromarray(array_for_pil, mode="I;16")

tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
temp_paths.append(tmp.name)
tmp.close()
# wandb.Image does not accept BytesIO; a file path preserves 16-bit PNG bytes on upload.
pil_image.save(temp_paths[-1], format="PNG")
wandb_images.append(wandb.Image(temp_paths[-1]))

try:
tracker.log({"test_samples": wandb_images, "epoch": epoch}, step=global_step)
finally:
for path in temp_paths:
try:
os.unlink(path)
except OSError:
pass
else:
tracker.log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
)

def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
Expand Down Expand Up @@ -147,6 +256,13 @@ def parse_args():
parser.add_argument(
"--eval_batch_size", type=int, default=16, help="The number of images to generate for evaluation."
)
parser.add_argument(
"--image_bit_depth",
type=int,
choices=[8, 16],
default=8,
help="Bit depth for generated sample images during evaluation/logging. Default: 8.",
)
parser.add_argument(
"--dataloader_num_workers",
type=int,
Expand Down Expand Up @@ -699,20 +815,16 @@ def transform_images(examples):
if args.use_ema:
ema_model.restore(unet.parameters())

# denormalize the images and save to tensorboard
images_processed = (images * 255).round().astype("uint8")
images_processed, tb_preview = _prepare_sample_images(images, args.image_bit_depth)

if args.logger == "tensorboard":
if is_accelerate_version(">=", "0.17.0.dev0"):
tracker = accelerator.get_tracker("tensorboard", unwrap=True)
else:
tracker = accelerator.get_tracker("tensorboard")
tracker.add_images("test_samples", images_processed.transpose(0, 3, 1, 2), epoch)
elif args.logger == "wandb":
# Upcoming `log_images` helper coming in https://github.com/huggingface/accelerate/pull/962/files
accelerator.get_tracker("wandb").log(
{"test_samples": [wandb.Image(img) for img in images_processed], "epoch": epoch},
step=global_step,
if args.logger in {"tensorboard", "wandb"}:
_log_sample_images(
accelerator,
images_processed,
tb_preview,
epoch,
global_step,
args.logger,
)

if epoch % args.save_model_epochs == 0 or epoch == args.num_epochs - 1:
Expand Down