Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 58 additions & 19 deletions src/diffusers/image_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,34 @@
PipelineImageInput = Union[
PIL.Image.Image,
np.ndarray,
torch.FloatTensor,
torch.Tensor,
List[PIL.Image.Image],
List[np.ndarray],
List[torch.FloatTensor],
List[torch.Tensor],
]

PipelineDepthInput = PipelineImageInput


def is_valid_image(image):
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)


def is_valid_image_imagelist(images):
# check if the image input is one of the supported formats for image and image list:
# it can be either one of below 3
# (1) a 4d pytorch tensor or numpy array,
# (2) a valid image: PIL.Image.Image, 2-d np.ndarray or torch.Tensor (grayscale image), 3-d np.ndarray or torch.Tensor
# (3) a list of valid image
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
return True
elif is_valid_image(images):
return True
elif isinstance(images, list):
return all(is_valid_image(image) for image in images)
return False


class VaeImageProcessor(ConfigMixin):
"""
Image processor for VAE.
Expand Down Expand Up @@ -111,7 +130,7 @@ def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.nd
return images

@staticmethod
def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
"""
Convert a NumPy image to a PyTorch tensor.
"""
Expand All @@ -122,7 +141,7 @@ def numpy_to_pt(images: np.ndarray) -> torch.FloatTensor:
return images

@staticmethod
def pt_to_numpy(images: torch.FloatTensor) -> np.ndarray:
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
"""
Convert a PyTorch tensor to a NumPy image.
"""
Expand Down Expand Up @@ -498,9 +517,29 @@ def preprocess(
else:
image = np.expand_dims(image, axis=-1)

if isinstance(image, supported_formats):
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
FutureWarning,
)
image = np.concatenate(image, axis=0)
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
warnings.warn(
"Passing `image` as a list of 4d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
FutureWarning,
)
image = torch.cat(image, axis=0)

# ensure the input is a list of images:
# - if it is a batch of images (4d torch.Tensor or np.ndarray), it is converted to a list of images (a list of 3d torch.Tensor or np.ndarray)
# - if it is a single image, it is converted to a list of one image
if isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim == 4:
image = list(image)
if is_valid_image(image):
image = [image]
elif not (isinstance(image, list) and all(isinstance(i, supported_formats) for i in image)):
if not all(is_valid_image(img) for img in image):
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}"
)
Expand Down Expand Up @@ -562,15 +601,15 @@ def preprocess(

def postprocess(
self,
image: torch.FloatTensor,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.

Args:
image (`torch.FloatTensor`):
image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
Expand All @@ -579,7 +618,7 @@ def postprocess(
`VaeImageProcessor` config.

Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -739,15 +778,15 @@ def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:

def postprocess(
self,
image: torch.FloatTensor,
image: torch.Tensor,
output_type: str = "pil",
do_denormalize: Optional[List[bool]] = None,
) -> Union[PIL.Image.Image, np.ndarray, torch.FloatTensor]:
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
"""
Postprocess the image output from tensor to `output_type`.

