Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ void SingleStreamDecoder::addVideoStream(
if (seekMode_ == SeekMode::custom_frame_mappings) {
TORCH_CHECK(
customFrameMappings.has_value(),
"Please provide frame mappings when using custom_frame_mappings seek mode.");
"Missing frame mappings when custom_frame_mappings seek mode is set.");
readCustomFrameMappingsUpdateMetadataAndIndex(
streamIndex, customFrameMappings.value());
}
Expand Down
94 changes: 93 additions & 1 deletion src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import io
import json
import numbers
from pathlib import Path
from typing import Literal, Optional, Tuple, Union
Expand Down Expand Up @@ -62,7 +63,25 @@ class VideoDecoder:
probably is. Default: "exact".
Read more about this parameter in:
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`

custom_frame_mappings (str, bytes, or file-like object, optional):
Mapping of frames to their metadata, typically generated via ffprobe.
This enables accurate frame seeking without requiring a full video scan.
Do not set seek_mode when custom_frame_mappings is provided.
Expected JSON format:

.. code-block:: json

{
"frames": [
{
"pts": 0,
"duration": 1001,
"key_frame": 1
}
]
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also say something about the relationship with seek_mode.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also add that in order to accomodate for different FFmpeg versions we also allow pkt_pts and pkt_duration ?


Alternative field names "pkt_pts" and "pkt_duration" are also supported.

Attributes:
metadata (VideoStreamMetadata): Metadata of the video stream.
Expand All @@ -80,6 +99,9 @@ def __init__(
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
] = None,
):
torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder")
allowed_seek_modes = ("exact", "approximate")
Expand All @@ -89,6 +111,21 @@ def __init__(
f"Supported values are {', '.join(allowed_seek_modes)}."
)

# Validate seek_mode and custom_frame_mappings are not mismatched
if custom_frame_mappings is not None and seek_mode == "approximate":
raise ValueError(
"custom_frame_mappings is incompatible with seek_mode='approximate'. "
"Use seek_mode='custom_frame_mappings' or leave it unspecified to automatically use custom frame mappings."
)

# Auto-select custom_frame_mappings seek_mode and process data when mappings are provided
custom_frame_mappings_data = None
if custom_frame_mappings is not None:
seek_mode = "custom_frame_mappings" # type: ignore[assignment]
custom_frame_mappings_data = _read_custom_frame_mappings(
custom_frame_mappings
)

self._decoder = create_decoder(source=source, seek_mode=seek_mode)

allowed_dimension_orders = ("NCHW", "NHWC")
Expand All @@ -110,6 +147,7 @@ def __init__(
dimension_order=dimension_order,
num_threads=num_ffmpeg_threads,
device=device,
custom_frame_mappings=custom_frame_mappings_data,
)

(
Expand Down Expand Up @@ -379,3 +417,57 @@ def _get_and_validate_stream_metadata(
end_stream_seconds,
num_frames,
)


def _read_custom_frame_mappings(
custom_frame_mappings: Union[str, bytes, io.RawIOBase, io.BufferedReader]
) -> tuple[Tensor, Tensor, Tensor]:
"""Parse custom frame mappings from JSON data and extract frame metadata.

Args:
custom_frame_mappings: JSON data containing frame metadata, provided as:
- A JSON string (str, bytes)
- A file-like object with a read() method

Returns:
A tuple of three tensors:
- all_frames (Tensor): Presentation timestamps (PTS) for each frame
- is_key_frame (Tensor): Boolean tensor indicating which frames are key frames
- duration (Tensor): Duration of each frame
"""
try:
input_data = (
json.load(custom_frame_mappings)
if hasattr(custom_frame_mappings, "read")
else json.loads(custom_frame_mappings)
)
except json.JSONDecodeError as e:
raise ValueError(
f"Invalid custom frame mappings: {e}. It should be a valid JSON string or a file-like object."
) from e

if not input_data or "frames" not in input_data:
raise ValueError(
"Invalid custom frame mappings. The input is empty or missing the required 'frames' key."
)

first_frame = input_data["frames"][0]
pts_key = next((key for key in ("pts", "pkt_pts") if key in first_frame), None)
duration_key = next(
(key for key in ("duration", "pkt_duration") if key in first_frame), None
)
key_frame_present = "key_frame" in first_frame

if not pts_key or not duration_key or not key_frame_present:
raise ValueError(
"Invalid custom frame mappings. The 'pts'/'pkt_pts', 'duration'/'pkt_duration', and 'key_frame' keys are required in the frame metadata."
)

frame_data = [
(float(frame[pts_key]), frame["key_frame"], float(frame[duration_key]))
for frame in input_data["frames"]
]
all_frames, is_key_frame, duration = map(torch.tensor, zip(*frame_data))
if not (len(all_frames) == len(is_key_frame) == len(duration)):
raise ValueError("Mismatched lengths in frame index data")
return all_frames, is_key_frame, duration
107 changes: 107 additions & 0 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import contextlib
import gc
import json
from functools import partial
from unittest.mock import patch

import numpy
Expand Down Expand Up @@ -1279,6 +1280,112 @@ def test_10bit_videos_cpu(self, asset):
decoder = VideoDecoder(asset.path)
decoder.get_frame_at(10)

def setup_frame_mappings(tmp_path, file, stream_index):
json_path = tmp_path / "custom_frame_mappings.json"
custom_frame_mappings = NASA_VIDEO.generate_custom_frame_mappings(stream_index)
if file:
# Write the custom frame mappings to a JSON file
with open(json_path, "w") as f:
f.write(custom_frame_mappings)
return json_path
else:
# Return the custom frame mappings as a JSON string
return custom_frame_mappings

@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("stream_index", [0, 3])
@pytest.mark.parametrize(
"method",
(
partial(setup_frame_mappings, file=True),
partial(setup_frame_mappings, file=False),
),
)
def test_custom_frame_mappings_json_and_bytes(
self, tmp_path, device, stream_index, method
):
custom_frame_mappings = method(tmp_path=tmp_path, stream_index=stream_index)
# Optionally open the custom frame mappings file if it is a file path
# or use a null context if it is a string.
with (
open(custom_frame_mappings, "r")
if hasattr(custom_frame_mappings, "read")
else contextlib.nullcontext()
) as custom_frame_mappings:
decoder = VideoDecoder(
NASA_VIDEO.path,
stream_index=stream_index,
device=device,
custom_frame_mappings=custom_frame_mappings,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: everything below can probably be out of the context manager. It's preferable to end the CM's scope as soon as it's not needed.

frame_0 = decoder.get_frame_at(0)
frame_5 = decoder.get_frame_at(5)
assert_frames_equal(
frame_0.data,
NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index).to(device),
)
assert_frames_equal(
frame_5.data,
NASA_VIDEO.get_frame_data_by_index(5, stream_index=stream_index).to(device),
)
frames0_5 = decoder.get_frames_played_in_range(
frame_0.pts_seconds, frame_5.pts_seconds
)
assert_frames_equal(
frames0_5.data,
NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to(
device
),
)

@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize(
"custom_frame_mappings,expected_match",
[
(NASA_VIDEO.generate_custom_frame_mappings(0), "seek_mode"),
("{}", "The input is empty or missing the required 'frames' key."),
(
'{"valid": "json"}',
"The input is empty or missing the required 'frames' key.",
),
(
'{"frames": [{"missing": "keys"}]}',
"keys are required in the frame metadata.",
),
],
)
def test_custom_frame_mappings_init_fails(
self, device, custom_frame_mappings, expected_match
):
with pytest.raises(ValueError, match=expected_match):
VideoDecoder(
NASA_VIDEO.path,
stream_index=0,
device=device,
custom_frame_mappings=custom_frame_mappings,
seek_mode=("approximate" if expected_match == "seek_mode" else "exact"),
)

@pytest.mark.parametrize("device", all_supported_devices())
def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
invalid_json_path = tmp_path / "invalid_json"
with open(invalid_json_path, "w+") as f:
f.write("invalid input")

# Test both file object and string
with open(invalid_json_path, "r") as file_obj:
for custom_frame_mappings in [
file_obj,
file_obj.read(),
]:
with pytest.raises(ValueError, match="Invalid custom frame mappings"):
VideoDecoder(
NASA_VIDEO.path,
stream_index=0,
device=device,
custom_frame_mappings=custom_frame_mappings,
)


class TestAudioDecoder:
@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))
Expand Down
7 changes: 1 addition & 6 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@
from .utils import (
all_supported_devices,
assert_frames_equal,
get_ffmpeg_major_version,
NASA_AUDIO,
NASA_AUDIO_MP3,
NASA_VIDEO,
Expand Down Expand Up @@ -485,7 +484,7 @@ def test_seek_mode_custom_frame_mappings_fails(self):
)
with pytest.raises(
RuntimeError,
match="Please provide frame mappings when using custom_frame_mappings seek mode.",
match="Missing frame mappings when custom_frame_mappings seek mode is set.",
):
add_video_stream(decoder, stream_index=0, custom_frame_mappings=None)

Expand All @@ -505,10 +504,6 @@ def test_seek_mode_custom_frame_mappings_fails(self):
decoder, stream_index=0, custom_frame_mappings=different_lengths
)

@pytest.mark.skipif(
get_ffmpeg_major_version() in (4, 5),
reason="ffprobe isn't accurate on ffmpeg 4 and 5",
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ffprobe does include the fields pts and duration in older versions, using the name pkt_pts and pkt_duration, so we can enable these tests by falling back to the pts... keys.

@pytest.mark.parametrize("device", all_supported_devices())
def test_seek_mode_custom_frame_mappings(self, device):
stream_index = 3 # custom_frame_index seek mode requires a stream index
Expand Down
55 changes: 23 additions & 32 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import torch

from torchcodec._core import get_ffmpeg_library_versions
from torchcodec.decoders._video_decoder import _read_custom_frame_mappings


# Decorator for skipping CUDA tests when CUDA isn't available. The tests are
Expand Down Expand Up @@ -267,40 +268,30 @@ def get_custom_frame_mappings(
if stream_index is None:
stream_index = self.default_stream_index
if self._custom_frame_mappings_data.get(stream_index) is None:
self.generate_custom_frame_mappings(stream_index)
self._custom_frame_mappings_data[stream_index] = (
_read_custom_frame_mappings(
self.generate_custom_frame_mappings(stream_index)
)
)
return self._custom_frame_mappings_data[stream_index]

def generate_custom_frame_mappings(self, stream_index: int) -> None:
result = json.loads(
subprocess.run(
[
"ffprobe",
"-i",
f"{self.path}",
"-select_streams",
f"{stream_index}",
"-show_frames",
"-of",
"json",
],
check=True,
capture_output=True,
text=True,
).stdout
)
all_frames = torch.tensor([float(frame["pts"]) for frame in result["frames"]])
is_key_frame = torch.tensor([frame["key_frame"] for frame in result["frames"]])
duration = torch.tensor(
[float(frame["duration"]) for frame in result["frames"]]
)
assert (
len(all_frames) == len(is_key_frame) == len(duration)
), "Mismatched lengths in frame index data"
self._custom_frame_mappings_data[stream_index] = (
all_frames,
is_key_frame,
duration,
)
def generate_custom_frame_mappings(self, stream_index: int) -> str:
result = subprocess.run(
[
"ffprobe",
"-i",
f"{self.path}",
"-select_streams",
f"{stream_index}",
"-show_frames",
"-of",
"json",
],
check=True,
capture_output=True,
text=True,
).stdout
return result

@property
def empty_pts_seconds(self) -> torch.Tensor:
Expand Down
Loading