From ede27ee1c6d494465d3f1ca0683371558c66d216 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 25 Apr 2024 14:27:59 +0530 Subject: [PATCH 01/27] introduce videoprocessor. --- .../animatediff/pipeline_animatediff.py | 27 +----- .../pipeline_animatediff_video2video.py | 28 +----- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 27 +----- src/diffusers/pipelines/pia/pipeline_pia.py | 27 +----- .../pipeline_stable_video_diffusion.py | 25 +----- .../pipeline_text_to_video_synth.py | 28 +----- .../pipeline_text_to_video_synth_img2img.py | 71 ++------------- src/diffusers/video_processor.py | 87 +++++++++++++++++++ 8 files changed, 114 insertions(+), 206 deletions(-) create mode 100644 src/diffusers/video_processor.py diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 3765db938cd5..293ae658c385 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -41,6 +40,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -65,27 +65,6 @@ """ -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - class AnimateDiffPipeline( DiffusionPipeline, StableDiffusionMixin, @@ -836,7 +815,9 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = VideoProcessor.tensor2vid( + video=video_tensor, processor=self.image_processor, output_type=output_type + ) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 106fabba721b..c616e377fefc 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection @@ -34,6 +33,7 @@ ) from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -95,28 +95,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor, output_type="np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -997,7 +975,9 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = VideoProcessor.tensor2vid( + video=video_tensor, processor=self.image_processor, output_type=output_type + ) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index a6b9499f5542..26122b4c925d 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -31,6 +31,7 @@ replace_example_docstring, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -70,28 +71,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - @dataclass class I2VGenXLPipelineOutput(BaseOutput): r""" @@ -737,7 +716,9 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = VideoProcessor.tensor2vid( + video=video_tensor, processor=self.image_processor, output_type=output_type + ) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index aceb95ae0451..83af3a247243 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -43,6 +43,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -89,28 +90,6 @@ ] -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - def prepare_mask_coef_by_statistics(num_frames: int, cond_frame: int, motion_scale: int): assert num_frames > 0, "video_length should be greater than 0" @@ -959,7 +938,9 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = VideoProcessor.tensor2vid( + video=video_tensor, processor=self.image_processor, output_type=output_type + ) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 070183b92409..b7722f74f514 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -26,6 +26,7 @@ from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, logging, replace_example_docstring from ...utils.torch_utils import is_compiled_module, randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline @@ -61,28 +62,6 @@ def _append_dims(x, target_dims): return x[(...,) + (None,) * dims_to_append] -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: VaeImageProcessor, output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - @dataclass class StableVideoDiffusionPipelineOutput(BaseOutput): r""" @@ -562,7 +541,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = tensor2vid(frames, self.image_processor, output_type=output_type) + frames = VideoProcessor.tensor2vid(video=frames, processor=self.image_processor, output_type=output_type) else: frames = latents diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 6c33836e60da..286dbdeb0b3c 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Union -import numpy as np import torch from transformers import CLIPTextModel, CLIPTokenizer @@ -33,6 +32,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import TextToVideoSDPipelineOutput @@ -59,28 +59,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - class TextToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-to-video generation. @@ -652,7 +630,9 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = VideoProcessor.tensor2vid( + video=video_tensor, processor=self.image_processor, output_type=output_type + ) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 3901946afe46..2b7c3af50168 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -16,7 +16,6 @@ from typing import Any, Callable, Dict, List, Optional, Union import numpy as np -import PIL.Image import torch from transformers import CLIPTextModel, CLIPTokenizer @@ -34,6 +33,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from . import TextToVideoSDPipelineOutput @@ -94,69 +94,6 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - -def preprocess_video(video): - supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) - - if isinstance(video, supported_formats): - video = [video] - elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): - raise ValueError( - f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" - ) - - if isinstance(video[0], PIL.Image.Image): - video = [np.array(frame) for frame in video] - - if isinstance(video[0], np.ndarray): - video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) - - if video.dtype == np.uint8: - video = np.array(video).astype(np.float32) / 255.0 - - if video.ndim == 4: - video = video[None, ...] - - video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) - - elif isinstance(video[0], torch.Tensor): - video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) - - # don't need any preprocess if the video is latents - channel = video.shape[1] - if channel == 4: - return video - - # move channels before num_frames - video = video.permute(0, 2, 1, 3, 4) - - # normalize video - video = 2.0 * video - 1.0 - - return video - - class VideoToVideoSDPipeline(DiffusionPipeline, StableDiffusionMixin, TextualInversionLoaderMixin, LoraLoaderMixin): r""" Pipeline for text-guided video-to-video generation. @@ -687,7 +624,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video - video = preprocess_video(video) + video = VideoProcessor.preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -749,7 +686,9 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type) + video = VideoProcessor.tensor2vid( + video=video_tensor, processor=self.image_processor, output_type=output_type + ) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py new file mode 100644 index 000000000000..89d0a167a1c2 --- /dev/null +++ b/src/diffusers/video_processor.py @@ -0,0 +1,87 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Union + +import numpy as np +import PIL +import torch + +from .image_processor import VaeImageProcessor + + +class VideoProcessor: + """Simple video processor.""" + @staticmethod + def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + """Converts a video tensor to a list of frames for export.""" + batch_size, channels, num_frames, height, width = video.shape + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = processor.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs + + @staticmethod + def preprocess_video(video: List[Union[PIL.Image.Image, np.ndarray, torch.Tensor]]): + """Preprocesses an input video.""" + supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) + + if isinstance(video, supported_formats): + video = [video] + elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): + raise ValueError( + f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" + ) + + if isinstance(video[0], PIL.Image.Image): + video = [np.array(frame) for frame in video] + + if isinstance(video[0], np.ndarray): + video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) + + # Notes from (sayakpaul): do we want to follow something similar to VaeImageProcessor here i.e., + # have methods `normalize()` and `denormalize()`? + if video.dtype == np.uint8: + video = np.array(video).astype(np.float32) / 255.0 + + if video.ndim == 4: + video = video[None, ...] + + video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) + + elif isinstance(video[0], torch.Tensor): + video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) + + # don't need any preprocess if the video is latents + channel = video.shape[1] + if channel == 4: + return video + + # move channels before num_frames + video = video.permute(0, 2, 1, 3, 4) + + # normalize video + video = 2.0 * video - 1.0 + + return video From b68e94d04d0c454b2ad3aecd00827c8bff384271 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 25 Apr 2024 14:30:58 +0530 Subject: [PATCH 02/27] fix quality --- src/diffusers/video_processor.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 89d0a167a1c2..ae333d1c071c 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -23,6 +23,7 @@ class VideoProcessor: """Simple video processor.""" + @staticmethod def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): """Converts a video tensor to a list of frames for export.""" From 6680d50e405c797aa391c815eb088a0261d33913 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 27 Apr 2024 09:04:35 +0530 Subject: [PATCH 03/27] address yiyi's feedback --- .../animatediff/pipeline_animatediff.py | 8 ++--- .../pipeline_animatediff_video2video.py | 8 ++--- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 6 ++-- src/diffusers/pipelines/pia/pipeline_pia.py | 7 ++-- .../pipeline_stable_video_diffusion.py | 6 ++-- .../pipeline_text_to_video_synth.py | 7 ++-- .../pipeline_text_to_video_synth_img2img.py | 7 ++-- src/diffusers/video_processor.py | 33 +++++++++++-------- 8 files changed, 36 insertions(+), 46 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 293ae658c385..327121bcfdff 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -18,7 +18,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -138,7 +138,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -815,9 +815,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = VideoProcessor.tensor2vid( - video=video_tensor, processor=self.image_processor, output_type=output_type - ) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index c616e377fefc..93de9c794e6e 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -18,7 +18,7 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -227,7 +227,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -975,9 +975,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = VideoProcessor.tensor2vid( - video=video_tensor, processor=self.image_processor, output_type=output_type - ) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 26122b4c925d..1df76535c6a3 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -135,7 +135,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # `do_resize=False` as we do custom resizing. - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @property def guidance_scale(self): @@ -716,9 +716,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = VideoProcessor.tensor2vid( - video=video_tensor, processor=self.image_processor, output_type=output_type - ) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 83af3a247243..bb4f666c2af8 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -43,7 +43,6 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor -from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin @@ -197,7 +196,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VaeImageProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -938,9 +937,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = VideoProcessor.tensor2vid( - video=video_tensor, processor=self.image_processor, output_type=output_type - ) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index b7722f74f514..d83c40b4c966 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -21,7 +21,7 @@ import torch from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel from ...schedulers import EulerDiscreteScheduler from ...utils import BaseOutput, logging, replace_example_docstring @@ -118,7 +118,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) def _encode_image( self, @@ -541,7 +541,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = VideoProcessor.tensor2vid(video=frames, processor=self.image_processor, output_type=output_type) + frames = self.video_processor.tensor2vid(video=frames, output_type=output_type) else: frames = latents diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 286dbdeb0b3c..bee0bb17bfdf 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -18,7 +18,6 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -105,7 +104,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -630,9 +629,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = VideoProcessor.tensor2vid( - video=video_tensor, processor=self.image_processor, output_type=output_type - ) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 2b7c3af50168..76e040333232 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -19,7 +19,6 @@ import torch from transformers import CLIPTextModel, CLIPTokenizer -from ...image_processor import VaeImageProcessor from ...loaders import LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, UNet3DConditionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -140,7 +139,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -686,9 +685,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = VideoProcessor.tensor2vid( - video=video_tensor, processor=self.image_processor, output_type=output_type - ) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index ae333d1c071c..27ffe2797b90 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -21,17 +21,16 @@ from .image_processor import VaeImageProcessor -class VideoProcessor: - """Simple video processor.""" +class VideoProcessor(VaeImageProcessor): + r"""Simple video processor.""" - @staticmethod - def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): + def tensor2vid(self, video: torch.FloatTensor, output_type: str = "np"): """Converts a video tensor to a list of frames for export.""" - batch_size, channels, num_frames, height, width = video.shape + batch_size = video.shape[0] outputs = [] for batch_idx in range(batch_size): batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) + batch_output = self.postprocess(batch_vid, output_type) outputs.append(batch_output) if output_type == "np": @@ -43,11 +42,11 @@ def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: return outputs - @staticmethod - def preprocess_video(video: List[Union[PIL.Image.Image, np.ndarray, torch.Tensor]]): - """Preprocesses an input video.""" + def preprocess_video(self, video: List[Union[PIL.Image.Image, np.ndarray, torch.Tensor]]) -> torch.FloatTensor: + """Preprocesses input video(s).""" supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) + # Single-frame video. if isinstance(video, supported_formats): video = [video] elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): @@ -55,21 +54,27 @@ def preprocess_video(video: List[Union[PIL.Image.Image, np.ndarray, torch.Tensor f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" ) + # In case the video is a list of PIL images, convert to a list of ndarrays. if isinstance(video[0], PIL.Image.Image): video = [np.array(frame) for frame in video] if isinstance(video[0], np.ndarray): + # When the number of dimension of the first element in `video` is 5, it means + # each element in the `video` list is a video. video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) - # Notes from (sayakpaul): do we want to follow something similar to VaeImageProcessor here i.e., - # have methods `normalize()` and `denormalize()`? if video.dtype == np.uint8: + if video.min() >= 0 and video.max() <= 255: + raise ValueError( + f"The inputs don't have the correct value range for the determined data-type ({video.dtype}): {video.min()=}, {video.max()=}" + ) + # We perform the scaling step here so that `preprocess()` can handle things correctly for us. video = np.array(video).astype(np.float32) / 255.0 if video.ndim == 4: video = video[None, ...] - video = torch.from_numpy(video.transpose(0, 4, 1, 2, 3)) + video = video.permute(0, 4, 1, 2, 3) elif isinstance(video[0], torch.Tensor): video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) @@ -82,7 +87,7 @@ def preprocess_video(video: List[Union[PIL.Image.Image, np.ndarray, torch.Tensor # move channels before num_frames video = video.permute(0, 2, 1, 3, 4) - # normalize video - video = 2.0 * video - 1.0 + # `preprocess()` here would return a PT tensor. + video = torch.stack([self.preprocess(f) for f in video], dim=0) return video From b00e3abaa53e0f949856ab49bad4d4f9b3c2dba7 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 27 Apr 2024 09:07:28 +0530 Subject: [PATCH 04/27] fix preprocess_video call. --- .../pipeline_text_to_video_synth_img2img.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 76e040333232..d517c623fea0 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -623,7 +623,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video - video = VideoProcessor.preprocess_video(video) + video = self.video_processor.preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) From 12515c44de58fd1ea09ef46a6d480f6401934ff1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 27 Apr 2024 09:23:21 +0530 Subject: [PATCH 05/27] video_processor -> image_processor --- src/diffusers/pipelines/animatediff/pipeline_animatediff.py | 4 ++-- .../animatediff/pipeline_animatediff_video2video.py | 4 ++-- src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 4 ++-- src/diffusers/pipelines/pia/pipeline_pia.py | 4 ++-- .../pipeline_stable_video_diffusion.py | 4 ++-- .../text_to_video_synthesis/pipeline_text_to_video_synth.py | 4 ++-- .../pipeline_text_to_video_synth_img2img.py | 6 +++--- 7 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 327121bcfdff..c401311d442d 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -138,7 +138,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -815,7 +815,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 93de9c794e6e..7247e4b1ff87 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -227,7 +227,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -975,7 +975,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index 1df76535c6a3..a476343e8f88 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -135,7 +135,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # `do_resize=False` as we do custom resizing. - self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + self.image_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @property def guidance_scale(self): @@ -716,7 +716,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index bb4f666c2af8..35728ea5259c 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -196,7 +196,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.video_processor = VaeImageProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VaeImageProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -937,7 +937,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index d83c40b4c966..69fdb0a22fe2 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -118,7 +118,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) def _encode_image( self, @@ -541,7 +541,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = self.video_processor.tensor2vid(video=frames, output_type=output_type) + frames = self.image_processor.tensor2vid(video=frames, output_type=output_type) else: frames = latents diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index bee0bb17bfdf..6d7a292523d4 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -104,7 +104,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -629,7 +629,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index d517c623fea0..2c38428b5d5f 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -139,7 +139,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -623,7 +623,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video - video = self.video_processor.preprocess_video(video) + video = self.image_processor.preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -685,7 +685,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() From 6f181046f1acc47b1f101289155d6411157e4d30 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 27 Apr 2024 09:36:37 +0530 Subject: [PATCH 06/27] fix --- src/diffusers/pipelines/pia/pipeline_pia.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 35728ea5259c..35de5d3cbe80 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -21,7 +21,8 @@ import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput +from ...video_processor import VideoProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -196,7 +197,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( From 2208a06775ac01646730b97a23b27435d8a274ef Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 27 Apr 2024 09:38:12 +0530 Subject: [PATCH 07/27] fix more. --- .../pipelines/animatediff/pipeline_animatediff_video2video.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 7247e4b1ff87..7f3fdb2fc0ba 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -227,7 +227,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.image_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( From 778577d227837258185517ca536bea0dca089608 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 27 Apr 2024 09:38:29 +0530 Subject: [PATCH 08/27] quality --- src/diffusers/pipelines/pia/pipeline_pia.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 35de5d3cbe80..9e59e8120c61 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -22,7 +22,6 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ...image_processor import PipelineImageInput -from ...video_processor import VideoProcessor from ...loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel from ...models.lora import adjust_lora_scale_text_encoder @@ -44,6 +43,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin From cb8f13852de20bc4554c5e71b838fd39c3ecaefa Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 29 Apr 2024 16:44:15 +0530 Subject: [PATCH 09/27] image_processor -> video_processor --- .../pipelines/animatediff/pipeline_animatediff.py | 4 ++-- .../animatediff/pipeline_animatediff_video2video.py | 6 +++--- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 10 +++++----- src/diffusers/pipelines/pia/pipeline_pia.py | 6 +++--- .../pipeline_text_to_video_synth.py | 4 ++-- .../pipeline_text_to_video_synth_img2img.py | 6 +++--- 6 files changed, 18 insertions(+), 18 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index c401311d442d..327121bcfdff 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -138,7 +138,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -815,7 +815,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 7f3fdb2fc0ba..696384508471 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -227,7 +227,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -620,7 +620,7 @@ def prepare_latents( video = [video] if latents is None: video = torch.cat( - [self.image_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0 + [self.video_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0 ) video = video.to(device=device, dtype=dtype) num_frames = video.shape[1] @@ -975,7 +975,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index a476343e8f88..de78003cf26c 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -135,7 +135,7 @@ def __init__( ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) # `do_resize=False` as we do custom resizing. - self.image_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor, do_resize=False) @property def guidance_scale(self): @@ -321,8 +321,8 @@ def _encode_image(self, image, device, num_videos_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): - image = self.image_processor.pil_to_numpy(image) - image = self.image_processor.numpy_to_pt(image) + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) # Normalize the image with CLIP training stats. image = self.feature_extractor( @@ -636,7 +636,7 @@ def __call__( # 3.2.2 Image latents. resized_image = _center_crop_wide(image, (width, height)) - image = self.image_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) + image = self.video_processor.preprocess(resized_image).to(device=device, dtype=image_embeddings.dtype) image_latents = self.prepare_image_latents( image, device=device, @@ -716,7 +716,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index 9e59e8120c61..a8a3e441190f 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -197,7 +197,7 @@ def __init__( image_encoder=image_encoder, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt def encode_prompt( @@ -600,7 +600,7 @@ def prepare_masked_condition( ) _, _, _, scaled_height, scaled_width = shape - image = self.image_processor.preprocess(image) + image = self.video_processor.preprocess(image) image = image.to(device, dtype) if isinstance(generator, list): @@ -938,7 +938,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index 6d7a292523d4..bee0bb17bfdf 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -104,7 +104,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -629,7 +629,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index 2c38428b5d5f..d517c623fea0 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -139,7 +139,7 @@ def __init__( scheduler=scheduler, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt def _encode_prompt( @@ -623,7 +623,7 @@ def __call__( prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Preprocess video - video = self.image_processor.preprocess_video(video) + video = self.video_processor.preprocess_video(video) # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) @@ -685,7 +685,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.image_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() From 7561c1e3a15821c8bb930d13e7392b3a743475d2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 30 Apr 2024 09:15:19 +0530 Subject: [PATCH 10/27] support List[List[PIL.Image.Image]] --- src/diffusers/video_processor.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 27ffe2797b90..54750b53595f 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -54,6 +54,19 @@ def preprocess_video(self, video: List[Union[PIL.Image.Image, np.ndarray, torch. f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" ) + # In case the video a list of list of PIL images. + if isinstance(video, list) and isinstance(video[0], list) and isinstance(video[0][0], PIL.Image.Image): + video_ = [] + first_video_length = 0 + for i, vid in enumerate(video): + current_video = [np.array(frame) for frame in vid] + if i == 0: + first_video_length = len(current_video) + if len(current_video) != first_video_length: + raise ValueError("Cannot batch together videos of different lengths.") + video_.append(current_video) + video = np.stack(video_, axis=0) + # In case the video is a list of PIL images, convert to a list of ndarrays. if isinstance(video[0], PIL.Image.Image): video = [np.array(frame) for frame in video] From fd4bb8ad9e231f3cc6b3809cceaac5bfc11a6b0f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 30 Apr 2024 09:53:14 +0530 Subject: [PATCH 11/27] change to video_processor. --- .../pipeline_stable_video_diffusion.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 97ff92d64759..2bec127fbe44 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -118,7 +118,7 @@ def __init__( feature_extractor=feature_extractor, ) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor) def _encode_image( self, @@ -130,8 +130,8 @@ def _encode_image( dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): - image = self.image_processor.pil_to_numpy(image) - image = self.image_processor.numpy_to_pt(image) + image = self.video_processor.pil_to_numpy(image) + image = self.video_processor.numpy_to_pt(image) # We normalize the image before resizing to match with the original implementation. # Then we unnormalize it after resizing. @@ -434,7 +434,7 @@ def __call__( fps = fps - 1 # 4. Encode input image using VAE - image = self.image_processor.preprocess(image, height=height, width=width).to(device) + image = self.video_processor.preprocess(image, height=height, width=width).to(device) noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype) image = image + noise_aug_strength * noise @@ -541,7 +541,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = self.image_processor.tensor2vid(video=frames, output_type=output_type) + frames = self.video_processor.tensor2vid(video=frames, output_type=output_type) else: frames = latents From 96fd13f9cb1dd13e22af72486ebb43361439981f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 30 Apr 2024 10:04:48 +0530 Subject: [PATCH 12/27] documentation --- src/diffusers/video_processor.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 54750b53595f..adb236581a9f 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -24,8 +24,16 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def tensor2vid(self, video: torch.FloatTensor, output_type: str = "np"): - """Converts a video tensor to a list of frames for export.""" + def tensor2vid( + self, video: torch.FloatTensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.FloatTensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.FloatTensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ batch_size = video.shape[0] outputs = [] for batch_idx in range(batch_size): @@ -42,8 +50,21 @@ def tensor2vid(self, video: torch.FloatTensor, output_type: str = "np"): return outputs - def preprocess_video(self, video: List[Union[PIL.Image.Image, np.ndarray, torch.Tensor]]) -> torch.FloatTensor: - """Preprocesses input video(s).""" + def preprocess_video(self, video) -> torch.FloatTensor: + r""" + Preprocesses input video(s). + + Args: + video: The input video. It can be one of the following: + * List of the PIL images. + * List of list of PIL images. + * List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). + * List of list of 4D Torch tensors (expected shape for tensor: (num_frames, num_channels, height, + width)). + * List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). + * List of list of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, + num_channels)). + """ supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) # Single-frame video. From a70fd89e839f6279579a0043b49a7f7d7207a47c Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Wed, 1 May 2024 12:18:33 +0530 Subject: [PATCH 13/27] Apply suggestions from code review --- src/diffusers/video_processor.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index adb236581a9f..1a20bd125065 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -98,9 +98,9 @@ def preprocess_video(self, video) -> torch.FloatTensor: video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) if video.dtype == np.uint8: - if video.min() >= 0 and video.max() <= 255: + if video.min() < 0: raise ValueError( - f"The inputs don't have the correct value range for the determined data-type ({video.dtype}): {video.min()=}, {video.max()=}" + f"The inputs don't have the correct value range for the determined data-type ({video.dtype}): {video.min()=}." ) # We perform the scaling step here so that `preprocess()` can handle things correctly for us. video = np.array(video).astype(np.float32) / 255.0 From 90c0dca5cca2397b77b6b85b0ddca20414a8909b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 3 May 2024 16:43:46 +0530 Subject: [PATCH 14/27] changes --- src/diffusers/video_processor.py | 82 ++++++++----- tests/others/test_video_processor.py | 174 +++++++++++++++++++++++++++ 2 files changed, 223 insertions(+), 33 deletions(-) create mode 100644 tests/others/test_video_processor.py diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index 1a20bd125065..ece646eccf6d 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -64,64 +64,80 @@ def preprocess_video(self, video) -> torch.FloatTensor: * List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). * List of list of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). + * List of 5D NumPy arrays (expected shape for each array: (batch_size, num_frames, height, width, + num_channels). + * List of 5D Torch tensors (expected shape for each array: (batch_size, num_frames, num_channels, + height, width). + * 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width, + num_channels). + * 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height, + width). """ - supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image) + supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image, list) # Single-frame video. - if isinstance(video, supported_formats): + if isinstance(video, supported_formats[:-1]): video = [video] + + # List of PIL images. + elif isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + video = [video] + elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): raise ValueError( - f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(supported_formats)}" + f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(list(map(str, supported_formats)))}" ) - # In case the video a list of list of PIL images. - if isinstance(video, list) and isinstance(video[0], list) and isinstance(video[0][0], PIL.Image.Image): - video_ = [] - first_video_length = 0 - for i, vid in enumerate(video): - current_video = [np.array(frame) for frame in vid] - if i == 0: - first_video_length = len(current_video) - if len(current_video) != first_video_length: - raise ValueError("Cannot batch together videos of different lengths.") - video_.append(current_video) - video = np.stack(video_, axis=0) - - # In case the video is a list of PIL images, convert to a list of ndarrays. - if isinstance(video[0], PIL.Image.Image): - video = [np.array(frame) for frame in video] - if isinstance(video[0], np.ndarray): # When the number of dimension of the first element in `video` is 5, it means # each element in the `video` list is a video. video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) - if video.dtype == np.uint8: - if video.min() < 0: - raise ValueError( - f"The inputs don't have the correct value range for the determined data-type ({video.dtype}): {video.min()=}." - ) - # We perform the scaling step here so that `preprocess()` can handle things correctly for us. - video = np.array(video).astype(np.float32) / 255.0 - if video.ndim == 4: video = video[None, ...] - video = video.permute(0, 4, 1, 2, 3) - elif isinstance(video[0], torch.Tensor): - video = torch.cat(video, axis=0) if video[0].ndim == 5 else torch.stack(video, axis=0) + video = torch.cat(video, dim=0) if video[0].ndim == 5 else torch.stack(video, dim=0) # don't need any preprocess if the video is latents channel = video.shape[1] if channel == 4: return video - # move channels before num_frames - video = video.permute(0, 2, 1, 3, 4) + # List of 5d tensors/ndarrays. + elif isinstance(video[0], list): + if isinstance(video[0][0], (np.ndarray, torch.Tensor)): + all_frames = [] + for list_of_videos in video: + temp_frames = [] + for vid in list_of_videos: + if vid.ndim == 4: + current_vid_frames = np.stack(vid, axis=0) if isinstance(vid, np.ndarray) else vid + elif vid.ndim == 5: + current_vid_frames = ( + np.concatenate(vid, axis=0) if isinstance(vid, np.ndarray) else torch.cat(vid, dim=0) + ) + temp_frames.append(current_vid_frames) + + # Process inner list. + temp_frames = ( + np.stack(temp_frames, axis=0) + if isinstance(temp_frames[0], np.ndarray) + else torch.stack(temp_frames, axis=0) + ) + all_frames.append(temp_frames) + + # Process outer list. + video = ( + np.concatenate(all_frames, axis=0) + if isinstance(all_frames[0], np.ndarray) + else torch.cat(all_frames, dim=0) + ) # `preprocess()` here would return a PT tensor. video = torch.stack([self.preprocess(f) for f in video], dim=0) + # move channels before num_frames + video = video.permute(0, 2, 1, 3, 4) + return video diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py new file mode 100644 index 000000000000..4ea6259545e7 --- /dev/null +++ b/tests/others/test_video_processor.py @@ -0,0 +1,174 @@ +# coding=utf-8 +# Copyright 2024 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +import PIL.Image +import torch +from parameterized import parameterized + +from diffusers.video_processor import VideoProcessor + + +np.random.seed(0) +torch.manual_seed(0) + + +class VideoProcessorTest(unittest.TestCase): + def get_dummy_sample(self, input_type): + batch_size = 1 + num_frames = 5 + num_channels = 3 + height = 8 + width = 8 + + def generate_image(): + return PIL.Image.fromarray(np.random.randint(0, 256, size=(height, width, num_channels)).astype("uint8")) + + def generate_4d_array(): + return np.random.rand(num_frames, height, width, num_channels) + + def generate_5d_array(): + return np.random.rand(batch_size, num_frames, height, width, num_channels) + + def generate_4d_tensor(): + return torch.rand(num_frames, num_channels, height, width) + + def generate_5d_tensor(): + return torch.rand(batch_size, num_frames, num_channels, height, width) + + if input_type == "list_images": + sample = [generate_image() for _ in range(num_frames)] + elif input_type == "list_list_images": + sample = [[generate_image() for _ in range(num_frames)] for _ in range(num_frames)] + elif input_type == "list_4d_np": + sample = [generate_4d_array() for _ in range(num_frames)] + elif input_type == "list_list_4d_np": + sample = [[generate_4d_array() for _ in range(num_frames)] for _ in range(num_frames)] + elif input_type == "list_5d_np": + sample = [generate_5d_array() for _ in range(num_frames)] + elif input_type == "5d_np": + sample = generate_5d_array() + elif input_type == "list_4d_pt": + sample = [generate_4d_tensor() for _ in range(num_frames)] + elif input_type == "list_list_4d_pt": + sample = [[generate_4d_tensor() for _ in range(num_frames)] for _ in range(num_frames)] + elif input_type == "list_5d_pt": + sample = [generate_5d_tensor() for _ in range(num_frames)] + elif input_type == "5d_pt": + sample = generate_5d_tensor() + + return sample + + def to_np(self, video): + # List of images. + if isinstance(video[0], PIL.Image.Image): + video = np.stack([np.array(i) for i in video], axis=0) + + # List of list of images. + elif isinstance(video, list) and isinstance(video[0][0], PIL.Image.Image): + frames = [] + for vid in video: + all_current_frames = np.stack([np.array(i) for i in vid], axis=0) + frames.append(all_current_frames) + video = np.stack([np.array(frame) for frame in frames], axis=0) + + # List of 4d/5d {ndarrays, torch tensors}. + elif isinstance(video, list) and isinstance(video[0], (torch.Tensor, np.ndarray)): + if isinstance(video[0], np.ndarray): + video = np.stack(video, axis=0) if video[0].ndim == 4 else np.concatenate(video, axis=0) + else: + if video[0].ndim == 4: + video = np.stack([i.cpu().numpy().transpose(0, 2, 3, 1) for i in video], axis=0) + elif video[0].ndim == 5: + video = np.concatenate([i.cpu().numpy().transpose(0, 1, 3, 4, 2) for i in video], axis=0) + + # List of list of 4d/5d {ndarrays, torch tensors}. + elif ( + isinstance(video, list) + and isinstance(video[0], list) + and isinstance(video[0][0], (torch.Tensor, np.ndarray)) + ): + all_frames = [] + for list_of_videos in video: + temp_frames = [] + for vid in list_of_videos: + if vid.ndim == 4: + current_vid_frames = np.stack( + [i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(1, 2, 0) for i in vid], + axis=0, + ) + elif vid.ndim == 5: + current_vid_frames = np.concatenate( + [i if isinstance(i, np.ndarray) else i.cpu().numpy().transpose(0, 2, 3, 1) for i in vid], + axis=0, + ) + temp_frames.append(current_vid_frames) + temp_frames = np.stack(temp_frames, axis=0) + all_frames.append(temp_frames) + + video = np.concatenate(all_frames, axis=0) + + # Just 5d {ndarrays, torch tensors}. + elif isinstance(video, (torch.Tensor, np.ndarray)) and video.ndim == 5: + video = video if isinstance(video, np.ndarray) else video.cpu().numpy().transpose(0, 1, 3, 4, 2) + + return video + + @parameterized.expand(["list_images", "list_list_images"]) + def test_video_processor_pil(self, input_type): + video_processor = VideoProcessor(do_resize=False, do_normalize=True) + + input = self.get_dummy_sample(input_type=input_type) + + for output_type in ["pt", "np", "pil"]: + out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) + out_np = self.to_np(out) + input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input) + assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" + + @parameterized.expand(["list_4d_np", "list_list_4d_np", "list_5d_np", "5d_np"]) + def test_video_processor_np(self, input_type): + video_processor = VideoProcessor(do_resize=False, do_normalize=True) + + input = self.get_dummy_sample(input_type=input_type) + + for output_type in ["pt", "np", "pil"]: + out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) + out_np = self.to_np(out) + input_np = ( + (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) + ) + assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" + + @parameterized.expand(["list_4d_pt", "list_list_4d_pt", "list_5d_pt", "5d_pt"]) + def test_video_processor_pt(self, input_type): + video_processor = VideoProcessor(do_resize=False, do_normalize=True) + + input = self.get_dummy_sample(input_type=input_type) + + for output_type in ["pt", "np", "pil"]: + if input_type == "list_list_4d_pt": + print(input[0][0].ndim) + out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) + out_np = self.to_np(out) + input_np = ( + (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) + ) + print(f"{input_np.max()=}, {input_np.min()=}, {out_np.max()=}, {out_np.min()=}") + print(input_np[0, :3, :3, -1].flatten()) + print(out_np[0, :3, :3, -1].flatten()) + assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" From 90ff0dab2bfae16b8cccd8dfdc5353ff54b2fe0e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 4 May 2024 07:52:25 +0530 Subject: [PATCH 15/27] remove print. --- tests/others/test_video_processor.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py index 4ea6259545e7..71524b35904b 100644 --- a/tests/others/test_video_processor.py +++ b/tests/others/test_video_processor.py @@ -161,14 +161,9 @@ def test_video_processor_pt(self, input_type): input = self.get_dummy_sample(input_type=input_type) for output_type in ["pt", "np", "pil"]: - if input_type == "list_list_4d_pt": - print(input[0][0].ndim) out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) out_np = self.to_np(out) input_np = ( (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) ) - print(f"{input_np.max()=}, {input_np.min()=}, {out_np.max()=}, {out_np.min()=}") - print(input_np[0, :3, :3, -1].flatten()) - print(out_np[0, :3, :3, -1].flatten()) assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" From e2f61d593c508bda63d908e09c08f2757de7fa01 Mon Sep 17 00:00:00 2001 From: YiYi Xu Date: Mon, 6 May 2024 18:51:37 -1000 Subject: [PATCH 16/27] refactor video processor (part # 7776) (#7861) * update * update remove deprecate * Update src/diffusers/video_processor.py * update * Apply suggestions from code review * deprecate list of 5d for video and list of 4d for image + apply other feedbacks * up --------- Co-authored-by: Sayak Paul --- src/diffusers/image_processor.py | 77 ++++++++++++++----- src/diffusers/video_processor.py | 108 +++++++++------------------ tests/others/test_video_processor.py | 4 +- 3 files changed, 94 insertions(+), 95 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 4ccb9d77d627..027691ad9f2f 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -29,15 +29,34 @@ PipelineImageInput = Union[ PIL.Image.Image, np.ndarray, - torch.FloatTensor, + torch.Tensor, List[PIL.Image.Image], List[np.ndarray], - List[torch.FloatTensor], + List[torch.Tensor], ] PipelineDepthInput = PipelineImageInput +def is_valid_image(image): + return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3) + + +def is_valid_image_imagelist(images): + # check if the image input is one of the supported formats for image and image list: + # it can be either one of below 3 + # (1) a 4d pytorch tensor or numpy array, + # (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor + # (3) a list of valid image + if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4: + return True + elif is_valid_image(images): + return True + elif isinstance(images, list): + return all(is_valid_image(image) for image in images) + return False + + class VaeImageProcessor(ConfigMixin): """ Image processor for VAE. @@ -111,7 +130,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd return images @staticmethod - def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: + def numpy_to_pt(images: np.ndarray) -> torch.Tensor: """ Convert a NumPy image to a PyTorch tensor. """ @@ -122,7 +141,7 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor: return images @staticmethod - def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray: + def pt_to_numpy(images: torch.Tensor) -> np.ndarray: """ Convert a PyTorch tensor to a NumPy image. """ @@ -498,9 +517,29 @@ def preprocess( else: image = np.expand_dims(image, axis=-1) - if isinstance(image, supported_formats): + if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray", + FutureWarning, + ) + image = np.concatenate(image, axis=0) + if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4: + warnings.warn( + "Passing `image` as a list of 4d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor", + FutureWarning, + ) + image = torch.cat(image, axis=0) + + # ensure the input is a list of images: + # - if it is a batch of images (4d torch.Tensor or np.ndarray), it is converted to a list of images (a list of 3d torch.Tensor or np.ndarray) + # - if it is a single image, it is converted to a list of one image + if isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim == 4: + image = list(image) + if is_valid_image(image): image = [image] - elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)): + if not all(is_valid_image(img) for img in image): raise ValueError( f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" ) @@ -562,15 +601,15 @@ def preprocess( def postprocess( self, - image: torch.FloatTensor, + image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: - image (`torch.FloatTensor`): + image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. @@ -579,7 +618,7 @@ def postprocess( `VaeImageProcessor` config. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): @@ -739,15 +778,15 @@ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]: def postprocess( self, - image: torch.FloatTensor, + image: torch.Tensor, output_type: str = "pil", do_denormalize: Optional[List[bool]] = None, - ) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]: + ) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]: """ Postprocess the image output from tensor to `output_type`. Args: - image (`torch.FloatTensor`): + image (`torch.Tensor`): The image input, should be a pytorch tensor with shape `B x C x H x W`. output_type (`str`, *optional*, defaults to `pil`): The output type of the image, can be one of `pil`, `np`, `pt`, `latent`. @@ -756,7 +795,7 @@ def postprocess( `VaeImageProcessor` config. Returns: - `PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`: + `PIL.Image.Image`, `np.ndarray` or `torch.Tensor`: The postprocessed image. """ if not isinstance(image, torch.Tensor): @@ -794,8 +833,8 @@ def postprocess( def preprocess( self, - rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], - depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray], + rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray], + depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray], height: Optional[int] = None, width: Optional[int] = None, target_res: Optional[int] = None, @@ -934,13 +973,13 @@ def __init__( ) @staticmethod - def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int): + def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int): """ Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued. Args: - mask (`torch.FloatTensor`): + mask (`torch.Tensor`): The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`. batch_size (`int`): The batch size. @@ -950,7 +989,7 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value The dimensionality of the value embeddings. Returns: - `torch.FloatTensor`: + `torch.Tensor`: The downsampled mask tensor. """ diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index ece646eccf6d..c03736c6398f 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -12,26 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. +import warnings from typing import List, Union import numpy as np import PIL import torch -from .image_processor import VaeImageProcessor +from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" def tensor2vid( - self, video: torch.FloatTensor, output_type: str = "np" - ) -> Union[np.ndarray, torch.FloatTensor, List[PIL.Image.Image]]: + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: r""" Converts a video tensor to a list of frames for export. Args: - video (`torch.FloatTensor`): The video as a tensor. + video (`torch.Tensor`): The video as a tensor. output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. """ batch_size = video.shape[0] @@ -50,7 +51,7 @@ def tensor2vid( return outputs - def preprocess_video(self, video) -> torch.FloatTensor: + def preprocess_video(self, video) -> torch.Tensor: r""" Preprocesses input video(s). @@ -58,86 +59,45 @@ def preprocess_video(self, video) -> torch.FloatTensor: video: The input video. It can be one of the following: * List of the PIL images. * List of list of PIL images. + * 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). + * 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). * List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). - * List of list of 4D Torch tensors (expected shape for tensor: (num_frames, num_channels, height, - width)). * List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). - * List of list of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, - num_channels)). - * List of 5D NumPy arrays (expected shape for each array: (batch_size, num_frames, height, width, - num_channels). - * List of 5D Torch tensors (expected shape for each array: (batch_size, num_frames, num_channels, - height, width). * 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width, num_channels). * 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height, width). """ - supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image, list) - - # Single-frame video. - if isinstance(video, supported_formats[:-1]): - video = [video] - - # List of PIL images. - elif isinstance(video, list) and isinstance(video[0], PIL.Image.Image): + if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d np.ndarray is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", + FutureWarning, + ) + video = np.concatenate(video, axis=0) + if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: + warnings.warn( + "Passing `video` as a list of 5d torch.Tensor is deprecated." + "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", + FutureWarning, + ) + video = torch.cat(video, axis=0) + + # ensure the input is a list of videos: + # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) + # - if it is is a single video, it is convereted to a list of one video. + if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: + video = list(video) + elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): video = [video] - - elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): + elif isinstance(video, list) and is_valid_image_imagelist(video[0]): + video = video + else: raise ValueError( - f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(list(map(str, supported_formats)))}" + "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - if isinstance(video[0], np.ndarray): - # When the number of dimension of the first element in `video` is 5, it means - # each element in the `video` list is a video. - video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) - - if video.ndim == 4: - video = video[None, ...] - - elif isinstance(video[0], torch.Tensor): - video = torch.cat(video, dim=0) if video[0].ndim == 5 else torch.stack(video, dim=0) - - # don't need any preprocess if the video is latents - channel = video.shape[1] - if channel == 4: - return video - - # List of 5d tensors/ndarrays. - elif isinstance(video[0], list): - if isinstance(video[0][0], (np.ndarray, torch.Tensor)): - all_frames = [] - for list_of_videos in video: - temp_frames = [] - for vid in list_of_videos: - if vid.ndim == 4: - current_vid_frames = np.stack(vid, axis=0) if isinstance(vid, np.ndarray) else vid - elif vid.ndim == 5: - current_vid_frames = ( - np.concatenate(vid, axis=0) if isinstance(vid, np.ndarray) else torch.cat(vid, dim=0) - ) - temp_frames.append(current_vid_frames) - - # Process inner list. - temp_frames = ( - np.stack(temp_frames, axis=0) - if isinstance(temp_frames[0], np.ndarray) - else torch.stack(temp_frames, axis=0) - ) - all_frames.append(temp_frames) - - # Process outer list. - video = ( - np.concatenate(all_frames, axis=0) - if isinstance(all_frames[0], np.ndarray) - else torch.cat(all_frames, dim=0) - ) - - # `preprocess()` here would return a PT tensor. - video = torch.stack([self.preprocess(f) for f in video], dim=0) - - # move channels before num_frames + video = torch.stack([self.preprocess(img) for img in video], dim=0) video = video.permute(0, 2, 1, 3, 4) return video diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py index 71524b35904b..40f024fc9b2b 100644 --- a/tests/others/test_video_processor.py +++ b/tests/others/test_video_processor.py @@ -140,7 +140,7 @@ def test_video_processor_pil(self, input_type): input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input) assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" - @parameterized.expand(["list_4d_np", "list_list_4d_np", "list_5d_np", "5d_np"]) + @parameterized.expand(["list_4d_np", "list_5d_np", "5d_np"]) def test_video_processor_np(self, input_type): video_processor = VideoProcessor(do_resize=False, do_normalize=True) @@ -154,7 +154,7 @@ def test_video_processor_np(self, input_type): ) assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" - @parameterized.expand(["list_4d_pt", "list_list_4d_pt", "list_5d_pt", "5d_pt"]) + @parameterized.expand(["list_4d_pt", "list_5d_pt", "5d_pt"]) def test_video_processor_pt(self, input_type): video_processor = VideoProcessor(do_resize=False, do_normalize=True) From c5ccf5f5ee1d742b61572903e7b5a94982920d4e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 06:57:32 +0200 Subject: [PATCH 17/27] add doc. --- docs/source/en/_toctree.yml | 2 ++ docs/source/en/api/video_processor.md | 21 +++++++++++++++++++++ 2 files changed, 23 insertions(+) create mode 100644 docs/source/en/api/video_processor.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index 1c21d4cd9f74..d3a6db727b82 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -439,6 +439,8 @@ title: Utilities - local: api/image_processor title: VAE Image Processor + - local: api/video_processor + title: Video Processor title: Internal classes isExpanded: false title: API diff --git a/docs/source/en/api/video_processor.md b/docs/source/en/api/video_processor.md new file mode 100644 index 000000000000..b9ae67f88dd9 --- /dev/null +++ b/docs/source/en/api/video_processor.md @@ -0,0 +1,21 @@ + + +# Video Processor + +The [`VideoProcessor`] provides a unified API for [`StableDiffusionPipeline`]s to prepare video inputs for VAE encoding and post-processing outputs once they're decoded. It inherits the [`VaeImageProcessor`] class, which already includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays. + +Some pipelines such as [`VideoToVideoSDPipeline`] with [`VideoProcessor`] accept videos as a list of PIL Images, PyTorch tensors, or NumPy arrays as video inputs and return outputs based on the `output_type` argument by the user. You can pass encoded image latents directly to the pipeline and return latents from the pipeline as a specific output with the `output_type` argument (for example `output_type="latent"`). This allows you to take the generated latents from one pipeline and pass it to another pipeline as input without leaving the latent space. It also makes it much easier to use multiple pipelines together by passing PyTorch tensors directly between different pipelines. + +## VideoProcessor + +[[autodoc]] video_processor.VideoProcessor From 00cb9c12d508dbbdeae3fcf4a5d87b7bc7486cac Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 07:06:58 +0200 Subject: [PATCH 18/27] tensor2vid -> postprocess_video. --- .../animatediff/pipeline_animatediff.py | 2 +- .../pipeline_animatediff_video2video.py | 2 +- .../pipelines/i2vgen_xl/pipeline_i2vgen_xl.py | 2 +- src/diffusers/pipelines/pia/pipeline_pia.py | 2 +- .../pipeline_stable_video_diffusion.py | 2 +- .../pipeline_text_to_video_synth.py | 2 +- .../pipeline_text_to_video_synth_img2img.py | 2 +- src/diffusers/video_processor.py | 52 +++++++++---------- tests/others/test_video_processor.py | 6 +-- 9 files changed, 36 insertions(+), 36 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index 327121bcfdff..94654c4a7e17 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -815,7 +815,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index 696384508471..d80ac9fee089 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -975,7 +975,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py index de78003cf26c..a38918e1a0f9 100644 --- a/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +++ b/src/diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py @@ -716,7 +716,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents, decode_chunk_size=decode_chunk_size) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index a8a3e441190f..a2723187f2db 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -938,7 +938,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index 2bec127fbe44..3121614b0002 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -541,7 +541,7 @@ def __call__( if needs_upcasting: self.vae.to(dtype=torch.float16) frames = self.decode_latents(latents, num_frames, decode_chunk_size) - frames = self.video_processor.tensor2vid(video=frames, output_type=output_type) + frames = self.video_processor.postprocess_video(video=frames, output_type=output_type) else: frames = latents diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py index bee0bb17bfdf..0ef769f32a15 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py @@ -629,7 +629,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 9. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py index d517c623fea0..0dc1ca93f873 100644 --- a/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +++ b/src/diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py @@ -685,7 +685,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = self.video_processor.tensor2vid(video=video_tensor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # 10. Offload all models self.maybe_free_model_hooks() diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index c03736c6398f..ec42f74332ba 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -25,32 +25,6 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def tensor2vid( - self, video: torch.Tensor, output_type: str = "np" - ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: - r""" - Converts a video tensor to a list of frames for export. - - Args: - video (`torch.Tensor`): The video as a tensor. - output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. - """ - batch_size = video.shape[0] - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = self.postprocess(batch_vid, output_type) - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - elif output_type == "pt": - outputs = torch.stack(outputs) - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - def preprocess_video(self, video) -> torch.Tensor: r""" Preprocesses input video(s). @@ -101,3 +75,29 @@ def preprocess_video(self, video) -> torch.Tensor: video = video.permute(0, 2, 1, 3, 4) return video + + def postprocess_video( + self, video: torch.Tensor, output_type: str = "np" + ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: + r""" + Converts a video tensor to a list of frames for export. + + Args: + video (`torch.Tensor`): The video as a tensor. + output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. + """ + batch_size = video.shape[0] + outputs = [] + for batch_idx in range(batch_size): + batch_vid = video[batch_idx].permute(1, 0, 2, 3) + batch_output = self.postprocess(batch_vid, output_type) + outputs.append(batch_output) + + if output_type == "np": + outputs = np.stack(outputs) + elif output_type == "pt": + outputs = torch.stack(outputs) + elif not output_type == "pil": + raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") + + return outputs diff --git a/tests/others/test_video_processor.py b/tests/others/test_video_processor.py index 40f024fc9b2b..a2fc87717f9d 100644 --- a/tests/others/test_video_processor.py +++ b/tests/others/test_video_processor.py @@ -135,7 +135,7 @@ def test_video_processor_pil(self, input_type): input = self.get_dummy_sample(input_type=input_type) for output_type in ["pt", "np", "pil"]: - out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) + out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type) out_np = self.to_np(out) input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input) assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}" @@ -147,7 +147,7 @@ def test_video_processor_np(self, input_type): input = self.get_dummy_sample(input_type=input_type) for output_type in ["pt", "np", "pil"]: - out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) + out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type) out_np = self.to_np(out) input_np = ( (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) @@ -161,7 +161,7 @@ def test_video_processor_pt(self, input_type): input = self.get_dummy_sample(input_type=input_type) for output_type in ["pt", "np", "pil"]: - out = video_processor.tensor2vid(video_processor.preprocess_video(input), output_type=output_type) + out = video_processor.postprocess_video(video_processor.preprocess_video(input), output_type=output_type) out_np = self.to_np(out) input_np = ( (self.to_np(input) * 255.0).round().astype("uint8") if output_type == "pil" else self.to_np(input) From 814aa79f4466bc137fbb706af28f45704b5b6e9e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 07:27:29 +0200 Subject: [PATCH 19/27] refactor preprocess with preprocess_video --- .../pipeline_animatediff_video2video.py | 4 +--- src/diffusers/video_processor.py | 21 +++++++++++++++---- 2 files changed, 18 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index d80ac9fee089..e6e750c7cbe2 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -619,9 +619,7 @@ def prepare_latents( if video and not isinstance(video[0], list): video = [video] if latents is None: - video = torch.cat( - [self.video_processor.preprocess(vid, height=height, width=width).unsqueeze(0) for vid in video], dim=0 - ) + video = self.video_processor.preprocess_video(video, height=height, width=width, preceed_with_frames=True) video = video.to(device=device, dtype=dtype) num_frames = video.shape[1] else: diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index ec42f74332ba..b42fefc4e228 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings -from typing import List, Union +from typing import List, Optional, Union import numpy as np import PIL @@ -25,7 +25,9 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" - def preprocess_video(self, video) -> torch.Tensor: + def preprocess_video( + self, video, height: Optional[int], width: Optional[int], preceed_with_frames: bool = False + ) -> torch.Tensor: r""" Preprocesses input video(s). @@ -41,6 +43,16 @@ def preprocess_video(self, video) -> torch.Tensor: num_channels). * 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height, width). + height (`int`, *optional*, defaults to `None`): + The height in preprocessed frames of the video. If `None`, will use the `get_default_height_width()` to + get default height. + width (`int`, *optional*`, defaults to `None`): + The width in preprocessed frames of the video. If `None`, will use get_default_height_width()` to get + the default width. + preceed_with_frames (`bool`, defaults to False): + Some pipelines keep the number of channels _before_ the number of frames in terms dimensions. + `(batch_size, num_channels, num_frames, height, width)`, for example. Some pipelines don't. This flag + helps to control that behaviour. """ if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: warnings.warn( @@ -71,8 +83,9 @@ def preprocess_video(self, video) -> torch.Tensor: "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" ) - video = torch.stack([self.preprocess(img) for img in video], dim=0) - video = video.permute(0, 2, 1, 3, 4) + video = torch.stack([self.preprocess(img, height=height, width=width) for img in video], dim=0) + if not preceed_with_frames: + video = video.permute(0, 2, 1, 3, 4) return video From 28ef7650494d66d63f541b27c5b96cbf7f39ab8b Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 07:34:34 +0200 Subject: [PATCH 20/27] set default values. --- src/diffusers/video_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/video_processor.py b/src/diffusers/video_processor.py index b42fefc4e228..3cd8461fa5f6 100644 --- a/src/diffusers/video_processor.py +++ b/src/diffusers/video_processor.py @@ -26,7 +26,7 @@ class VideoProcessor(VaeImageProcessor): r"""Simple video processor.""" def preprocess_video( - self, video, height: Optional[int], width: Optional[int], preceed_with_frames: bool = False + self, video, height: Optional[int] = None, width: Optional[int] = None, preceed_with_frames: bool = False ) -> torch.Tensor: r""" Preprocesses input video(s). From 833d41500c7bd9e16755cf92878d7e9001b9dd16 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 08:24:49 +0200 Subject: [PATCH 21/27] empty commit From a2168f3bf1d6b15bd5b42aac103915d26bf35175 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 14:04:25 +0200 Subject: [PATCH 22/27] more refactoring of prepare_latents in animatediff vid2vid --- .../pipeline_animatediff_video2video.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index e6e750c7cbe2..11d15e2f897b 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -613,14 +613,7 @@ def prepare_latents( generator, latents=None, ): - # video must be a list of list of images - # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video - # as a list of images - if video and not isinstance(video[0], list): - video = [video] if latents is None: - video = self.video_processor.preprocess_video(video, height=height, width=width, preceed_with_frames=True) - video = video.to(device=device, dtype=dtype) num_frames = video.shape[1] else: num_frames = latents.shape[2] @@ -893,6 +886,14 @@ def __call__( latent_timestep = timesteps[:1].repeat(batch_size * num_videos_per_prompt) # 5. Prepare latent variables + # video must be a list of list of images + # the outer list denotes having multiple videos as input, whereas inner list means the frames of the video + # as a list of images + if video and not isinstance(video[0], list): + video = [video] + if latents is None: + video = self.video_processor.preprocess_video(video, height=height, width=width, preceed_with_frames=True) + video = video.to(device=device, dtype=prompt_embeds.dtype) num_channels_latents = self.unet.config.in_channels latents = self.prepare_latents( video=video, From 7174f4bd001370e1658a5f7e9a6a10e760349515 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 14:08:17 +0200 Subject: [PATCH 23/27] checking documentation --- docs/source/en/api/video_processor.md | 4 ---- 1 file changed, 4 deletions(-) diff --git a/docs/source/en/api/video_processor.md b/docs/source/en/api/video_processor.md index b9ae67f88dd9..cc262a901ab6 100644 --- a/docs/source/en/api/video_processor.md +++ b/docs/source/en/api/video_processor.md @@ -12,10 +12,6 @@ specific language governing permissions and limitations under the License. # Video Processor -The [`VideoProcessor`] provides a unified API for [`StableDiffusionPipeline`]s to prepare video inputs for VAE encoding and post-processing outputs once they're decoded. It inherits the [`VaeImageProcessor`] class, which already includes transformations such as resizing, normalization, and conversion between PIL Image, PyTorch, and NumPy arrays. - -Some pipelines such as [`VideoToVideoSDPipeline`] with [`VideoProcessor`] accept videos as a list of PIL Images, PyTorch tensors, or NumPy arrays as video inputs and return outputs based on the `output_type` argument by the user. You can pass encoded image latents directly to the pipeline and return latents from the pipeline as a specific output with the `output_type` argument (for example `output_type="latent"`). This allows you to take the generated latents from one pipeline and pass it to another pipeline as input without leaving the latent space. It also makes it much easier to use multiple pipelines together by passing PyTorch tensors directly between different pipelines. - ## VideoProcessor [[autodoc]] video_processor.VideoProcessor From d81791bf4429d2c792ed0437abaad049a85a0384 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 May 2024 14:14:58 +0200 Subject: [PATCH 24/27] remove documentation for now. --- docs/source/en/_toctree.yml | 2 -- docs/source/en/api/video_processor.md | 17 ----------------- 2 files changed, 19 deletions(-) delete mode 100644 docs/source/en/api/video_processor.md diff --git a/docs/source/en/_toctree.yml b/docs/source/en/_toctree.yml index d3a6db727b82..1c21d4cd9f74 100644 --- a/docs/source/en/_toctree.yml +++ b/docs/source/en/_toctree.yml @@ -439,8 +439,6 @@ title: Utilities - local: api/image_processor title: VAE Image Processor - - local: api/video_processor - title: Video Processor title: Internal classes isExpanded: false title: API diff --git a/docs/source/en/api/video_processor.md b/docs/source/en/api/video_processor.md deleted file mode 100644 index cc262a901ab6..000000000000 --- a/docs/source/en/api/video_processor.md +++ /dev/null @@ -1,17 +0,0 @@ - - -# Video Processor - -## VideoProcessor - -[[autodoc]] video_processor.VideoProcessor From efeafcdf99f00032e9567fce63b9d69679c92c2a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Wed, 8 May 2024 22:48:58 +0200 Subject: [PATCH 25/27] fix animatediff sdxl --- .../animatediff/pipeline_animatediff_sdxl.py | 30 +++---------------- 1 file changed, 4 insertions(+), 26 deletions(-) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py index f15cd5dbebf7..9dafcfd993ed 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py @@ -15,7 +15,6 @@ import inspect from typing import Any, Callable, Dict, List, Optional, Tuple, Union -import numpy as np import torch from transformers import ( CLIPImageProcessor, @@ -25,7 +24,7 @@ CLIPVisionModelWithProjection, ) -from ...image_processor import PipelineImageInput, VaeImageProcessor +from ...image_processor import PipelineImageInput from ...loaders import ( FromSingleFileMixin, IPAdapterMixin, @@ -57,6 +56,7 @@ unscale_lora_layers, ) from ...utils.torch_utils import randn_tensor +from ...video_processor import VideoProcessor from ..free_init_utils import FreeInitMixin from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin from .pipeline_output import AnimateDiffPipelineOutput @@ -113,28 +113,6 @@ """ -# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid -def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"): - batch_size, channels, num_frames, height, width = video.shape - outputs = [] - for batch_idx in range(batch_size): - batch_vid = video[batch_idx].permute(1, 0, 2, 3) - batch_output = processor.postprocess(batch_vid, output_type) - - outputs.append(batch_output) - - if output_type == "np": - outputs = np.stack(outputs) - - elif output_type == "pt": - outputs = torch.stack(outputs) - - elif not output_type == "pil": - raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") - - return outputs - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -305,7 +283,7 @@ def __init__( ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) - self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor) self.default_sample_size = self.unet.config.sample_size @@ -1269,7 +1247,7 @@ def __call__( video = latents else: video_tensor = self.decode_latents(latents) - video = tensor2vid(video_tensor, self.image_processor, output_type=output_type) + video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type) # cast back to fp16 if needed if needs_upcasting: From 97aaf8ea7776cf1d929020a16beb942d829c401c Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 10 May 2024 06:54:53 +0200 Subject: [PATCH 26/27] up --- src/diffusers/image_processor.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 5d2e2c2b9c8c..4d8d73c863ae 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -531,17 +531,12 @@ def preprocess( ) image = torch.cat(image, axis=0) - # ensure the input is a list of images: - # - if it is a batch of images (4d torch.Tensor or np.ndarray), it is converted to a list of images (a list of 3d torch.Tensor or np.ndarray) - # - if it is a single image, it is converted to a list of one image - if isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim == 4: - image = list(image) - if is_valid_image(image): - image = [image] - if not all(is_valid_image(img) for img in image): + if not is_valid_image_imagelist(image): raise ValueError( - f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}" ) + if not isinstance(image, list): + image = [image] if isinstance(image[0], PIL.Image.Image): if crops_coords is not None: From 24658ad79e95585ef74e41d0a101f3fe46236491 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Fri, 10 May 2024 06:58:34 +0200 Subject: [PATCH 27/27] style --- .../stable_video_diffusion/pipeline_stable_video_diffusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py index ffea166db9a8..d815adab049f 100644 --- a/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +++ b/src/diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py @@ -121,6 +121,7 @@ def retrieve_timesteps( timesteps = scheduler.timesteps return timesteps, num_inference_steps + @dataclass class StableVideoDiffusionPipelineOutput(BaseOutput): r"""