diff --git a/src/transformers/video_processing_utils.py b/src/transformers/video_processing_utils.py index 9813a81b4e81..184860978d0e 100644 --- a/src/transformers/video_processing_utils.py +++ b/src/transformers/video_processing_utils.py @@ -57,12 +57,12 @@ VideoInput, VideoMetadata, group_videos_by_shape, + infer_channel_dimension_format, is_valid_video, load_video, make_batched_metadata, make_batched_videos, reorder_videos, - to_channel_dimension_format, ) @@ -338,10 +338,16 @@ def _prepare_input_videos( for video in videos: # `make_batched_videos` always returns a 4D array per video if isinstance(video, np.ndarray): - video = to_channel_dimension_format(video, ChannelDimension.FIRST, input_data_format) # not using F.to_tensor as it doesn't handle (C, H, W) numpy arrays video = torch.from_numpy(video).contiguous() + # Infer the channel dimension format if not provided + if input_data_format is None: + input_data_format = infer_channel_dimension_format(video) + + if input_data_format == ChannelDimension.LAST: + video = video.permute(0, 3, 1, 2).contiguous() + if device is not None: video = video.to(device)