diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 4ccb9d77d627..027691ad9f2f 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -29,15 +29,34 @@ PipelineImageInput = Union[ PIL.Image.Image, np.ndarray, - torch.FloatTensor, + torch.Tensor, List[PIL.Image.Image], List[np.ndarray], - List[torch.FloatTensor], + List[torch.Tensor], ] PipelineDepthInput = PipelineImageInput +def is_valid_image(image): + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + # check if the image input is one of the supported formats for image and image list: + # it can be either one of below 3 + # (1) a 4d pytorch tensor or numpy array, + # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor + # (3) a list of valid image + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + + class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. @@ -111,7 +130,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd return images @staticmethod - def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: """ Convert a NumPy image to a PyTorch tensor. """ @@ -122,7 +141,7 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: return images @staticmethod - def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ @@ -498,9 +517,29 @@ def preprocess( else: image = np.expand_dims(image, axis=-1) - if isinstance(image, supported_formats): + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + # ensure the input is a list of images: + # - if it is a batch of images (4d torch.Tensor or np.ndarray), it is converted to a list of images (a list of 3d torch.Tensor or np.ndarray) + # - if it is a single image, it is converted to a list of one image + if isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim == 4: + image = list(image) + if is_valid_image(image): image = [image] - elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): + if not all(is_valid_image(img) for img in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" ) @@ -562,15 +601,15 @@ def preprocess( def postprocess( self, - image: torch.FloatTensor, + image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: - image (`torch.FloatTensor`): + image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. @@ -579,7 +618,7 @@ def postprocess( `VaeImageProcessor` config. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): @@ -739,15 +778,15 @@ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: def postprocess( self, - image: torch.FloatTensor, + image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: - image (`torch.FloatTensor`): + image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. @@ -756,7 +795,7 @@ def postprocess( `VaeImageProcessor` config. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): @@ -794,8 +833,8 @@ def postprocess( def preprocess( self, - rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], - depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, target_res: Optional[int] = None, @@ -934,13 +973,13 @@ def __init__( ) @staticmethod - def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): + def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): """ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. Args: - mask (`torch.FloatTensor`): + mask (`torch.Tensor`): The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. batch_size (`int`): The batch size. @@ -950,7 +989,7 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value The dimensionality of the value embeddings. Returns: - `torch.FloatTensor`: + `torch.Tensor`: The downsampled mask tensor. """ diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index ece646eccf6d..c03736c6398f 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -12,26 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import List, Union import numpy as np import PIL import torch -from .image_processor import VaeImageProcessor +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" def tensor2vid( - self, video: torch.FloatTensor, output_type: str = "np" - ) -> Union[np.ndarray, torch.FloatTensor, List[PIL.Image.Image]]: + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: r""" Converts a video tensor to a list of frames for export. Args: - video (`torch.FloatTensor`): The video as a tensor. + video (`torch.Tensor`): The video as a tensor. output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. """ batch_size = video.shape[0] @@ -50,7 +51,7 @@ def tensor2vid( return outputs - def preprocess_video(self, video) -> torch.FloatTensor: + def preprocess_video(self, video) -> torch.Tensor: r""" Preprocesses input video(s). @@ -58,86 +59,45 @@ def preprocess_video(self, video) -> torch.FloatTensor: video: The input video. It can be one of the following: * List of the PIL images. * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). + * 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). * List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). - * List of list of 4D Torch tensors (expected shape for tensor: (num_frames, num_channels, height, - width)). * List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). - * List of list of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, - num_channels)). - * List of 5D NumPy arrays (expected shape for each array: (batch_size, num_frames, height, width, - num_channels). - * List of 5D Torch tensors (expected shape for each array: (batch_size, num_frames, num_channels, - height, width). * 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width, num_channels). * 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height, width). """ - supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image, list) - - # Single-frame video. - if isinstance(video, supported_formats[:-1]): - video = [video] - - # List of PIL images. - elif isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): video = [video] - - elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: raise ValueError( - f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(list(map(str, supported_formats)))}" + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - if isinstance(video[0], np.ndarray): - # When the number of dimension of the first element in `video` is 5, it means - # each element in the `video` list is a video. - video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) - - if video.ndim == 4: - video = video[None, ...] - - elif isinstance(video[0], torch.Tensor): - video = torch.cat(video, dim=0) if video[0].ndim == 5 else torch.stack(video, dim=0) - - # don't need any preprocess if the video is latents - channel = video.shape[1] - if channel == 4: - return video - - # List of 5d tensors/ndarrays. - elif isinstance(video[0], list): - if isinstance(video[0][0], (np.ndarray, torch.Tensor)): - all_frames = [] - for list_of_videos in video: - temp_frames = [] - for vid in list_of_videos: - if vid.ndim == 4: - current_vid_frames = np.stack(vid, axis=0) if isinstance(vid, np.ndarray) else vid - elif vid.ndim == 5: - current_vid_frames = ( - np.concatenate(vid, axis=0) if isinstance(vid, np.ndarray) else torch.cat(vid, dim=0) - ) - temp_frames.append(current_vid_frames) - - # Process inner list. - temp_frames = ( - np.stack(temp_frames, axis=0) - if isinstance(temp_frames[0], np.ndarray) - else torch.stack(temp_frames, axis=0) - ) - all_frames.append(temp_frames) - - # Process outer list. - video = ( - np.concatenate(all_frames, axis=0) - if isinstance(all_frames[0], np.ndarray) - else torch.cat(all_frames, dim=0) - ) - - # `preprocess()` here would return a PT tensor. - video = torch.stack([self.preprocess(f) for f in video], dim=0) - - # move channels before num_frames + video = torch.stack([self.preprocess(img) for img in video], dim=0) video = video.permute(0, 2, 1, 3, 4) return video diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py index 71524b35904b..40f024fc9b2b 100644 --- a/tests/others/test_video_processor.py +++ b/tests/others/test_video_processor.py @@ -140,7 +140,7 @@ def test_video_processor_pil(self, input_type): input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input) assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" - @parameterized.expand(["list_4d_np", "list_list_4d_np", "list_5d_np", "5d_np"]) + @parameterized.expand(["list_4d_np", "list_5d_np", "5d_np"]) def test_video_processor_np(self, input_type): video_processor = VideoProcessor(do_resize=False, do_normalize=True) @@ -154,7 +154,7 @@ def test_video_processor_np(self, input_type): ) assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" - @parameterized.expand(["list_4d_pt", "list_list_4d_pt", "list_5d_pt", "5d_pt"]) + @parameterized.expand(["list_4d_pt", "list_5d_pt", "5d_pt"]) def test_video_processor_pt(self, input_type): video_processor = VideoProcessor(do_resize=False, do_normalize=True)