Args:
image (`torch.FloatTensor`):
image (`torch.Tensor`):
The image input, should be a pytorch tensor with shape `B x C x H x W`.
output_type (`str`, *optional*, defaults to `pil`):
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
Expand All @@ -756,7 +795,7 @@ def postprocess(
`VaeImageProcessor` config.

Returns:
`PIL.Image.Image`, `np.ndarray` or `torch.FloatTensor`:
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
The postprocessed image.
"""
if not isinstance(image, torch.Tensor):
Expand Down Expand Up @@ -794,8 +833,8 @@ def postprocess(

def preprocess(
self,
rgb: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.FloatTensor, PIL.Image.Image, np.ndarray],
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
height: Optional[int] = None,
width: Optional[int] = None,
target_res: Optional[int] = None,
Expand Down Expand Up @@ -934,13 +973,13 @@ def __init__(
)

@staticmethod
def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value_embed_dim: int):
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
"""
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.

Args:
mask (`torch.FloatTensor`):
mask (`torch.Tensor`):
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
batch_size (`int`):
The batch size.
Expand All @@ -950,7 +989,7 @@ def downsample(mask: torch.FloatTensor, batch_size: int, num_queries: int, value
The dimensionality of the value embeddings.

Returns:
`torch.FloatTensor`:
`torch.Tensor`:
The downsampled mask tensor.

"""
Expand Down
108 changes: 34 additions & 74 deletions src/diffusers/video_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,27 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings
from typing import List, Union

import numpy as np
import PIL
import torch

from .image_processor import VaeImageProcessor
from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist


class VideoProcessor(VaeImageProcessor):
r"""Simple video processor."""

def tensor2vid(
self, video: torch.FloatTensor, output_type: str = "np"
) -> Union[np.ndarray, torch.FloatTensor, List[PIL.Image.Image]]:
self, video: torch.Tensor, output_type: str = "np"
) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]:
r"""
Converts a video tensor to a list of frames for export.

Args:
video (`torch.FloatTensor`): The video as a tensor.
video (`torch.Tensor`): The video as a tensor.
output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor.
"""
batch_size = video.shape[0]
Expand All @@ -50,94 +51,53 @@ def tensor2vid(

return outputs

def preprocess_video(self, video) -> torch.FloatTensor:
def preprocess_video(self, video) -> torch.Tensor:
r"""
Preprocesses input video(s).

Args:
video: The input video. It can be one of the following:
* List of the PIL images.
* List of list of PIL images.
* 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
* 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
* List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)).
* List of list of 4D Torch tensors (expected shape for tensor: (num_frames, num_channels, height,
width)).
* List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)).
* List of list of 4D NumPy arrays (expected shape for each array: (num_frames, height, width,
num_channels)).
* List of 5D NumPy arrays (expected shape for each array: (batch_size, num_frames, height, width,
num_channels).
* List of 5D Torch tensors (expected shape for each array: (batch_size, num_frames, num_channels,
height, width).
* 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width,
num_channels).
* 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height,
width).
"""
supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image, list)

# Single-frame video.
if isinstance(video, supported_formats[:-1]):
video = [video]

# List of PIL images.
elif isinstance(video, list) and isinstance(video[0], PIL.Image.Image):
if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5:
Copy link
Collaborator

@DN6 DN6 May 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it's okay to not accept a list of 5D video. Perhaps just raise an error if a 5D list is passed here with a message asking to concatenate along the batch dimension?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I throw a warning and deprecated it just to be more safe

warnings.warn(
"Passing `video` as a list of 5d np.ndarray is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray",
FutureWarning,
)
video = np.concatenate(video, axis=0)
if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5:
warnings.warn(
"Passing `video` as a list of 5d torch.Tensor is deprecated."
"Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor",
FutureWarning,
)
video = torch.cat(video, axis=0)

# ensure the input is a list of videos:
# - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray)
# - if it is is a single video, it is convereted to a list of one video.
if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5:
video = list(video)
elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video):
video = [video]

elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)):
elif isinstance(video, list) and is_valid_image_imagelist(video[0]):
video = video
else:
raise ValueError(
f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(list(map(str, supported_formats)))}"
"Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image"
)

if isinstance(video[0], np.ndarray):
# When the number of dimension of the first element in `video` is 5, it means
# each element in the `video` list is a video.
video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0)

if video.ndim == 4:
video = video[None, ...]

elif isinstance(video[0], torch.Tensor):
video = torch.cat(video, dim=0) if video[0].ndim == 5 else torch.stack(video, dim=0)

# don't need any preprocess if the video is latents
channel = video.shape[1]
if channel == 4:
return video

# List of 5d tensors/ndarrays.
elif isinstance(video[0], list):
if isinstance(video[0][0], (np.ndarray, torch.Tensor)):
all_frames = []
for list_of_videos in video:
temp_frames = []
for vid in list_of_videos:
if vid.ndim == 4:
current_vid_frames = np.stack(vid, axis=0) if isinstance(vid, np.ndarray) else vid
elif vid.ndim == 5:
current_vid_frames = (
np.concatenate(vid, axis=0) if isinstance(vid, np.ndarray) else torch.cat(vid, dim=0)
)
temp_frames.append(current_vid_frames)

# Process inner list.
temp_frames = (
np.stack(temp_frames, axis=0)
if isinstance(temp_frames[0], np.ndarray)
else torch.stack(temp_frames, axis=0)
)
all_frames.append(temp_frames)

# Process outer list.
video = (
np.concatenate(all_frames, axis=0)
if isinstance(all_frames[0], np.ndarray)
else torch.cat(all_frames, dim=0)
)

# `preprocess()` here would return a PT tensor.
video = torch.stack([self.preprocess(f) for f in video], dim=0)

# move channels before num_frames
video = torch.stack([self.preprocess(img) for img in video], dim=0)
video = video.permute(0, 2, 1, 3, 4)

return video
4 changes: 2 additions & 2 deletions tests/others/test_video_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_video_processor_pil(self, input_type):
input_np = self.to_np(input).astype("float32") / 255.0 if output_type != "pil" else self.to_np(input)
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"

@parameterized.expand(["list_4d_np", "list_list_4d_np", "list_5d_np", "5d_np"])
@parameterized.expand(["list_4d_np", "list_5d_np", "5d_np"])
def test_video_processor_np(self, input_type):
video_processor = VideoProcessor(do_resize=False, do_normalize=True)

Expand All @@ -154,7 +154,7 @@ def test_video_processor_np(self, input_type):
)
assert np.abs(input_np - out_np).max() < 1e-6, f"Decoded output does not match input for {output_type=}"

@parameterized.expand(["list_4d_pt", "list_list_4d_pt", "list_5d_pt", "5d_pt"])
@parameterized.expand(["list_4d_pt", "list_5d_pt", "5d_pt"])
def test_video_processor_pt(self, input_type):
video_processor = VideoProcessor(do_resize=False, do_normalize=True)

Expand Down