diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 6ccf9b465ebd..097257c1a5a8 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -24,6 +24,16 @@ from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate +PipelineImageInput = Union[ + PIL.Image.Image, + np.ndarray, + torch.FloatTensor, + List[PIL.Image.Image], + List[np.ndarray], + List[torch.FloatTensor], +] + + class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. @@ -38,8 +48,12 @@ class VaeImageProcessor(ConfigMixin): Resampling filter to use when resizing the image. do_normalize (`bool`, *optional*, defaults to `True`): Whether to normalize the image to [-1,1]. + do_binarize (`bool`, *optional*, defaults to `True`): + Whether to binarize the image to 0/1. do_convert_rgb (`bool`, *optional*, defaults to be `False`): Whether to convert the images to RGB format. + do_convert_grayscale (`bool`, *optional*, defaults to be `False`): + Whether to convert the images to grayscale format. """ config_name = CONFIG_NAME @@ -51,9 +65,18 @@ def __init__( vae_scale_factor: int = 8, resample: str = "lanczos", do_normalize: bool = True, + do_binarize: bool = False, do_convert_rgb: bool = False, + do_convert_grayscale: bool = False, ): super().__init__() + if do_convert_rgb and do_convert_grayscale: + raise ValueError( + "`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`," + " if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.", + " if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`", + ) + self.config.do_convert_rgb = False @staticmethod def numpy_to_pil(images: np.ndarray) -> PIL.Image.Image: @@ -119,31 +142,84 @@ def denormalize(images): @staticmethod def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image: """ - Converts an image to RGB format. + Converts a PIL image to RGB format. """ image = image.convert("RGB") + return image - def resize( + @staticmethod + def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image: + """ + Converts a PIL image to grayscale format. + """ + image = image.convert("L") + + return image + + def get_default_height_width( self, - image: PIL.Image.Image, + image: [PIL.Image.Image, np.ndarray, torch.Tensor], height: Optional[int] = None, width: Optional[int] = None, - ) -> PIL.Image.Image: + ): """ - Resize a PIL image. Both height and width are downscaled to the next integer multiple of `vae_scale_factor`. + This function return the height and width that are downscaled to the next integer multiple of + `vae_scale_factor`. + + Args: + image(`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`): + The image input, can be a PIL image, numpy array or pytorch tensor. if it is a numpy array, should have + shape `[batch, height, width]` or `[batch, height, width, channel]` if it is a pytorch tensor, should + have shape `[batch, channel, height, width]`. + height (`int`, *optional*, defaults to `None`): + The height in preprocessed image. If `None`, will use the height of `image` input. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed. If `None`, will use the width of the `image` input. """ + if height is None: - height = image.height + if isinstance(image, PIL.Image.Image): + height = image.height + elif isinstance(image, torch.Tensor): + height = image.shape[2] + else: + height = image.shape[1] + if width is None: - width = image.width + if isinstance(image, PIL.Image.Image): + width = image.width + elif isinstance(image, torch.Tensor): + width = image.shape[3] + else: + height = image.shape[2] width, height = ( x - x % self.config.vae_scale_factor for x in (width, height) ) # resize to integer multiple of vae_scale_factor + + return height, width + + def resize( + self, + image: PIL.Image.Image, + height: Optional[int] = None, + width: Optional[int] = None, + ) -> PIL.Image.Image: + """ + Resize a PIL image. + """ image = image.resize((width, height), resample=PIL_INTERPOLATION[self.config.resample]) return image + def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image: + """ + create a mask + """ + image[image < 0.5] = 0 + image[image >= 0.5] = 1 + return image + def preprocess( self, image: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], @@ -154,6 +230,25 @@ def preprocess( Preprocess the image input. Accepted formats are PIL images, NumPy arrays or PyTorch tensors. """ supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor) + + # Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image + if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3: + if isinstance(image, torch.Tensor): + # if image is a pytorch tensor could have 2 possible shapes: + # 1. batch x height x width: we should insert the channel dimension at position 1 + # 2. channnel x height x width: we should insert batch dimension at position 0, + # however, since both channel and batch dimension has same size 1, it is same to insert at position 1 + # for simplicity, we insert a dimension of size 1 at position 1 for both cases + image = image.unsqueeze(1) + else: + # if it is a numpy array, it could have 2 possible shapes: + # 1. batch x height x width: insert channel dimension on last position + # 2. height x width x channel: insert batch dimension on first position + if image.shape[-1] == 1: + image = np.expand_dims(image, axis=0) + else: + image = np.expand_dims(image, axis=-1) + if isinstance(image, supported_formats): image = [image] elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): @@ -164,42 +259,47 @@ def preprocess( if isinstance(image[0], PIL.Image.Image): if self.config.do_convert_rgb: image = [self.convert_to_rgb(i) for i in image] + elif self.config.do_convert_grayscale: + image = [self.convert_to_grayscale(i) for i in image] if self.config.do_resize: + height, width = self.get_default_height_width(image[0], height, width) image = [self.resize(i, height, width) for i in image] image = self.pil_to_numpy(image) # to np image = self.numpy_to_pt(image) # to pt elif isinstance(image[0], np.ndarray): image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0) + image = self.numpy_to_pt(image) - _, _, height, width = image.shape - if self.config.do_resize and ( - height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 - ): + + height, width = self.get_default_height_width(image, height, width) + if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width): raise ValueError( - f"Currently we only support resizing for PIL image - please resize your numpy array to be divisible by {self.config.vae_scale_factor}" - f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" + f"Currently we only support resizing for PIL image - please resize your numpy array to be {height} and {width}" + f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) elif isinstance(image[0], torch.Tensor): image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0) - _, channel, height, width = image.shape + if self.config.do_convert_grayscale and image.ndim == 3: + image = image.unsqueeze(1) + + channel = image.shape[1] # don't need any preprocess if the image is latents if channel == 4: return image - if self.config.do_resize and ( - height % self.config.vae_scale_factor != 0 or width % self.config.vae_scale_factor != 0 - ): + height, width = self.get_default_height_width(image, height, width) + if self.config.do_resize and (image.shape[2] != height or image.shape[3] != width): raise ValueError( - f"Currently we only support resizing for PIL image - please resize your pytorch tensor to be divisible by {self.config.vae_scale_factor}" - f"currently the sizes are {height} and {width}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" + f"Currently we only support resizing for PIL image - please resize your torch tensor to be {height} and {width}" + f"currently the sizes are {image.shape[2]} and {image.shape[3]}. You can also pass a PIL image instead to use resize option in VAEImageProcessor" ) # expected range [0,1], normalize to [-1,1] do_normalize = self.config.do_normalize - if image.min() < 0: + if image.min() < 0 and do_normalize: warnings.warn( "Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] " f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]", @@ -210,6 +310,9 @@ def preprocess( if do_normalize: image = self.normalize(image) + if self.config.do_binarize: + image = self.binarize(image) + return image def postprocess( diff --git a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py index 3ceff61f79e1..64be253d3bbf 100644 --- a/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +++ b/src/diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py @@ -25,7 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -567,14 +567,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -597,7 +590,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index b5d8ea8cb7f9..734c8d025fb9 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -678,14 +678,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 1a55ffaaf1da..ec58304600d3 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -23,7 +23,7 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -750,22 +750,8 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, - control_image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.8, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 83bec1e190a9..2719b385f207 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -24,11 +24,12 @@ import torch.nn.functional as F from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, is_compiled_module, @@ -133,7 +134,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image=False tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ - + deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" + deprecate( + "prepare_mask_and_masked_image", + "0.30.0", + deprecation_message, + ) if image is None: raise ValueError("`image` input cannot be undefined.") @@ -316,6 +322,9 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) self.control_image_processor = VaeImageProcessor( vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False ) @@ -615,7 +624,7 @@ def check_inputs( control_guidance_start=0.0, control_guidance_end=1.0, ): - if height % 8 != 0 or width % 8 != 0: + if height is not None and height % 8 != 0 or width is not None and width % 8 != 0: raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") if (callback_steps is None) or ( @@ -863,31 +872,6 @@ def prepare_latents( return outputs - def _default_height_width(self, height, width, image): - # NOTE: It is possible that a list of images have different - # dimensions for each image, so just checking the first image - # is not _exactly_ correct, but it is simple. - while isinstance(image, list): - image = image[0] - - if height is None: - if isinstance(image, PIL.Image.Image): - height = image.height - elif isinstance(image, torch.Tensor): - height = image.shape[2] - - height = (height // 8) * 8 # round down to nearest multiple of 8 - - if width is None: - if isinstance(image, PIL.Image.Image): - width = image.width - elif isinstance(image, torch.Tensor): - width = image.shape[3] - - width = (width // 8) * 8 # round down to nearest multiple of 8 - - return height, width - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_inpaint.StableDiffusionInpaintPipeline.prepare_mask_latents def prepare_mask_latents( self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance @@ -950,16 +934,9 @@ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.Tensor, PIL.Image.Image] = None, - mask_image: Union[torch.Tensor, PIL.Image.Image] = None, - control_image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, + control_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, @@ -989,14 +966,29 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. - image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, + `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch + tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the + expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but + if passing latents directly it is not encoded again. + mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, + `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. + control_image (`torch.FloatTensor`, `PIL.Image.Image`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[List[torch.FloatTensor]]`, or `List[List[PIL.Image.Image]]`): - The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If - the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can - also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If - height and/or width are passed, `image` is resized according to them. If multiple ControlNets are - specified in init, images must be passed as a list such that each element of the list can be correctly - batched for input to a single controlnet. + The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. The + dimensions of the output image defaults to `image`'s dimensions. If height and/or width are passed, + `image` is resized according to them. If multiple ControlNets are specified in init, images must be + passed as a list such that each element of the list can be correctly batched for input to a single + controlnet. height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): The height in pixels of the generated image. width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): @@ -1080,9 +1072,6 @@ def __call__( """ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet - # 0. Default height and width to unet - height, width = self._default_height_width(height, width, image) - # align format for control guidance if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): control_guidance_start = len(control_guidance_end) * [control_guidance_start] @@ -1184,9 +1173,13 @@ def __call__( assert False # 4. Preprocess mask and image - resizes image and mask w.r.t height and width - mask, masked_image, init_image = prepare_mask_and_masked_image( - image, mask_image, height, width, return_image=True - ) + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = init_image * (mask < 0.5) + _, _, height, width = init_image.shape # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 9866875425f7..d203960c1536 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -25,7 +25,7 @@ from diffusers.utils.import_utils import is_invisible_watermark_available -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ControlNetModel, UNet2DConditionModel from ...models.attention_processor import ( @@ -755,14 +755,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py index bcaf2fad0ad9..7ce4734b131c 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py @@ -25,7 +25,7 @@ from diffusers.utils import is_accelerate_available, is_accelerate_version from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import DDIMScheduler @@ -578,14 +578,7 @@ def __call__( self, prompt: Union[str, List[str]], source_prompt: Union[str, List[str]], - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py index f580115a3cc8..1713050364d5 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py @@ -24,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer, DPTFeatureExtractor, DPTForDepthEstimation from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -499,14 +499,7 @@ def prepare_depth_map(self, image, depth_map, batch_size, do_classifier_free_gui def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, depth_map: Optional[torch.FloatTensor] = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 72c7c3ad4c16..e3133cc1076e 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -573,14 +573,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, strength: float = 0.8, num_inference_steps: Optional[int] = 50, guidance_scale: Optional[float] = 7.5, @@ -603,7 +596,10 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): - `Image` or tensor representing an image batch to be used as the starting point. Can also accept image + `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both + numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list + or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a + list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but if passing latents directly it is not encoded again. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index ab2abdc05f51..5c9d8d85145f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -22,7 +22,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer from ...configuration_utils import FrozenDict -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AsymmetricAutoencoderKL, AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -63,7 +63,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4 dimensions: ``batch x channels x height x width``. """ - + deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" + deprecate( + "prepare_mask_and_masked_image", + "0.30.0", + deprecation_message, + ) if image is None: raise ValueError("`image` input cannot be undefined.") @@ -280,6 +285,9 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) self.register_to_config(requires_safety_checker=requires_safety_checker) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload @@ -679,8 +687,8 @@ def get_timesteps(self, num_inference_steps, strength, device): def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, @@ -705,14 +713,20 @@ def __call__( Args: prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. - image (`PIL.Image.Image`): - `Image` or tensor representing an image batch to be inpainted (which parts of the image to be masked - out with `mask_image` and repainted according to `prompt`). - mask_image (`PIL.Image.Image`): - `Image` or tensor representing an image batch to mask `image`. White pixels in the mask are repainted - while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a single channel - (luminance) before use. If it's a tensor, it should contain one color channel (L) instead of 3, so the - expected shape would be `(B, H, W, 1)`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to be inpainted (which parts of the image to + be masked out with `mask_image` and repainted according to `prompt`). For both numpy array and pytorch + tensor, the expected value range is between `[0, 1]` If it's a tensor or a list or tensors, the + expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a list of arrays, the + expected shape should be `(B, H, W, C)` or `(H, W, C)` It can also accept image latents as `image`, but + if passing latents directly it is not encoded again. + mask_image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`): + `Image`, numpy array or tensor representing an image batch to mask `image`. White pixels in the mask + are repainted while black pixels are preserved. If `mask_image` is a PIL image, it is converted to a + single channel (luminance) before use. If it's a numpy array or pytorch tensor, it should contain one + color channel (L) instead of 3, so the expected shape for pytorch tensor would be `(B, 1, H, W)`, `(B, + H, W)`, `(1, H, W)`, `(H, W)`. And for numpy array would be for `(B, H, W, 1)`, `(B, H, W)`, `(H, W, + 1)`, or `(H, W)`. height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): The height in pixels of the generated image. width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): @@ -865,9 +879,14 @@ def __call__( is_strength_max = strength == 1.0 # 5. Preprocess mask and image - mask, masked_image, init_image = prepare_mask_and_masked_image( - image, mask_image, height, width, return_image=True - ) + + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + + masked_image = init_image * (mask < 0.5) + mask_condition = mask.clone() # 6. Prepare latent variables diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py index d27f8a21f369..3253c135d6e6 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import KarrasDiffusionSchedulers @@ -147,14 +147,7 @@ def __init__( def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py index cad82cb71940..79501a78cdd1 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py @@ -21,7 +21,7 @@ import torch.nn.functional as F from transformers import CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...models import AutoencoderKL, UNet2DConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import logging, randn_tensor @@ -257,14 +257,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype def __call__( self, prompt: Union[str, List[str]], - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, negative_prompt: Optional[Union[str, List[str]]] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py index db26221d4234..a67bb3e87e59 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py @@ -29,7 +29,7 @@ CLIPTokenizer, ) -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import Attention @@ -1066,14 +1066,7 @@ def __call__( def invert( self, prompt: Optional[str] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, num_inference_steps: int = 50, guidance_scale: float = 1, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py index 11e75f7e9556..8f0776361fc3 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( @@ -496,14 +496,7 @@ def upcast_vae(self): def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, num_inference_steps: int = 75, guidance_scale: float = 9.0, noise_level: int = 20, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index d07405d45bfc..d244f0d7ab2b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -16,12 +16,11 @@ import os from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import PIL.Image import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( @@ -656,14 +655,7 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, strength: float = 0.3, num_inference_steps: int = 50, denoising_start: Optional[float] = None, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index c480549aebb3..4b912b9790a8 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( @@ -32,6 +32,7 @@ ) from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( + deprecate, is_accelerate_available, is_accelerate_version, is_invisible_watermark_available, @@ -140,6 +141,12 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool """ # checkpoint. TOD(Yiyi) - need to clean this up later + deprecation_message = "The prepare_mask_and_masked_image method is deprecated and will be removed in a future version. Please use VaeImageProcessor.preprocess instead" + deprecate( + "prepare_mask_and_masked_image", + "0.30.0", + deprecation_message, + ) if image is None: raise ValueError("`image` input cannot be undefined.") @@ -290,6 +297,9 @@ def __init__( self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.mask_processor = VaeImageProcessor( + vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -857,8 +867,8 @@ def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, - image: Union[torch.FloatTensor, PIL.Image.Image] = None, - mask_image: Union[torch.FloatTensor, PIL.Image.Image] = None, + image: PipelineImageInput = None, + mask_image: PipelineImageInput = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 1.0, @@ -1098,9 +1108,16 @@ def denoising_value_valid(dnv): is_strength_max = strength == 1.0 # 5. Preprocess mask and image - mask, masked_image, init_image = prepare_mask_and_masked_image( - image, mask_image, height, width, return_image=True - ) + init_image = self.image_processor.preprocess(image, height=height, width=width) + init_image = init_image.to(dtype=torch.float32) + + mask = self.mask_processor.preprocess(mask_image, height=height, width=width) + + if init_image.shape[1] == 4: + # if images are in latent space, we can't mask it + masked_image = None + else: + masked_image = init_image * (mask < 0.5) # 6. Prepare latent variables num_channels_latents = self.vae.config.latent_channels diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index e252d5294626..e4cec6482e3c 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -15,12 +15,11 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import PIL.Image import torch from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer -from ...image_processor import VaeImageProcessor +from ...image_processor import PipelineImageInput, VaeImageProcessor from ...loaders import FromSingleFileMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet2DConditionModel from ...models.attention_processor import ( @@ -587,14 +586,7 @@ def upcast_vae(self): def __call__( self, prompt: Union[str, List[str]] = None, - image: Union[ - torch.FloatTensor, - PIL.Image.Image, - np.ndarray, - List[torch.FloatTensor], - List[PIL.Image.Image], - List[np.ndarray], - ] = None, + image: PipelineImageInput = None, num_inference_steps: int = 100, guidance_scale: float = 7.5, image_guidance_scale: float = 1.5, diff --git a/tests/others/test_image_processor.py b/tests/others/test_image_processor.py index c2cd6f4a04f4..f8d22ce9a01b 100644 --- a/tests/others/test_image_processor.py +++ b/tests/others/test_image_processor.py @@ -34,6 +34,17 @@ def dummy_sample(self): return sample + @property + def dummy_mask(self): + batch_size = 1 + num_channels = 1 + height = 8 + width = 8 + + sample = torch.rand((batch_size, num_channels, height, width)) + + return sample + def to_np(self, image): if isinstance(image[0], PIL.Image.Image): return np.stack([np.array(i) for i in image], axis=0) @@ -133,17 +144,144 @@ def test_preprocess_input_list(self): ) input_np_4d = self.to_np(self.dummy_sample) - list(input_np_4d) + input_np_list = list(input_np_4d) out_np_4d = image_processor.postprocess( - image_processor.preprocess(input_pt_4d), + image_processor.preprocess(input_np_4d), output_type="np", ) out_np_list = image_processor.postprocess( - image_processor.preprocess(input_pt_list), + image_processor.preprocess(input_np_list), output_type="np", ) assert np.abs(out_pt_4d - out_pt_list).max() < 1e-6 assert np.abs(out_np_4d - out_np_list).max() < 1e-6 + + def test_preprocess_input_mask_3d(self): + image_processor = VaeImageProcessor( + do_resize=False, do_normalize=False, do_binarize=True, do_convert_grayscale=True + ) + + input_pt_4d = self.dummy_mask + input_pt_3d = input_pt_4d.squeeze(0) + input_pt_2d = input_pt_3d.squeeze(0) + + out_pt_4d = image_processor.postprocess( + image_processor.preprocess(input_pt_4d), + output_type="np", + ) + out_pt_3d = image_processor.postprocess( + image_processor.preprocess(input_pt_3d), + output_type="np", + ) + + out_pt_2d = image_processor.postprocess( + image_processor.preprocess(input_pt_2d), + output_type="np", + ) + + input_np_4d = self.to_np(self.dummy_mask) + input_np_3d = input_np_4d.squeeze(0) + input_np_3d_1 = input_np_4d.squeeze(-1) + input_np_2d = input_np_3d.squeeze(-1) + + out_np_4d = image_processor.postprocess( + image_processor.preprocess(input_np_4d), + output_type="np", + ) + out_np_3d = image_processor.postprocess( + image_processor.preprocess(input_np_3d), + output_type="np", + ) + + out_np_3d_1 = image_processor.postprocess( + image_processor.preprocess(input_np_3d_1), + output_type="np", + ) + + out_np_2d = image_processor.postprocess( + image_processor.preprocess(input_np_2d), + output_type="np", + ) + + assert np.abs(out_pt_4d - out_pt_3d).max() == 0 + assert np.abs(out_pt_4d - out_pt_2d).max() == 0 + assert np.abs(out_np_4d - out_np_3d).max() == 0 + assert np.abs(out_np_4d - out_np_3d_1).max() == 0 + assert np.abs(out_np_4d - out_np_2d).max() == 0 + + def test_preprocess_input_mask_list(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True) + + input_pt_4d = self.dummy_mask + input_pt_3d = input_pt_4d.squeeze(0) + input_pt_2d = input_pt_3d.squeeze(0) + + inputs_pt = [input_pt_4d, input_pt_3d, input_pt_2d] + inputs_pt_list = [[input_pt] for input_pt in inputs_pt] + + for input_pt, input_pt_list in zip(inputs_pt, inputs_pt_list): + out_pt = image_processor.postprocess( + image_processor.preprocess(input_pt), + output_type="np", + ) + out_pt_list = image_processor.postprocess( + image_processor.preprocess(input_pt_list), + output_type="np", + ) + assert np.abs(out_pt - out_pt_list).max() < 1e-6 + + input_np_4d = self.to_np(self.dummy_mask) + input_np_3d = input_np_4d.squeeze(0) + input_np_2d = input_np_3d.squeeze(-1) + + inputs_np = [input_np_4d, input_np_3d, input_np_2d] + inputs_np_list = [[input_np] for input_np in inputs_np] + + for input_np, input_np_list in zip(inputs_np, inputs_np_list): + out_np = image_processor.postprocess( + image_processor.preprocess(input_np), + output_type="np", + ) + out_np_list = image_processor.postprocess( + image_processor.preprocess(input_np_list), + output_type="np", + ) + assert np.abs(out_np - out_np_list).max() < 1e-6 + + def test_preprocess_input_mask_3d_batch(self): + image_processor = VaeImageProcessor(do_resize=False, do_normalize=False, do_convert_grayscale=True) + + # create a dummy mask input with batch_size 2 + dummy_mask_batch = torch.cat([self.dummy_mask] * 2, axis=0) + + # squeeze out the channel dimension + input_pt_3d = dummy_mask_batch.squeeze(1) + input_np_3d = self.to_np(dummy_mask_batch).squeeze(-1) + + input_pt_3d_list = list(input_pt_3d) + input_np_3d_list = list(input_np_3d) + + out_pt_3d = image_processor.postprocess( + image_processor.preprocess(input_pt_3d), + output_type="np", + ) + out_pt_3d_list = image_processor.postprocess( + image_processor.preprocess(input_pt_3d_list), + output_type="np", + ) + + assert np.abs(out_pt_3d - out_pt_3d_list).max() < 1e-6 + + out_np_3d = image_processor.postprocess( + image_processor.preprocess(input_np_3d), + output_type="np", + ) + out_np_3d_list = image_processor.postprocess( + image_processor.preprocess(input_np_3d_list), + output_type="np", + ) + + assert np.abs(out_np_3d - out_np_3d_list).max() < 1e-6