-
Notifications
You must be signed in to change notification settings - Fork 6.4k
refactor prepare_mask_and_masked_image with VaeImageProcessor #4444
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a3c8c1c
a6bffca
11328e5
520dd47
6e1d59c
84f7037
4e46ea1
3f5e046
ccbfcab
6aa4114
8d7c091
0f09e72
5a86b88
12cf87e
05ab579
b813ee6
8a225bc
7c920ac
317b130
280709f
51442dc
c24c2ad
c171912
d426dc1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,6 +24,16 @@ | |
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate | ||
|
||
|
||
PipelineImageInput = Union[ | ||
PIL.Image.Image, | ||
np.ndarray, | ||
torch.FloatTensor, | ||
List[PIL.Image.Image], | ||
List[np.ndarray], | ||
List[torch.FloatTensor], | ||
] | ||
|
||
|
||
class VaeImageProcessor(ConfigMixin): | ||
""" | ||
Image processor for VAE. | ||
|
@@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin): | |
Resampling filter to use when resizing the image. | ||
do_normalize (`bool`, *optional*, defaults to `True`): | ||
Whether to normalize the image to [-1,1]. | ||
do_binarize (`bool`, *optional*, defaults to `True`): | ||
Whether to binarize the image to 0/1. | ||
do_convert_rgb (`bool`, *optional*, defaults to be `False`): | ||
Whether to convert the images to RGB format. | ||
do_convert_grayscale (`bool`, *optional*, defaults to be `False`): | ||
Whether to convert the images to grayscale format. | ||
""" | ||
|
||
config_name = CONFIG_NAME | ||
|
@@ -51,9 +65,18 @@ def __init__( | |
vae_scale_factor: int = 8, | ||
resample: str = "lanczos", | ||
do_normalize: bool = True, | ||
do_binarize: bool = False, | ||
do_convert_rgb: bool = False, | ||
do_convert_grayscale: bool = False, | ||
): | ||
super().__init__() | ||
if do_convert_rgb and do_convert_grayscale: | ||
raise ValueError( | ||
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," | ||
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", | ||
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", | ||
) | ||
Comment on lines
+74
to
+78
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very descriptive! Looks good. |
||
self.config.do_convert_rgb = False | ||
|
||
@staticmethod | ||
def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: | ||
|
@@ -119,31 +142,84 @@ def denormalize(images): | |
@staticmethod | ||
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: | ||
""" | ||
Converts an image to RGB format. | ||
Converts a PIL image to RGB format. | ||
""" | ||
image = image.convert("RGB") | ||
|
||
return image | ||
|
||
def resize( | ||
@staticmethod | ||
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: | ||
""" | ||
Converts a PIL image to grayscale format. | ||
""" | ||
image = image.convert("L") | ||
|
||
return image | ||
|
||
def get_default_height_width( | ||
self, | ||
image: PIL.Image.Image, | ||
image: [PIL.Image.Image, np.ndarray, torch.Tensor], | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
) -> PIL.Image.Image: | ||
): | ||
""" | ||
Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`. | ||
This function return the height and width that are downscaled to the next integer multiple of | ||
`vae_scale_factor`. | ||
|
||
Args: | ||
image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): | ||
The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have | ||
shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should | ||
have shape `[batch, channel, height, width]`. | ||
height (`int`, *optional*, defaults to `None`): | ||
The height in preprocessed image. If `None`, will use the height of `image` input. | ||
width (`int`, *optional*`, defaults to `None`): | ||
The width in preprocessed. If `None`, will use the width of the `image` input. | ||
""" | ||
|
||
if height is None: | ||
height = image.height | ||
if isinstance(image, PIL.Image.Image): | ||
height = image.height | ||
elif isinstance(image, torch.Tensor): | ||
height = image.shape[2] | ||
else: | ||
height = image.shape[1] | ||
|
||
if width is None: | ||
width = image.width | ||
if isinstance(image, PIL.Image.Image): | ||
width = image.width | ||
elif isinstance(image, torch.Tensor): | ||
width = image.shape[3] | ||
else: | ||
height = image.shape[2] | ||
|
||
width, height = ( | ||
x - x % self.config.vae_scale_factor for x in (width, height) | ||
) # resize to integer multiple of vae_scale_factor | ||
|
||
return height, width | ||
|
||
def resize( | ||
self, | ||
image: PIL.Image.Image, | ||
height: Optional[int] = None, | ||
width: Optional[int] = None, | ||
) -> PIL.Image.Image: | ||
""" | ||
Resize a PIL image. | ||
""" | ||
image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) | ||
return image | ||
|
||
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: | ||
""" | ||
create a mask | ||
""" | ||
image[image < 0.5] = 0 | ||
image[image >= 0.5] = 1 | ||
return image | ||
Comment on lines
+219
to
+221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should 0 and 1 be registered into the config vars? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think 0 and 1 is global enough to not have to be added to the config |
||
|
||
def preprocess( | ||
self, | ||
image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], | ||
|
@@ -154,6 +230,25 @@ def preprocess( | |
Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. | ||
""" | ||
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) | ||
|
||
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image | ||
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: | ||
if isinstance(image, torch.Tensor): | ||
# if image is a pytorch tensor could have 2 possible shapes: | ||
# 1. batch x height x width: we should insert the channel dimension at position 1 | ||
# 2. channnel x height x width: we should insert batch dimension at position 0, | ||
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1 | ||
# for simplicity, we insert a dimension of size 1 at position 1 for both cases | ||
image = image.unsqueeze(1) | ||
else: | ||
# if it is a numpy array, it could have 2 possible shapes: | ||
# 1. batch x height x width: insert channel dimension on last position | ||
# 2. height x width x channel: insert batch dimension on first position | ||
if image.shape[-1] == 1: | ||
image = np.expand_dims(image, axis=0) | ||
else: | ||
image = np.expand_dims(image, axis=-1) | ||
|
||
Comment on lines
+236
to
+251
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For easier operability, does it make sense to first convert the input tensor to a NumPy array and then operate from there? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the output of preprocess is pytorch tensors, why would we want to convert to numpy array first? the reason we want to this step here is because rest of our preprocessing logic assumes a 3D tensor has shape for example if we have a tensor with shape the correct would be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why cannot we add a dummy channel to represent the grayscale images then? Just trying to understand it better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @sayakpaul |
||
if isinstance(image, supported_formats): | ||
image = [image] | ||
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): | ||
|
@@ -164,42 +259,47 @@ def preprocess( | |
if isinstance(image[0], PIL.Image.Image): | ||
if self.config.do_convert_rgb: | ||
image = [self.convert_to_rgb(i) for i in image] | ||
elif self.config.do_convert_grayscale: | ||
image = [self.convert_to_grayscale(i) for i in image] | ||
if self.config.do_resize: | ||
height, width = self.get_default_height_width(image[0], height, width) | ||
image = [self.resize(i, height, width) for i in image] | ||
image = self.pil_to_numpy(image) # to np | ||
image = self.numpy_to_pt(image) # to pt | ||
|
||
elif isinstance(image[0], np.ndarray): | ||
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) | ||
|
||
image = self.numpy_to_pt(image) | ||
_, _, height, width = image.shape | ||
if self.config.do_resize and ( | ||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 | ||
): | ||
|
||
height, width = self.get_default_height_width(image, height, width) | ||
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width): | ||
raise ValueError( | ||
f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" | ||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" | ||
f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}" | ||
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" | ||
) | ||
|
||
elif isinstance(image[0], torch.Tensor): | ||
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) | ||
_, channel, height, width = image.shape | ||
|
||
if self.config.do_convert_grayscale and image.ndim == 3: | ||
image = image.unsqueeze(1) | ||
|
||
channel = image.shape[1] | ||
# don't need any preprocess if the image is latents | ||
if channel == 4: | ||
return image | ||
|
||
if self.config.do_resize and ( | ||
height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 | ||
): | ||
height, width = self.get_default_height_width(image, height, width) | ||
if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width): | ||
raise ValueError( | ||
f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" | ||
f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" | ||
f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}" | ||
f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" | ||
) | ||
|
||
# expected range [0,1], normalize to [-1,1] | ||
do_normalize = self.config.do_normalize | ||
if image.min() < 0: | ||
if image.min() < 0 and do_normalize: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. if user configured the image_processor to have |
||
warnings.warn( | ||
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " | ||
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", | ||
|
@@ -210,6 +310,9 @@ def preprocess( | |
if do_normalize: | ||
image = self.normalize(image) | ||
|
||
if self.config.do_binarize: | ||
image = self.binarize(image) | ||
|
||
return image | ||
|
||
def postprocess( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nice idea!