Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 123 additions & 20 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate


PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
]


class VaeImageProcessor(ConfigMixin):
"""
Image processor for VAE.
Expand All @@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin):
Resampling filter to use when resizing the image.
do_normalize (`bool`, *optional*, defaults to `True`):
Whether to normalize the image to [-1,1].
do_binarize (`bool`, *optional*, defaults to `True`):
Whether to binarize the image to 0/1.
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to RGB format.
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
Whether to convert the images to grayscale format.
"""

config_name = CONFIG_NAME
Expand All @@ -51,9 +65,18 @@ def __init__(
vae_scale_factor: int = 8,
resample: str = "lanczos",
do_normalize: bool = True,
do_binarize: bool = False,
do_convert_rgb: bool = False,
do_convert_grayscale: bool = False,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice idea!

):
super().__init__()
if do_convert_rgb and do_convert_grayscale:
raise ValueError(
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
)
Comment on lines +74 to +78
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very descriptive! Looks good.

self.config.do_convert_rgb = False

@staticmethod
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image:
Expand Down Expand Up @@ -119,31 +142,84 @@ def denormalize(images):
@staticmethod
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts an image to RGB format.
Converts a PIL image to RGB format.
"""
image = image.convert("RGB")

return image

def resize(
@staticmethod
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
"""
Converts a PIL image to grayscale format.
"""
image = image.convert("L")

return image

def get_default_height_width(
self,
image: PIL.Image.Image,
image: [PIL.Image.Image, np.ndarray, torch.Tensor],
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
):
"""
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`.
This function return the height and width that are downscaled to the next integer multiple of
`vae_scale_factor`.

Args:
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should
have shape `[batch, channel, height, width]`.
height (`int`, *optional*, defaults to `None`):
The height in preprocessed image. If `None`, will use the height of `image` input.
width (`int`, *optional*`, defaults to `None`):
The width in preprocessed. If `None`, will use the width of the `image` input.
"""

if height is None:
height = image.height
if isinstance(image, PIL.Image.Image):
height = image.height
elif isinstance(image, torch.Tensor):
height = image.shape[2]
else:
height = image.shape[1]

if width is None:
width = image.width
if isinstance(image, PIL.Image.Image):
width = image.width
elif isinstance(image, torch.Tensor):
width = image.shape[3]
else:
height = image.shape[2]

width, height = (
x - x % self.config.vae_scale_factor for x in (width, height)
) # resize to integer multiple of vae_scale_factor

return height, width

def resize(
self,
image: PIL.Image.Image,
height: Optional[int] = None,
width: Optional[int] = None,
) -> PIL.Image.Image:
"""
Resize a PIL image.
"""
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample])
return image

def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
"""
create a mask
"""
image[image < 0.5] = 0
image[image >= 0.5] = 1
return image
Comment on lines +219 to +221
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should 0 and 1 be registered into the config vars?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think 0 and 1 is global enough to not have to be added to the config


def preprocess(
self,
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
Expand All @@ -154,6 +230,25 @@ def preprocess(
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors.
"""
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)

# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
if isinstance(image, torch.Tensor):
# if image is a pytorch tensor could have 2 possible shapes:
# 1. batch x height x width: we should insert the channel dimension at position 1
# 2. channnel x height x width: we should insert batch dimension at position 0,
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
image = image.unsqueeze(1)
else:
# if it is a numpy array, it could have 2 possible shapes:
# 1. batch x height x width: insert channel dimension on last position
# 2. height x width x channel: insert batch dimension on first position
if image.shape[-1] == 1:
image = np.expand_dims(image, axis=0)
else:
image = np.expand_dims(image, axis=-1)

Comment on lines +236 to +251
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For easier operability, does it make sense to first convert the input tensor to a NumPy array and then operate from there?

Copy link
Collaborator Author

@yiyixuxu yiyixuxu Aug 4, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul

the output of preprocess is pytorch tensors, why would we want to convert to numpy array first?

the reason we want to this step here is because rest of our preprocessing logic assumes a 3D tensor has shape channel x height x width, which doesn't apply to grayscale images

for example if we have a tensor with shape (5, 256,256), if we try to process with current logic directly, we will do this:
(5, 256,256) -> put it into a list :[ (5,256,256)] -> torch.stack the list when not 4d: (1, 5, 256, 256)

the correct would be (5, 1, 256, 256)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why cannot we add a dummy channel to represent the grayscale images then? Just trying to understand it better.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sayakpaul
yeah and that's what this section of code is doing:
Adding a dummy channel to represent either a missing channel or batch dimension to remove the ambiguity

if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
Expand All @@ -164,42 +259,47 @@ def preprocess(
if isinstance(image[0], PIL.Image.Image):
if self.config.do_convert_rgb:
image = [self.convert_to_rgb(i) for i in image]
elif self.config.do_convert_grayscale:
image = [self.convert_to_grayscale(i) for i in image]
if self.config.do_resize:
height, width = self.get_default_height_width(image[0], height, width)
image = [self.resize(i, height, width) for i in image]
image = self.pil_to_numpy(image) # to np
image = self.numpy_to_pt(image) # to pt

elif isinstance(image[0], np.ndarray):
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)

image = self.numpy_to_pt(image)
_, _, height, width = image.shape
if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):

height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)

elif isinstance(image[0], torch.Tensor):
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
_, channel, height, width = image.shape

if self.config.do_convert_grayscale and image.ndim == 3:
image = image.unsqueeze(1)

channel = image.shape[1]
# don't need any preprocess if the image is latents
if channel == 4:
return image

if self.config.do_resize and (
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0
):
height, width = self.get_default_height_width(image, height, width)
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width):
raise ValueError(
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}"
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}"
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor"
)

# expected range [0,1], normalize to [-1,1]
do_normalize = self.config.do_normalize
if image.min() < 0:
if image.min() < 0 and do_normalize:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if user configured the image_processor to have do_normalize=False, the expected range should be [-1,1] and we shouldn't send a warning (this is the case for controlnet image)

warnings.warn(
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
Expand All @@ -210,6 +310,9 @@ def preprocess(
if do_normalize:
image = self.normalize(image)

if self.config.do_binarize:
image = self.binarize(image)

return image

def postprocess(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from diffusers.utils import is_accelerate_available, is_accelerate_version

from ...configuration_utils import FrozenDict
from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -567,14 +567,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
strength: float = 0.8,
num_inference_steps: Optional[int] = 50,
guidance_scale: Optional[float] = 7.5,
Expand All @@ -597,7 +590,10 @@ def __call__(
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
`Image` or tensor representing an image batch to be used as the starting point. Can also accept image
`Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image
latents as `image`, but if passing latents directly it is not encoded again.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
Expand Down
11 changes: 2 additions & 9 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -678,14 +678,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
Expand Down
20 changes: 3 additions & 17 deletions src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import torch.nn.functional as F
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer

from ...image_processor import VaeImageProcessor
from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
from ...schedulers import KarrasDiffusionSchedulers
Expand Down Expand Up @@ -750,22 +750,8 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
def __call__(
self,
prompt: Union[str, List[str]] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
control_image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
] = None,
image: PipelineImageInput = None,
control_image: PipelineImageInput = None,
height: Optional[int] = None,
width: Optional[int] = None,
strength: float = 0.8,
Expand Down
Loading