diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 9cfc652ad..c61872b7c 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -506,7 +506,7 @@ void SingleStreamDecoder::addVideoStream( if (seekMode_ == SeekMode::custom_frame_mappings) { TORCH_CHECK( customFrameMappings.has_value(), - "Please provide frame mappings when using custom_frame_mappings seek mode."); + "Missing frame mappings when custom_frame_mappings seek mode is set."); readCustomFrameMappingsUpdateMetadataAndIndex( streamIndex, customFrameMappings.value()); } diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 0a030dbd0..8c1152e8c 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import io +import json import numbers from pathlib import Path from typing import Literal, Optional, Tuple, Union @@ -62,7 +63,25 @@ class VideoDecoder: probably is. Default: "exact". Read more about this parameter in: :ref:`sphx_glr_generated_examples_decoding_approximate_mode.py` - + custom_frame_mappings (str, bytes, or file-like object, optional): + Mapping of frames to their metadata, typically generated via ffprobe. + This enables accurate frame seeking without requiring a full video scan. + Do not set seek_mode when custom_frame_mappings is provided. + Expected JSON format: + + .. code-block:: json + + { + "frames": [ + { + "pts": 0, + "duration": 1001, + "key_frame": 1 + } + ] + } + + Alternative field names "pkt_pts" and "pkt_duration" are also supported. Attributes: metadata (VideoStreamMetadata): Metadata of the video stream. @@ -80,6 +99,9 @@ def __init__( num_ffmpeg_threads: int = 1, device: Optional[Union[str, torch_device]] = "cpu", seek_mode: Literal["exact", "approximate"] = "exact", + custom_frame_mappings: Optional[ + Union[str, bytes, io.RawIOBase, io.BufferedReader] + ] = None, ): torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder") allowed_seek_modes = ("exact", "approximate") @@ -89,6 +111,21 @@ def __init__( f"Supported values are {', '.join(allowed_seek_modes)}." ) + # Validate seek_mode and custom_frame_mappings are not mismatched + if custom_frame_mappings is not None and seek_mode == "approximate": + raise ValueError( + "custom_frame_mappings is incompatible with seek_mode='approximate'. " + "Use seek_mode='custom_frame_mappings' or leave it unspecified to automatically use custom frame mappings." + ) + + # Auto-select custom_frame_mappings seek_mode and process data when mappings are provided + custom_frame_mappings_data = None + if custom_frame_mappings is not None: + seek_mode = "custom_frame_mappings" # type: ignore[assignment] + custom_frame_mappings_data = _read_custom_frame_mappings( + custom_frame_mappings + ) + self._decoder = create_decoder(source=source, seek_mode=seek_mode) allowed_dimension_orders = ("NCHW", "NHWC") @@ -110,6 +147,7 @@ def __init__( dimension_order=dimension_order, num_threads=num_ffmpeg_threads, device=device, + custom_frame_mappings=custom_frame_mappings_data, ) ( @@ -379,3 +417,57 @@ def _get_and_validate_stream_metadata( end_stream_seconds, num_frames, ) + + +def _read_custom_frame_mappings( + custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader] +) -> tuple[Tensor, Tensor, Tensor]: + """Parse custom frame mappings from JSON data and extract frame metadata. + + Args: + custom_frame_mappings: JSON data containing frame metadata, provided as: + - A JSON string (str, bytes) + - A file-like object with a read() method + + Returns: + A tuple of three tensors: + - all_frames (Tensor): Presentation timestamps (PTS) for each frame + - is_key_frame (Tensor): Boolean tensor indicating which frames are key frames + - duration (Tensor): Duration of each frame + """ + try: + input_data = ( + json.load(custom_frame_mappings) + if hasattr(custom_frame_mappings, "read") + else json.loads(custom_frame_mappings) + ) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid custom frame mappings: {e}. It should be a valid JSON string or a file-like object." + ) from e + + if not input_data or "frames" not in input_data: + raise ValueError( + "Invalid custom frame mappings. The input is empty or missing the required 'frames' key." + ) + + first_frame = input_data["frames"][0] + pts_key = next((key for key in ("pts", "pkt_pts") if key in first_frame), None) + duration_key = next( + (key for key in ("duration", "pkt_duration") if key in first_frame), None + ) + key_frame_present = "key_frame" in first_frame + + if not pts_key or not duration_key or not key_frame_present: + raise ValueError( + "Invalid custom frame mappings. The 'pts'/'pkt_pts', 'duration'/'pkt_duration', and 'key_frame' keys are required in the frame metadata." + ) + + frame_data = [ + (float(frame[pts_key]), frame["key_frame"], float(frame[duration_key])) + for frame in input_data["frames"] + ] + all_frames, is_key_frame, duration = map(torch.tensor, zip(*frame_data)) + if not (len(all_frames) == len(is_key_frame) == len(duration)): + raise ValueError("Mismatched lengths in frame index data") + return all_frames, is_key_frame, duration diff --git a/test/test_decoders.py b/test/test_decoders.py index d3f5ebc00..bf3482a11 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -7,6 +7,7 @@ import contextlib import gc import json +from functools import partial from unittest.mock import patch import numpy @@ -1279,6 +1280,112 @@ def test_10bit_videos_cpu(self, asset): decoder = VideoDecoder(asset.path) decoder.get_frame_at(10) + def setup_frame_mappings(tmp_path, file, stream_index): + json_path = tmp_path / "custom_frame_mappings.json" + custom_frame_mappings = NASA_VIDEO.generate_custom_frame_mappings(stream_index) + if file: + # Write the custom frame mappings to a JSON file + with open(json_path, "w") as f: + f.write(custom_frame_mappings) + return json_path + else: + # Return the custom frame mappings as a JSON string + return custom_frame_mappings + + @pytest.mark.parametrize("device", all_supported_devices()) + @pytest.mark.parametrize("stream_index", [0, 3]) + @pytest.mark.parametrize( + "method", + ( + partial(setup_frame_mappings, file=True), + partial(setup_frame_mappings, file=False), + ), + ) + def test_custom_frame_mappings_json_and_bytes( + self, tmp_path, device, stream_index, method + ): + custom_frame_mappings = method(tmp_path=tmp_path, stream_index=stream_index) + # Optionally open the custom frame mappings file if it is a file path + # or use a null context if it is a string. + with ( + open(custom_frame_mappings, "r") + if hasattr(custom_frame_mappings, "read") + else contextlib.nullcontext() + ) as custom_frame_mappings: + decoder = VideoDecoder( + NASA_VIDEO.path, + stream_index=stream_index, + device=device, + custom_frame_mappings=custom_frame_mappings, + ) + frame_0 = decoder.get_frame_at(0) + frame_5 = decoder.get_frame_at(5) + assert_frames_equal( + frame_0.data, + NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index).to(device), + ) + assert_frames_equal( + frame_5.data, + NASA_VIDEO.get_frame_data_by_index(5, stream_index=stream_index).to(device), + ) + frames0_5 = decoder.get_frames_played_in_range( + frame_0.pts_seconds, frame_5.pts_seconds + ) + assert_frames_equal( + frames0_5.data, + NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to( + device + ), + ) + + @pytest.mark.parametrize("device", all_supported_devices()) + @pytest.mark.parametrize( + "custom_frame_mappings,expected_match", + [ + (NASA_VIDEO.generate_custom_frame_mappings(0), "seek_mode"), + ("{}", "The input is empty or missing the required 'frames' key."), + ( + '{"valid": "json"}', + "The input is empty or missing the required 'frames' key.", + ), + ( + '{"frames": [{"missing": "keys"}]}', + "keys are required in the frame metadata.", + ), + ], + ) + def test_custom_frame_mappings_init_fails( + self, device, custom_frame_mappings, expected_match + ): + with pytest.raises(ValueError, match=expected_match): + VideoDecoder( + NASA_VIDEO.path, + stream_index=0, + device=device, + custom_frame_mappings=custom_frame_mappings, + seek_mode=("approximate" if expected_match == "seek_mode" else "exact"), + ) + + @pytest.mark.parametrize("device", all_supported_devices()) + def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device): + invalid_json_path = tmp_path / "invalid_json" + with open(invalid_json_path, "w+") as f: + f.write("invalid input") + + # Test both file object and string + with open(invalid_json_path, "r") as file_obj: + for custom_frame_mappings in [ + file_obj, + file_obj.read(), + ]: + with pytest.raises(ValueError, match="Invalid custom frame mappings"): + VideoDecoder( + NASA_VIDEO.path, + stream_index=0, + device=device, + custom_frame_mappings=custom_frame_mappings, + ) + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32)) diff --git a/test/test_ops.py b/test/test_ops.py index d2f2fd3b1..f7d11a84d 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -44,7 +44,6 @@ from .utils import ( all_supported_devices, assert_frames_equal, - get_ffmpeg_major_version, NASA_AUDIO, NASA_AUDIO_MP3, NASA_VIDEO, @@ -485,7 +484,7 @@ def test_seek_mode_custom_frame_mappings_fails(self): ) with pytest.raises( RuntimeError, - match="Please provide frame mappings when using custom_frame_mappings seek mode.", + match="Missing frame mappings when custom_frame_mappings seek mode is set.", ): add_video_stream(decoder, stream_index=0, custom_frame_mappings=None) @@ -505,10 +504,6 @@ def test_seek_mode_custom_frame_mappings_fails(self): decoder, stream_index=0, custom_frame_mappings=different_lengths ) - @pytest.mark.skipif( - get_ffmpeg_major_version() in (4, 5), - reason="ffprobe isn't accurate on ffmpeg 4 and 5", - ) @pytest.mark.parametrize("device", all_supported_devices()) def test_seek_mode_custom_frame_mappings(self, device): stream_index = 3 # custom_frame_index seek mode requires a stream index diff --git a/test/utils.py b/test/utils.py index ed611cfda..4cba27507 100644 --- a/test/utils.py +++ b/test/utils.py @@ -14,6 +14,7 @@ import torch from torchcodec._core import get_ffmpeg_library_versions +from torchcodec.decoders._video_decoder import _read_custom_frame_mappings # Decorator for skipping CUDA tests when CUDA isn't available. The tests are @@ -267,40 +268,30 @@ def get_custom_frame_mappings( if stream_index is None: stream_index = self.default_stream_index if self._custom_frame_mappings_data.get(stream_index) is None: - self.generate_custom_frame_mappings(stream_index) + self._custom_frame_mappings_data[stream_index] = ( + _read_custom_frame_mappings( + self.generate_custom_frame_mappings(stream_index) + ) + ) return self._custom_frame_mappings_data[stream_index] - def generate_custom_frame_mappings(self, stream_index: int) -> None: - result = json.loads( - subprocess.run( - [ - "ffprobe", - "-i", - f"{self.path}", - "-select_streams", - f"{stream_index}", - "-show_frames", - "-of", - "json", - ], - check=True, - capture_output=True, - text=True, - ).stdout - ) - all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]]) - is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]]) - duration = torch.tensor( - [float(frame["duration"]) for frame in result["frames"]] - ) - assert ( - len(all_frames) == len(is_key_frame) == len(duration) - ), "Mismatched lengths in frame index data" - self._custom_frame_mappings_data[stream_index] = ( - all_frames, - is_key_frame, - duration, - ) + def generate_custom_frame_mappings(self, stream_index: int) -> str: + result = subprocess.run( + [ + "ffprobe", + "-i", + f"{self.path}", + "-select_streams", + f"{stream_index}", + "-show_frames", + "-of", + "json", + ], + check=True, + capture_output=True, + text=True, + ).stdout + return result @property def empty_pts_seconds(self) -> torch.Tensor: