diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 0f4481570829..4d8d73c863ae 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -29,15 +29,34 @@ PipelineImageInput = Union[ PIL.Image.Image, np.ndarray, - torch.FloatTensor, + torch.Tensor, List[PIL.Image.Image], List[np.ndarray], - List[torch.FloatTensor], + List[torch.Tensor], ] PipelineDepthInput = PipelineImageInput +def is_valid_image(image): + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + # check if the image input is one of the supported formats for image and image list: + # it can be either one of below 3 + # (1) a 4d pytorch tensor or numpy array, + # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor + # (3) a list of valid image + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + + class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. @@ -110,7 +129,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd return images @staticmethod - def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: """ Convert a NumPy image to a PyTorch tensor. """ @@ -121,7 +140,7 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: return images @staticmethod - def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ @@ -497,12 +516,27 @@ def preprocess( else: image = np.expand_dims(image, axis=-1) - if isinstance(image, supported_formats): - image = [image] - elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + if not is_valid_image_imagelist(image): raise ValueError( - f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}" ) + if not isinstance(image, list): + image = [image] if isinstance(image[0], PIL.Image.Image): if crops_coords is not None: @@ -561,15 +595,15 @@ def preprocess( def postprocess( self, - image: torch.FloatTensor, + image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: - image (`torch.FloatTensor`): + image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. @@ -578,7 +612,7 @@ def postprocess( `VaeImageProcessor` config. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): @@ -738,15 +772,15 @@ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: def postprocess( self, - image: torch.FloatTensor, + image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: - image (`torch.FloatTensor`): + image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. @@ -755,7 +789,7 @@ def postprocess( `VaeImageProcessor` config. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): @@ -793,8 +827,8 @@ def postprocess( def preprocess( self, - rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], - depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, target_res: Optional[int] = None, @@ -933,13 +967,13 @@ def __init__( ) @staticmethod - def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): + def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): """ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. Args: - mask (`torch.FloatTensor`): + mask (`torch.Tensor`): The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. batch_size (`int`): The batch size. @@ -949,7 +983,7 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value The dimensionality of the value embeddings. Returns: - `torch.FloatTensor`: + `torch.Tensor`: The downsampled mask tensor. """ diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 3765db938cd5..94654c4a7e17 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -15,11 +15,10 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -41,6 +40,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -65,27 +65,6 @@ """ -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - class AnimateDiffPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -159,7 +138,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -836,7 +815,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index 56146b6d05ca..11ccafdf57b0 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch from transformers import ( CLIPImageProcessor, @@ -25,7 +24,7 @@ CLIPVisionModelWithProjection, ) -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, @@ -57,6 +56,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -113,28 +113,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -320,7 +298,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size @@ -1291,7 +1269,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # cast back to fp16 if needed if needs_upcasting: diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 67fa8b12a34b..00f773c98181 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -15,11 +15,10 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -34,6 +33,7 @@ ) from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -95,28 +95,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -264,7 +242,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -650,16 +628,7 @@ def prepare_latents( generator, latents=None, ): - # video must be a list of list of images - # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video - # as a list of images - if video and not isinstance(video[0], list): - video = [video] if latents is None: - video = torch.cat( - [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0 - ) - video = video.to(device=device, dtype=dtype) num_frames = video.shape[1] else: num_frames = latents.shape[2] @@ -943,6 +912,14 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) # 5. Prepare latent variables + # video must be a list of list of images + # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video + # as a list of images + if video and not isinstance(video[0], list): + video = [video] + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width, preceed_with_frames=True) + video = video.to(device=device, dtype=prompt_embeds.dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, @@ -1023,7 +1000,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index a6b9499f5542..a38918e1a0f9 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -31,6 +31,7 @@ replace_example_docstring, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -70,28 +71,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - @dataclass class I2VGenXLPipelineOutput(BaseOutput): r""" @@ -156,7 +135,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # `do_resize=False` as we do custom resizing. - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @property def guidance_scale(self): @@ -342,8 +321,8 @@ def _encode_image(self, image, device, num_videos_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): - image = self.image_processor.pil_to_numpy(image) - image = self.image_processor.numpy_to_pt(image) + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) # Normalize the image with CLIP training stats. image = self.feature_extractor( @@ -657,7 +636,7 @@ def __call__( # 3.2.2 Image latents. resized_image = _center_crop_wide(image, (width, height)) - image = self.image_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) + image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) image_latents = self.prepare_image_latents( image, device=device, @@ -737,7 +716,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index aceb95ae0451..a2723187f2db 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -43,6 +43,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -89,28 +90,6 @@ ] -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): assert num_frames > 0, "video_length should be greater than 0" @@ -218,7 +197,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -621,7 +600,7 @@ def prepare_masked_condition( ) _, _, _, scaled_height, scaled_width = shape - image = self.image_processor.preprocess(image) + image = self.video_processor.preprocess(image) image = image.to(device, dtype) if isinstance(generator, list): @@ -959,7 +938,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index b5b048d5cbe0..d815adab049f 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -21,11 +21,12 @@ import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline @@ -61,28 +62,6 @@ def _append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps def retrieve_timesteps( scheduler, @@ -199,7 +178,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) def _encode_image( self, @@ -211,8 +190,8 @@ def _encode_image( dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): - image = self.image_processor.pil_to_numpy(image) - image = self.image_processor.numpy_to_pt(image) + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) # We normalize the image before resizing to match with the original implementation. # Then we unnormalize it after resizing. @@ -520,7 +499,7 @@ def __call__( fps = fps - 1 # 4. Encode input image using VAE - image = self.image_processor.preprocess(image, height=height, width=width).to(device) + image = self.video_processor.preprocess(image, height=height, width=width).to(device) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) image = image + noise_aug_strength * noise @@ -626,7 +605,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = tensor2vid(frames, self.image_processor, output_type=output_type) + frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) else: frames = latents diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 6c33836e60da..0ef769f32a15 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -15,11 +15,9 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -33,6 +31,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import TextToVideoSDPipelineOutput @@ -59,28 +58,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -127,7 +104,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -652,7 +629,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 3901946afe46..0dc1ca93f873 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -16,11 +16,9 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import PIL.Image import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -34,6 +32,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import TextToVideoSDPipelineOutput @@ -94,69 +93,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - -def preprocess_video(video): - supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) - - if isinstance(video, supported_formats): - video = [video] - elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): - raise ValueError( - f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" - ) - - if isinstance(video[0], PIL.Image.Image): - video = [np.array(frame) for frame in video] - - if isinstance(video[0], np.ndarray): - video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) - - if video.dtype == np.uint8: - video = np.array(video).astype(np.float32) / 255.0 - - if video.ndim == 4: - video = video[None, ...] - - video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) - - elif isinstance(video[0], torch.Tensor): - video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) - - # don't need any preprocess if the video is latents - channel = video.shape[1] - if channel == 4: - return video - - # move channels before num_frames - video = video.permute(0, 2, 1, 3, 4) - - # normalize video - video = 2.0 * video - 1.0 - - return video - - class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided video-to-video generation. @@ -203,7 +139,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -687,7 +623,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video - video = preprocess_video(video) + video = self.video_processor.preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -749,7 +685,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py new file mode 100644 index 000000000000..3cd8461fa5f6 --- /dev/null +++ b/src/diffusers/video_processor.py @@ -0,0 +1,116 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings +from typing import List, Optional, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist + + +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" + + def preprocess_video( + self, video, height: Optional[int] = None, width: Optional[int] = None, preceed_with_frames: bool = False + ) -> torch.Tensor: + r""" + Preprocesses input video(s). + + Args: + video: The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). + * 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). + * List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). + * List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). + * 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width, + num_channels). + * 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height, + width). + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + preceed_with_frames (`bool`, defaults to False): + Some pipelines keep the number of channels _before_ the number of frames in terms dimensions. + `(batch_size, num_channels, num_frames, height, width)`, for example. Some pipelines don't. This flag + helps to control that behaviour. + """ + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): + video = [video] + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: + raise ValueError( + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" + ) + + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + if not preceed_with_frames: + video = video.permute(0, 2, 1, 3, 4) + + return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py new file mode 100644 index 000000000000..a2fc87717f9d --- /dev/null +++ b/tests/others/test_video_processor.py @@ -0,0 +1,169 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import PIL.Image +import torch +from parameterized import parameterized + +from diffusers.video_processor import VideoProcessor + + +np.random.seed(0) +torch.manual_seed(0) + + +class VideoProcessorTest(unittest.TestCase): + def get_dummy_sample(self, input_type): + batch_size = 1 + num_frames = 5 + num_channels = 3 + height = 8 + width = 8 + + def generate_image(): + return PIL.Image.fromarray(np.random.randint(0, 256, size=(height, width, num_channels)).astype("uint8")) + + def generate_4d_array(): + return np.random.rand(num_frames, height, width, num_channels) + + def generate_5d_array(): + return np.random.rand(batch_size, num_frames, height, width, num_channels) + + def generate_4d_tensor(): + return torch.rand(num_frames, num_channels, height, width) + + def generate_5d_tensor(): + return torch.rand(batch_size, num_frames, num_channels, height, width) + + if input_type == "list_images": + sample = [generate_image() for _ in range(num_frames)] + elif input_type == "list_list_images": + sample = [[generate_image() for _ in range(num_frames)] for _ in range(num_frames)] + elif input_type == "list_4d_np": + sample = [generate_4d_array() for _ in range(num_frames)] + elif input_type == "list_list_4d_np": + sample = [[generate_4d_array() for _ in range(num_frames)] for _ in range(num_frames)] + elif input_type == "list_5d_np": + sample = [generate_5d_array() for _ in range(num_frames)] + elif input_type == "5d_np": + sample = generate_5d_array() + elif input_type == "list_4d_pt": + sample = [generate_4d_tensor() for _ in range(num_frames)] + elif input_type == "list_list_4d_pt": + sample = [[generate_4d_tensor() for _ in range(num_frames)] for _ in range(num_frames)] + elif input_type == "list_5d_pt": + sample = [generate_5d_tensor() for _ in range(num_frames)] + elif input_type == "5d_pt": + sample = generate_5d_tensor() + + return sample + + def to_np(self, video): + # List of images. + if isinstance(video[0], PIL.Image.Image): + video = np.stack([np.array(i) for i in video], axis=0) + + # List of list of images. + elif isinstance(video, list) and isinstance(video[0][0], PIL.Image.Image): + frames = [] + for vid in video: + all_current_frames = np.stack([np.array(i) for i in vid], axis=0) + frames.append(all_current_frames) + video = np.stack([np.array(frame) for frame in frames], axis=0) + + # List of 4d/5d {ndarrays, torch tensors}. + elif isinstance(video, list) and isinstance(video[0], (torch.Tensor, np.ndarray)): + if isinstance(video[0], np.ndarray): + video = np.stack(video, axis=0) if video[0].ndim == 4 else np.concatenate(video, axis=0) + else: + if video[0].ndim == 4: + video = np.stack([i.cpu().numpy().transpose(0, 2, 3, 1) for i in video], axis=0) + elif video[0].ndim == 5: + video = np.concatenate([i.cpu().numpy().transpose(0, 1, 3, 4, 2) for i in video], axis=0) + + # List of list of 4d/5d {ndarrays, torch tensors}. + elif ( + isinstance(video, list) + and isinstance(video[0], list) + and isinstance(video[0][0], (torch.Tensor, np.ndarray)) + ): + all_frames = [] + for list_of_videos in video: + temp_frames = [] + for vid in list_of_videos: + if vid.ndim == 4: + current_vid_frames = np.stack( + [i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(1, 2, 0) for i in vid], + axis=0, + ) + elif vid.ndim == 5: + current_vid_frames = np.concatenate( + [i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(0, 2, 3, 1) for i in vid], + axis=0, + ) + temp_frames.append(current_vid_frames) + temp_frames = np.stack(temp_frames, axis=0) + all_frames.append(temp_frames) + + video = np.concatenate(all_frames, axis=0) + + # Just 5d {ndarrays, torch tensors}. + elif isinstance(video, (torch.Tensor, np.ndarray)) and video.ndim == 5: + video = video if isinstance(video, np.ndarray) else video.cpu().numpy().transpose(0, 1, 3, 4, 2) + + return video + + @parameterized.expand(["list_images", "list_list_images"]) + def test_video_processor_pil(self, input_type): + video_processor = VideoProcessor(do_resize=False, do_normalize=True) + + input = self.get_dummy_sample(input_type=input_type) + + for output_type in ["pt", "np", "pil"]: + out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type) + out_np = self.to_np(out) + input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input) + assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" + + @parameterized.expand(["list_4d_np", "list_5d_np", "5d_np"]) + def test_video_processor_np(self, input_type): + video_processor = VideoProcessor(do_resize=False, do_normalize=True) + + input = self.get_dummy_sample(input_type=input_type) + + for output_type in ["pt", "np", "pil"]: + out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type) + out_np = self.to_np(out) + input_np = ( + (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) + ) + assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" + + @parameterized.expand(["list_4d_pt", "list_5d_pt", "5d_pt"]) + def test_video_processor_pt(self, input_type): + video_processor = VideoProcessor(do_resize=False, do_normalize=True) + + input = self.get_dummy_sample(input_type=input_type) + + for output_type in ["pt", "np", "pil"]: + out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type) + out_np = self.to_np(out) + input_np = ( + (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) + ) + assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"