Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
ede27ee
introduce videoprocessor.
sayakpaul Apr 25, 2024
b68e94d
fix quality
sayakpaul Apr 25, 2024
3d8a263
Merge branch 'main' into video-processor
sayakpaul Apr 27, 2024
6680d50
address yiyi's feedback
sayakpaul Apr 27, 2024
b00e3ab
fix preprocess_video call.
sayakpaul Apr 27, 2024
12515c4
video_processor -> image_processor
sayakpaul Apr 27, 2024
6f18104
fix
sayakpaul Apr 27, 2024
2208a06
fix more.
sayakpaul Apr 27, 2024
778577d
quality
sayakpaul Apr 27, 2024
cbc5638
Merge branch 'main' into video-processor
sayakpaul Apr 29, 2024
1cdd919
Merge branch 'main' into video-processor
sayakpaul Apr 29, 2024
cb8f138
image_processor -> video_processor
sayakpaul Apr 29, 2024
6246244
Merge branch 'main' into video-processor
sayakpaul Apr 29, 2024
6c8f300
Merge branch 'main' into video-processor
sayakpaul Apr 30, 2024
459843a
Merge branch 'main' into video-processor
sayakpaul Apr 30, 2024
7561c1e
support List[List[PIL.Image.Image]]
sayakpaul Apr 30, 2024
fd4bb8a
change to video_processor.
sayakpaul Apr 30, 2024
50e5498
Merge branch 'main' into video-processor
sayakpaul Apr 30, 2024
96fd13f
documentation
sayakpaul Apr 30, 2024
c5d22e6
Merge branch 'main' into video-processor
sayakpaul May 1, 2024
a70fd89
Apply suggestions from code review
sayakpaul May 1, 2024
2170c90
Merge branch 'main' into video-processor
sayakpaul May 3, 2024
90c0dca
changes
sayakpaul May 3, 2024
90ff0da
remove print.
sayakpaul May 4, 2024
1c474ba
Merge branch 'main' into video-processor
sayakpaul May 5, 2024
e2f61d5
refactor video processor (part # 7776) (#7861)
yiyixuxu May 7, 2024
06e7420
Merge branch 'main' into video-processor
sayakpaul May 7, 2024
c5ccf5f
add doc.
sayakpaul May 7, 2024
00cb9c1
tensor2vid -> postprocess_video.
sayakpaul May 7, 2024
814aa79
refactor preprocess with preprocess_video
sayakpaul May 7, 2024
28ef765
set default values.
sayakpaul May 7, 2024
833d415
empty commit
sayakpaul May 7, 2024
2fc3528
Merge branch 'main' into video-processor
sayakpaul May 7, 2024
a2168f3
more refactoring of prepare_latents in animatediff vid2vid
sayakpaul May 7, 2024
7174f4b
checking documentation
sayakpaul May 7, 2024
d81791b
remove documentation for now.
sayakpaul May 7, 2024
2c052ab
Merge branch 'main' into video-processor
sayakpaul May 8, 2024
efeafcd
fix animatediff sdxl
sayakpaul May 8, 2024
1d30685
Merge branch 'main' into video-processor
sayakpaul May 9, 2024
05d3e90
Merge branch 'main' into test-test
yiyixuxu May 10, 2024
97aaf8e
up
yiyixuxu May 10, 2024
fe65900
Merge branch 'test-test' of github.com:huggingface/diffusers into tes…
yiyixuxu May 10, 2024
24658ad
style
yiyixuxu May 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 55 additions & 21 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,34 @@
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
torch.Tensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
List[torch.Tensor],
]

PipelineDepthInput = PipelineImageInput


def is_valid_image(image):
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)


def is_valid_image_imagelist(images):
# check if the image input is one of the supported formats for image and image list:
# it can be either one of below 3
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
# (3) a list of valid image
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
return True
elif isinstance(images, list):
return all(is_valid_image(image) for image in images)
return False


