-
Notifications
You must be signed in to change notification settings - Fork 6.5k
refactor video processor (part # 7776) #7861
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
7ca8759
feb561f
2721d9c
964508b
d59a596
7a09ea3
90fac4f
3967cca
02594f2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -12,26 +12,27 @@ | |
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import warnings | ||
| from typing import List, Union | ||
|
|
||
| import numpy as np | ||
| import PIL | ||
| import torch | ||
|
|
||
| from .image_processor import VaeImageProcessor | ||
| from .image_processor import VaeImageProcessor, is_valid_image, is_valid_image_imagelist | ||
|
|
||
|
|
||
| class VideoProcessor(VaeImageProcessor): | ||
| r"""Simple video processor.""" | ||
|
|
||
| def tensor2vid( | ||
| self, video: torch.FloatTensor, output_type: str = "np" | ||
| ) -> Union[np.ndarray, torch.FloatTensor, List[PIL.Image.Image]]: | ||
| self, video: torch.Tensor, output_type: str = "np" | ||
| ) -> Union[np.ndarray, torch.Tensor, List[PIL.Image.Image]]: | ||
| r""" | ||
| Converts a video tensor to a list of frames for export. | ||
|
|
||
| Args: | ||
| video (`torch.FloatTensor`): The video as a tensor. | ||
| video (`torch.Tensor`): The video as a tensor. | ||
| output_type (`str`, defaults to `"np"`): Output type of the postprocessed `video` tensor. | ||
| """ | ||
| batch_size = video.shape[0] | ||
|
|
@@ -50,94 +51,53 @@ def tensor2vid( | |
|
|
||
| return outputs | ||
|
|
||
| def preprocess_video(self, video) -> torch.FloatTensor: | ||
| def preprocess_video(self, video) -> torch.Tensor: | ||
| r""" | ||
| Preprocesses input video(s). | ||
|
|
||
| Args: | ||
| video: The input video. It can be one of the following: | ||
| * List of the PIL images. | ||
| * List of list of PIL images. | ||
| * 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). | ||
| * 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). | ||
| * List of 4D Torch tensors (expected shape for each tensor: (num_frames, num_channels, height, width)). | ||
| * List of list of 4D Torch tensors (expected shape for tensor: (num_frames, num_channels, height, | ||
| width)). | ||
| * List of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, num_channels)). | ||
| * List of list of 4D NumPy arrays (expected shape for each array: (num_frames, height, width, | ||
| num_channels)). | ||
| * List of 5D NumPy arrays (expected shape for each array: (batch_size, num_frames, height, width, | ||
| num_channels). | ||
| * List of 5D Torch tensors (expected shape for each array: (batch_size, num_frames, num_channels, | ||
| height, width). | ||
| * 5D NumPy arrays: expected shape for each array: (batch_size, num_frames, height, width, | ||
| num_channels). | ||
| * 5D Torch tensors: expected shape for each array: (batch_size, num_frames, num_channels, height, | ||
| width). | ||
| """ | ||
| supported_formats = (np.ndarray, torch.Tensor, PIL.Image.Image, list) | ||
|
|
||
| # Single-frame video. | ||
| if isinstance(video, supported_formats[:-1]): | ||
| video = [video] | ||
|
|
||
| # List of PIL images. | ||
| elif isinstance(video, list) and isinstance(video[0], PIL.Image.Image): | ||
| if isinstance(video, list) and isinstance(video[0], np.ndarray) and video[0].ndim == 5: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Think it's okay to not accept a list of 5D video. Perhaps just raise an error if a 5D list is passed here with a message asking to concatenate along the batch dimension?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I throw a warning and deprecated it just to be more safe |
||
| warnings.warn( | ||
| "Passing `video` as a list of 5d np.ndarray is deprecated." | ||
| "Please concatenate the list along the batch dimension and pass it as a single 5d np.ndarray", | ||
| FutureWarning, | ||
| ) | ||
| video = np.concatenate(video, axis=0) | ||
| if isinstance(video, list) and isinstance(video[0], torch.Tensor) and video[0].ndim == 5: | ||
| warnings.warn( | ||
| "Passing `video` as a list of 5d torch.Tensor is deprecated." | ||
| "Please concatenate the list along the batch dimension and pass it as a single 5d torch.Tensor", | ||
| FutureWarning, | ||
| ) | ||
| video = torch.cat(video, axis=0) | ||
|
|
||
| # ensure the input is a list of videos: | ||
| # - if it is a batch of videos (5d torch.Tensor or np.ndarray), it is converted to a list of videos (a list of 4d torch.Tensor or np.ndarray) | ||
| # - if it is is a single video, it is convereted to a list of one video. | ||
| if isinstance(video, (np.ndarray, torch.Tensor)) and video.ndim == 5: | ||
| video = list(video) | ||
| elif isinstance(video, list) and is_valid_image(video[0]) or is_valid_image_imagelist(video): | ||
| video = [video] | ||
|
|
||
| elif not (isinstance(video, list) and all(isinstance(i, supported_formats) for i in video)): | ||
| elif isinstance(video, list) and is_valid_image_imagelist(video[0]): | ||
| video = video | ||
| else: | ||
| raise ValueError( | ||
| f"Input is in incorrect format: {[type(i) for i in video]}. Currently, we only support {', '.join(list(map(str, supported_formats)))}" | ||
| "Input is in incorrect format. Currently, we only support numpy.ndarray, torch.Tensor, PIL.Image.Image" | ||
| ) | ||
|
|
||
| if isinstance(video[0], np.ndarray): | ||
| # When the number of dimension of the first element in `video` is 5, it means | ||
| # each element in the `video` list is a video. | ||
| video = np.concatenate(video, axis=0) if video[0].ndim == 5 else np.stack(video, axis=0) | ||
|
|
||
| if video.ndim == 4: | ||
| video = video[None, ...] | ||
|
|
||
| elif isinstance(video[0], torch.Tensor): | ||
| video = torch.cat(video, dim=0) if video[0].ndim == 5 else torch.stack(video, dim=0) | ||
|
|
||
| # don't need any preprocess if the video is latents | ||
| channel = video.shape[1] | ||
| if channel == 4: | ||
| return video | ||
|
|
||
| # List of 5d tensors/ndarrays. | ||
| elif isinstance(video[0], list): | ||
| if isinstance(video[0][0], (np.ndarray, torch.Tensor)): | ||
| all_frames = [] | ||
| for list_of_videos in video: | ||
| temp_frames = [] | ||
| for vid in list_of_videos: | ||
| if vid.ndim == 4: | ||
| current_vid_frames = np.stack(vid, axis=0) if isinstance(vid, np.ndarray) else vid | ||
| elif vid.ndim == 5: | ||
| current_vid_frames = ( | ||
| np.concatenate(vid, axis=0) if isinstance(vid, np.ndarray) else torch.cat(vid, dim=0) | ||
| ) | ||
| temp_frames.append(current_vid_frames) | ||
|
|
||
| # Process inner list. | ||
| temp_frames = ( | ||
| np.stack(temp_frames, axis=0) | ||
| if isinstance(temp_frames[0], np.ndarray) | ||
| else torch.stack(temp_frames, axis=0) | ||
| ) | ||
| all_frames.append(temp_frames) | ||
|
|
||
| # Process outer list. | ||
| video = ( | ||
| np.concatenate(all_frames, axis=0) | ||
| if isinstance(all_frames[0], np.ndarray) | ||
| else torch.cat(all_frames, dim=0) | ||
| ) | ||
|
|
||
| # `preprocess()` here would return a PT tensor. | ||
| video = torch.stack([self.preprocess(f) for f in video], dim=0) | ||
|
|
||
| # move channels before num_frames | ||
| video = torch.stack([self.preprocess(img) for img in video], dim=0) | ||
| video = video.permute(0, 2, 1, 3, 4) | ||
|
|
||
| return video | ||
Uh oh!
There was an error while loading. Please reload this page.