From d95f9b8f22e08eafd099cfcbaa36eb4336de79c0 Mon Sep 17 00:00:00 2001 From: Thomas Date: Mon, 26 Feb 2024 07:20:14 -0800 Subject: [PATCH 1/5] Add vae_roundtrip.py example --- examples/inference/vae_roundtrip.py | 183 ++++++++++++++++++++++++++++ 1 file changed, 183 insertions(+) create mode 100644 examples/inference/vae_roundtrip.py diff --git a/examples/inference/vae_roundtrip.py b/examples/inference/vae_roundtrip.py new file mode 100644 index 000000000000..b76c6b764b63 --- /dev/null +++ b/examples/inference/vae_roundtrip.py @@ -0,0 +1,183 @@ +#!/usr/bin/env python +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import argparse +import torch + +from diffusers import AutoencoderKL +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.models.autoencoders.vae import DecoderOutput +from PIL import Image +from torchvision import transforms # type: ignore +from typing import Optional + + +def load_vae_model( + *, + model_name_or_path: str, + revision: Optional[str], + variant: Optional[str], + # NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE + subfolder: Optional[str], +) -> AutoencoderKL: + vae = AutoencoderKL.from_pretrained( # type: ignore + model_name_or_path, + subfolder=subfolder, + revision=revision, + variant=variant, + ) + assert isinstance(vae, AutoencoderKL) + vae.eval() # Set the model to inference mode + return vae + + +def preprocess_image( + *, + image_path: str, +) -> torch.FloatTensor: + image = Image.open(image_path).convert("RGB") + transform = transforms.Compose( + [ + transforms.ToTensor(), + ] + ) + nhwc = transform(image).unsqueeze(0) # type: ignore + assert isinstance(nhwc, torch.FloatTensor) + return nhwc + + +def postprocess_image( + *, + nhwc: torch.FloatTensor, +) -> Image.Image: + assert nhwc.shape[0] == 1 + hwc = nhwc.squeeze(0) + return transforms.ToPILImage()(hwc) # type: ignore + + +def concatenate_images( + *, + left: Image.Image, + right: Image.Image, +) -> Image.Image: + width1, height1 = left.size + width2, height2 = right.size + total_width = width1 + width2 + max_height = max(height1, height2) + + new_image = Image.new("RGB", (total_width, max_height)) + new_image.paste(left, (0, 0)) + new_image.paste(right, (width1, 0)) + return new_image + + +def infer_and_show_images( + *, + input_image_path: str, + pretrained_model_name_or_path: str, + revision: Optional[str], + variant: Optional[str], + subfolder: Optional[str], +) -> None: + vae = load_vae_model( + model_name_or_path=pretrained_model_name_or_path, + revision=revision, + variant=variant, + subfolder=subfolder, + ) + original_image = preprocess_image(image_path=input_image_path) + with torch.no_grad(): + encoding = vae.encode(original_image) + assert isinstance(encoding, AutoencoderKLOutput) + latent = encoding.latent_dist.sample() # type: ignore + assert isinstance(latent, torch.FloatTensor) + decoding = vae.decode(latent) # type: ignore + assert isinstance(decoding, DecoderOutput) + reconstructed_image = decoding.sample + + original_pil = postprocess_image(nhwc=original_image) + reconstructed_pil = postprocess_image(nhwc=reconstructed_image) + + combined_image = concatenate_images( + left=original_pil, + right=reconstructed_pil, + ) + combined_image.show("Original | Reconstruction") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description="Inference with VAE") + parser.add_argument( + "--input_image", + type=str, + required=True, + help="Path to the input image for inference.", + ) + parser.add_argument( + "--pretrained_model_name_or_path", + type=str, + required=True, + help="Path to pretrained VAE model.", + ) + parser.add_argument( + "--revision", + type=str, + default=None, + help="Model version.", + ) + parser.add_argument( + "--variant", + type=str, + default=None, + help="Model file variant, e.g., 'fp16'.", + ) + parser.add_argument( + "--subfolder", + type=str, + default=None, + help="Subfolder in the model file.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + input_image_path = args.input_image + assert isinstance(input_image_path, str) + + pretrained_model_name_or_path = args.pretrained_model_name_or_path + assert isinstance(pretrained_model_name_or_path, str) + + revision = args.revision + assert revision is None or isinstance(revision, str) + + variant = args.variant + assert variant is None or isinstance(variant, str) + + subfolder = args.subfolder + assert subfolder is None or isinstance(subfolder, str) + + infer_and_show_images( + input_image_path=input_image_path, + pretrained_model_name_or_path=pretrained_model_name_or_path, + revision=revision, + variant=variant, + subfolder=subfolder, + ) + + +if __name__ == "__main__": + main() From 56895375e88a1a992df31ba3c161c27bc5ef3460 Mon Sep 17 00:00:00 2001 From: Thomas Date: Mon, 26 Feb 2024 07:52:15 -0800 Subject: [PATCH 2/5] Add cuda support to vae_roundtrip --- examples/inference/vae_roundtrip.py | 36 ++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 8 deletions(-) diff --git a/examples/inference/vae_roundtrip.py b/examples/inference/vae_roundtrip.py index b76c6b764b63..d380e7aff8bd 100644 --- a/examples/inference/vae_roundtrip.py +++ b/examples/inference/vae_roundtrip.py @@ -15,6 +15,7 @@ import argparse import torch +import typing from diffusers import AutoencoderKL from diffusers.models.modeling_outputs import AutoencoderKLOutput @@ -26,6 +27,7 @@ def load_vae_model( *, + device: torch.device, model_name_or_path: str, revision: Optional[str], variant: Optional[str], @@ -39,31 +41,33 @@ def load_vae_model( variant=variant, ) assert isinstance(vae, AutoencoderKL) + vae = vae.to(device) vae.eval() # Set the model to inference mode return vae def preprocess_image( *, + device: torch.device, image_path: str, -) -> torch.FloatTensor: +) -> torch.Tensor: image = Image.open(image_path).convert("RGB") transform = transforms.Compose( [ transforms.ToTensor(), ] ) - nhwc = transform(image).unsqueeze(0) # type: ignore - assert isinstance(nhwc, torch.FloatTensor) + nhwc = transform(image).unsqueeze(0).to(device) # type: ignore + assert isinstance(nhwc, torch.Tensor) return nhwc def postprocess_image( *, - nhwc: torch.FloatTensor, + nhwc: torch.Tensor, ) -> Image.Image: assert nhwc.shape[0] == 1 - hwc = nhwc.squeeze(0) + hwc = nhwc.squeeze(0).cpu() return transforms.ToPILImage()(hwc) # type: ignore @@ -85,6 +89,7 @@ def concatenate_images( def infer_and_show_images( *, + device: torch.device, input_image_path: str, pretrained_model_name_or_path: str, revision: Optional[str], @@ -92,17 +97,21 @@ def infer_and_show_images( subfolder: Optional[str], ) -> None: vae = load_vae_model( + device=device, model_name_or_path=pretrained_model_name_or_path, revision=revision, variant=variant, subfolder=subfolder, ) - original_image = preprocess_image(image_path=input_image_path) + original_image = preprocess_image( + device=device, + image_path=input_image_path, + ) with torch.no_grad(): - encoding = vae.encode(original_image) + encoding = vae.encode(typing.cast(torch.FloatTensor, original_image)) assert isinstance(encoding, AutoencoderKLOutput) latent = encoding.latent_dist.sample() # type: ignore - assert isinstance(latent, torch.FloatTensor) + assert isinstance(latent, torch.Tensor) decoding = vae.decode(latent) # type: ignore assert isinstance(decoding, DecoderOutput) reconstructed_image = decoding.sample @@ -149,6 +158,11 @@ def parse_args() -> argparse.Namespace: default=None, help="Subfolder in the model file.", ) + parser.add_argument( + "--use_cuda", + action="store_true", + help="Use CUDA if available.", + ) return parser.parse_args() @@ -170,7 +184,13 @@ def main() -> None: subfolder = args.subfolder assert subfolder is None or isinstance(subfolder, str) + use_cuda = args.use_cuda + assert isinstance(use_cuda, bool) + + device = torch.device("cuda" if use_cuda else "cpu") + infer_and_show_images( + device=device, input_image_path=input_image_path, pretrained_model_name_or_path=pretrained_model_name_or_path, revision=revision, From c65229e416a9de0f962eadeadc53e18e64754473 Mon Sep 17 00:00:00 2001 From: Thomas Date: Wed, 28 Feb 2024 19:51:56 -0800 Subject: [PATCH 3/5] Move vae_roundtrip.py into research_projects/vae --- examples/research_projects/vae/README.md | 11 +++++++++++ .../vae}/vae_roundtrip.py | 0 2 files changed, 11 insertions(+) create mode 100644 examples/research_projects/vae/README.md rename examples/{inference => research_projects/vae}/vae_roundtrip.py (100%) diff --git a/examples/research_projects/vae/README.md b/examples/research_projects/vae/README.md new file mode 100644 index 000000000000..2e24c955b7ae --- /dev/null +++ b/examples/research_projects/vae/README.md @@ -0,0 +1,11 @@ +# VAE + +`vae_roundtrip.py` Demonstrates the use of a VAE by roundtripping an image through the encoder and decoder. Original and reconstructed images are displayed side by side. + +``` +cd examples/research_projects/vae +python vae_roundtrip.py \ + --pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5" \ + --subfolder="vae" \ + --input_image="/path/to/your/input.png" +``` diff --git a/examples/inference/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py similarity index 100% rename from examples/inference/vae_roundtrip.py rename to examples/research_projects/vae/vae_roundtrip.py From 90cd240baff3d32d1c48a7d20eb7eb03350f3b01 Mon Sep 17 00:00:00 2001 From: Thomas Date: Sun, 30 Jun 2024 20:38:18 -0700 Subject: [PATCH 4/5] Fix channel scaling in vae roundrip and also support taesd. --- .../research_projects/vae/vae_roundtrip.py | 170 +++++++++++++----- 1 file changed, 124 insertions(+), 46 deletions(-) diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py index d380e7aff8bd..5981010a6d05 100644 --- a/examples/research_projects/vae/vae_roundtrip.py +++ b/examples/research_projects/vae/vae_roundtrip.py @@ -17,12 +17,22 @@ import torch import typing -from diffusers import AutoencoderKL -from diffusers.models.modeling_outputs import AutoencoderKLOutput +from diffusers.image_processor import VaeImageProcessor +from diffusers.models.autoencoders.autoencoder_kl import ( + AutoencoderKL, + AutoencoderKLOutput, +) +from diffusers.models.autoencoders.autoencoder_tiny import ( + AutoencoderTiny, + AutoencoderTinyOutput, +) from diffusers.models.autoencoders.vae import DecoderOutput from PIL import Image from torchvision import transforms # type: ignore -from typing import Optional +from typing import Optional, Union + + +SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny] def load_vae_model( @@ -33,36 +43,47 @@ def load_vae_model( variant: Optional[str], # NOTE: use subfolder="vae" if the pointed model is for stable diffusion as a whole instead of just the VAE subfolder: Optional[str], -) -> AutoencoderKL: - vae = AutoencoderKL.from_pretrained( # type: ignore - model_name_or_path, - subfolder=subfolder, - revision=revision, - variant=variant, - ) - assert isinstance(vae, AutoencoderKL) + use_tiny_nn: bool, +) -> SupportedAutoencoder: + if use_tiny_nn: + # NOTE: These scaling factors don't have to be the same as each other. + down_scale = 2 + up_scale = 2 + vae = AutoencoderTiny.from_pretrained( # type: ignore + model_name_or_path, + subfolder=subfolder, + revision=revision, + variant=variant, + downscaling_scaling_factor=down_scale, + upsampling_scaling_factor=up_scale, + ) + assert isinstance(vae, AutoencoderTiny) + else: + vae = AutoencoderKL.from_pretrained( # type: ignore + model_name_or_path, + subfolder=subfolder, + revision=revision, + variant=variant, + ) + assert isinstance(vae, AutoencoderKL) vae = vae.to(device) vae.eval() # Set the model to inference mode return vae -def preprocess_image( +def pil_to_nhwc( *, device: torch.device, - image_path: str, + image: Image.Image, ) -> torch.Tensor: - image = Image.open(image_path).convert("RGB") - transform = transforms.Compose( - [ - transforms.ToTensor(), - ] - ) + assert image.mode == "RGB" + transform = transforms.ToTensor() nhwc = transform(image).unsqueeze(0).to(device) # type: ignore assert isinstance(nhwc, torch.Tensor) return nhwc -def postprocess_image( +def nhwc_to_pil( *, nhwc: torch.Tensor, ) -> Image.Image: @@ -75,19 +96,60 @@ def concatenate_images( *, left: Image.Image, right: Image.Image, + vertical: bool = False, ) -> Image.Image: width1, height1 = left.size width2, height2 = right.size - total_width = width1 + width2 - max_height = max(height1, height2) - - new_image = Image.new("RGB", (total_width, max_height)) - new_image.paste(left, (0, 0)) - new_image.paste(right, (width1, 0)) + if vertical: + total_height = height1 + height2 + max_width = max(width1, width2) + new_image = Image.new("RGB", (max_width, total_height)) + new_image.paste(left, (0, 0)) + new_image.paste(right, (0, height1)) + else: + total_width = width1 + width2 + max_height = max(height1, height2) + new_image = Image.new("RGB", (total_width, max_height)) + new_image.paste(left, (0, 0)) + new_image.paste(right, (width1, 0)) return new_image -def infer_and_show_images( +def to_latent( + *, + rgb_nchw: torch.Tensor, + vae: SupportedAutoencoder, +) -> torch.Tensor: + rgb_nchw = VaeImageProcessor.normalize(rgb_nchw) # type: ignore + encoding_nchw = vae.encode(typing.cast(torch.FloatTensor, rgb_nchw)) + if isinstance(encoding_nchw, AutoencoderKLOutput): + latent = encoding_nchw.latent_dist.sample() # type: ignore + assert isinstance(latent, torch.Tensor) + elif isinstance(encoding_nchw, AutoencoderTinyOutput): + latent = encoding_nchw.latents + do_internal_vae_scaling = False # Is this needed? + if do_internal_vae_scaling: + latent = vae.scale_latents(latent).mul(255).round().byte() # type: ignore + latent = vae.unscale_latents(latent / 255.0) # type: ignore + assert isinstance(latent, torch.Tensor) + else: + assert False, f"Unknown encoding type: {type(encoding_nchw)}" + return latent + + +def from_latent( + *, + latent_nchw: torch.Tensor, + vae: SupportedAutoencoder, +) -> torch.Tensor: + decoding_nchw = vae.decode(latent_nchw) # type: ignore + assert isinstance(decoding_nchw, DecoderOutput) + rgb_nchw = VaeImageProcessor.denormalize(decoding_nchw.sample) # type: ignore + assert isinstance(rgb_nchw, torch.Tensor) + return rgb_nchw + + +def main_kwargs( *, device: torch.device, input_image_path: str, @@ -95,6 +157,7 @@ def infer_and_show_images( revision: Optional[str], variant: Optional[str], subfolder: Optional[str], + use_tiny_nn: bool, ) -> None: vae = load_vae_model( device=device, @@ -102,28 +165,28 @@ def infer_and_show_images( revision=revision, variant=variant, subfolder=subfolder, + use_tiny_nn=use_tiny_nn, ) - original_image = preprocess_image( + original_pil = Image.open(input_image_path).convert("RGB") + original_image = pil_to_nhwc( device=device, - image_path=input_image_path, + image=original_pil, ) - with torch.no_grad(): - encoding = vae.encode(typing.cast(torch.FloatTensor, original_image)) - assert isinstance(encoding, AutoencoderKLOutput) - latent = encoding.latent_dist.sample() # type: ignore - assert isinstance(latent, torch.Tensor) - decoding = vae.decode(latent) # type: ignore - assert isinstance(decoding, DecoderOutput) - reconstructed_image = decoding.sample - - original_pil = postprocess_image(nhwc=original_image) - reconstructed_pil = postprocess_image(nhwc=reconstructed_image) + print(f"Original image shape: {original_image.shape}") + reconstructed_image: Optional[torch.Tensor] = None + with torch.no_grad(): + latent_image = to_latent(rgb_nchw=original_image, vae=vae) + print(f"Latent shape: {latent_image.shape}") + reconstructed_image = from_latent(latent_nchw=latent_image, vae=vae) + reconstructed_pil = nhwc_to_pil(nhwc=reconstructed_image) combined_image = concatenate_images( left=original_pil, right=reconstructed_pil, + vertical=False, ) combined_image.show("Original | Reconstruction") + print(f"Reconstructed image shape: {reconstructed_image.shape}") def parse_args() -> argparse.Namespace: @@ -163,10 +226,21 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Use CUDA if available.", ) + parser.add_argument( + "--use_tiny_nn", + action="store_true", + help="Use tiny neural network.", + ) return parser.parse_args() -def main() -> None: +# EXAMPLE USAGE: +# +# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "runwayml/stable-diffusion-v1-5" --subfolder "vae" --input_image "foo.png" +# +# python vae_roundtrip.py --use_cuda --pretrained_model_name_or_path "madebyollin/taesd" --use_tiny_nn --input_image "foo.png" +# +def main_cli() -> None: args = parse_args() input_image_path = args.input_image @@ -176,28 +250,32 @@ def main() -> None: assert isinstance(pretrained_model_name_or_path, str) revision = args.revision - assert revision is None or isinstance(revision, str) + assert isinstance(revision, (str, type(None))) variant = args.variant - assert variant is None or isinstance(variant, str) + assert isinstance(variant, (str, type(None))) subfolder = args.subfolder - assert subfolder is None or isinstance(subfolder, str) + assert isinstance(subfolder, (str, type(None))) use_cuda = args.use_cuda assert isinstance(use_cuda, bool) + use_tiny_nn = args.use_tiny_nn + assert isinstance(use_tiny_nn, bool) + device = torch.device("cuda" if use_cuda else "cpu") - infer_and_show_images( + main_kwargs( device=device, input_image_path=input_image_path, pretrained_model_name_or_path=pretrained_model_name_or_path, revision=revision, variant=variant, subfolder=subfolder, + use_tiny_nn=use_tiny_nn, ) if __name__ == "__main__": - main() + main_cli() From b59b43589096e68e7cf7ad4bd393beccee9007bc Mon Sep 17 00:00:00 2001 From: Thomas Date: Wed, 3 Jul 2024 19:31:11 -0700 Subject: [PATCH 5/5] Apply ruff --fix for CI gatekeep check --- examples/research_projects/vae/vae_roundtrip.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/examples/research_projects/vae/vae_roundtrip.py b/examples/research_projects/vae/vae_roundtrip.py index 5981010a6d05..65c2b43a9bde 100644 --- a/examples/research_projects/vae/vae_roundtrip.py +++ b/examples/research_projects/vae/vae_roundtrip.py @@ -14,8 +14,12 @@ # See the License for the specific language governing permissions and import argparse -import torch import typing +from typing import Optional, Union + +import torch +from PIL import Image +from torchvision import transforms # type: ignore from diffusers.image_processor import VaeImageProcessor from diffusers.models.autoencoders.autoencoder_kl import ( @@ -27,9 +31,6 @@ AutoencoderTinyOutput, ) from diffusers.models.autoencoders.vae import DecoderOutput -from PIL import Image -from torchvision import transforms # type: ignore -from typing import Optional, Union SupportedAutoencoder = Union[AutoencoderKL, AutoencoderTiny]