class VaeImageProcessor(ConfigMixin):
"""
Image processor for VAE.
Expand Down Expand Up @@ -110,7 +129,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd
return images

@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
"""
Convert a NumPy image to a PyTorch tensor.
"""
Expand All @@ -121,7 +140,7 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
return images

@staticmethod
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
"""
Convert a PyTorch tensor to a NumPy image.
"""
Expand Down Expand Up @@ -497,12 +516,27 @@ def preprocess(
else:
image = np.expand_dims(image, axis=-1)

if isinstance(image, supported_formats):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
FutureWarning,
)
image = np.concatenate(image, axis=0)
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
FutureWarning,
)
image = torch.cat(image, axis=0)

if not is_valid_image_imagelist(image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}"
)
if not isinstance(image, list):
image = [image]

if isinstance(image[0], PIL.Image.Image):
if crops_coords is not None:
Expand Down Expand Up @@ -561,15 +595,15 @@ def preprocess(

def postprocess(
self,
image: torch.FloatTensor,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.

Args:
image (`torch.FloatTensor`):
image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
Expand All @@ -578,7 +612,7 @@ def postprocess(
`VaeImageProcessor` config.

Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -738,15 +772,15 @@ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:

def postprocess(
self,
image: torch.FloatTensor,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.

Args:
image (`torch.FloatTensor`):
image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
Expand All @@ -755,7 +789,7 @@ def postprocess(
`VaeImageProcessor` config.

Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -793,8 +827,8 @@ def postprocess(

def preprocess(
self,
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
target_res: Optional[int] = None,
Expand Down Expand Up @@ -933,13 +967,13 @@ def __init__(
)

@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.

Args:
mask (`torch.FloatTensor`):
mask (`torch.Tensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`):
The batch size.
Expand All @@ -949,7 +983,7 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value
The dimensionality of the value embeddings.

Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The downsampled mask tensor.

"""
Expand Down
29 changes: 4 additions & 25 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,10 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import torch
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
from ...models.lora import adjust_lora_scale_text_encoder
Expand All @@ -41,6 +40,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
Expand All @@ -65,27 +65,6 @@
"""


def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)

outputs.append(batch_output)

if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs


class AnimateDiffPipeline(
DiffusionPipeline,
StableDiffusionMixin,
Expand Down Expand Up @@ -159,7 +138,7 @@ def __init__(
image_encoder=image_encoder,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(do_resize=False, vae_scale_factor=self.vae_scale_factor)

# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt with num_images_per_prompt -> num_videos_per_prompt
def encode_prompt(
Expand Down Expand Up @@ -836,7 +815,7 @@ def __call__(
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# 10. Offload all models
self.maybe_free_model_hooks()
Expand Down
30 changes: 4 additions & 26 deletions src/diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
from transformers import (
CLIPImageProcessor,
Expand All @@ -25,7 +24,7 @@
CLIPVisionModelWithProjection,
)

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...image_processor import PipelineImageInput
from ...loaders import (
FromSingleFileMixin,
IPAdapterMixin,
Expand Down Expand Up @@ -57,6 +56,7 @@
unscale_lora_layers,
)
from ...utils.torch_utils import randn_tensor
from ...video_processor import VideoProcessor
from ..free_init_utils import FreeInitMixin
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
from .pipeline_output import AnimateDiffPipelineOutput
Expand Down Expand Up @@ -113,28 +113,6 @@
"""


# Copied from diffusers.pipelines.animatediff.pipeline_animatediff.tensor2vid
def tensor2vid(video: torch.Tensor, processor: "VaeImageProcessor", output_type: str = "np"):
batch_size, channels, num_frames, height, width = video.shape
outputs = []
for batch_idx in range(batch_size):
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
batch_output = processor.postprocess(batch_vid, output_type)

outputs.append(batch_output)

if output_type == "np":
outputs = np.stack(outputs)

elif output_type == "pt":
outputs = torch.stack(outputs)

elif not output_type == "pil":
raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")

return outputs


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Expand Down Expand Up @@ -320,7 +298,7 @@ def __init__(
)
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor)

self.default_sample_size = self.unet.config.sample_size

Expand Down Expand Up @@ -1291,7 +1269,7 @@ def __call__(
video = latents
else:
video_tensor = self.decode_latents(latents)
video = tensor2vid(video_tensor, self.image_processor, output_type=output_type)
video = self.video_processor.postprocess_video(video=video_tensor, output_type=output_type)

# cast back to fp16 if needed
if needs_upcasting:
Expand Down
Loading