Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 1 addition & 22 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,25 +63,6 @@ class VideoDecoder:
probably is. Default: "exact".
Read more about this parameter in:
:ref:`sphx_glr_generated_examples_decoding_approximate_mode.py`
custom_frame_mappings (str, bytes, or file-like object, optional):
Mapping of frames to their metadata, typically generated via ffprobe.
This enables accurate frame seeking without requiring a full video scan.
Do not set seek_mode when custom_frame_mappings is provided.
Expected JSON format:

.. code-block:: json

{
"frames": [
{
"pts": 0,
"duration": 1001,
"key_frame": 1
}
]
}

Alternative field names "pkt_pts" and "pkt_duration" are also supported.

Attributes:
metadata (VideoStreamMetadata): Metadata of the video stream.
Expand All @@ -99,9 +80,6 @@ def __init__(
num_ffmpeg_threads: int = 1,
device: Optional[Union[str, torch_device]] = "cpu",
seek_mode: Literal["exact", "approximate"] = "exact",
custom_frame_mappings: Optional[
Union[str, bytes, io.RawIOBase, io.BufferedReader]
] = None,
):
torch._C._log_api_usage_once("torchcodec.decoders.VideoDecoder")
allowed_seek_modes = ("exact", "approximate")
Expand All @@ -111,6 +89,7 @@ def __init__(
f"Supported values are {', '.join(allowed_seek_modes)}."
)

custom_frame_mappings = None
# Validate seek_mode and custom_frame_mappings are not mismatched
if custom_frame_mappings is not None and seek_mode == "approximate":
raise ValueError(
Expand Down
219 changes: 109 additions & 110 deletions test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import contextlib
import gc
import json
from functools import partial
from unittest.mock import patch

import numpy
Expand Down Expand Up @@ -1280,115 +1279,115 @@ def test_10bit_videos_cpu(self, asset):
decoder = VideoDecoder(asset.path)
decoder.get_frame_at(10)

def setup_frame_mappings(tmp_path, file, stream_index):
json_path = tmp_path / "custom_frame_mappings.json"
custom_frame_mappings = NASA_VIDEO.generate_custom_frame_mappings(stream_index)
if file:
# Write the custom frame mappings to a JSON file
with open(json_path, "w") as f:
f.write(custom_frame_mappings)
return json_path
else:
# Return the custom frame mappings as a JSON string
return custom_frame_mappings

@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize("stream_index", [0, 3])
@pytest.mark.parametrize(
"method",
(
partial(setup_frame_mappings, file=True),
partial(setup_frame_mappings, file=False),
),
)
def test_custom_frame_mappings_json_and_bytes(
self, tmp_path, device, stream_index, method
):
custom_frame_mappings = method(tmp_path=tmp_path, stream_index=stream_index)
# Optionally open the custom frame mappings file if it is a file path
# or use a null context if it is a string.
with (
open(custom_frame_mappings, "r")
if hasattr(custom_frame_mappings, "read")
else contextlib.nullcontext()
) as custom_frame_mappings:
decoder = VideoDecoder(
NASA_VIDEO.path,
stream_index=stream_index,
device=device,
custom_frame_mappings=custom_frame_mappings,
)
frame_0 = decoder.get_frame_at(0)
frame_5 = decoder.get_frame_at(5)
assert_frames_equal(
frame_0.data,
NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index).to(device),
)
assert_frames_equal(
frame_5.data,
NASA_VIDEO.get_frame_data_by_index(5, stream_index=stream_index).to(device),
)
frames0_5 = decoder.get_frames_played_in_range(
frame_0.pts_seconds, frame_5.pts_seconds
)
assert_frames_equal(
frames0_5.data,
NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to(
device
),
)

@pytest.mark.parametrize("device", all_supported_devices())
@pytest.mark.parametrize(
"custom_frame_mappings,expected_match",
[
pytest.param(
NASA_VIDEO.generate_custom_frame_mappings(0),
"seek_mode",
id="valid_content_approximate",
),
("{}", "The input is empty or missing the required 'frames' key."),
(
'{"valid": "json"}',
"The input is empty or missing the required 'frames' key.",
),
(
'{"frames": [{"missing": "keys"}]}',
"keys are required in the frame metadata.",
),
],
)
def test_custom_frame_mappings_init_fails(
self, device, custom_frame_mappings, expected_match
):
with pytest.raises(ValueError, match=expected_match):
VideoDecoder(
NASA_VIDEO.path,
stream_index=0,
device=device,
custom_frame_mappings=custom_frame_mappings,
seek_mode=("approximate" if expected_match == "seek_mode" else "exact"),
)

@pytest.mark.parametrize("device", all_supported_devices())
def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
invalid_json_path = tmp_path / "invalid_json"
with open(invalid_json_path, "w+") as f:
f.write("invalid input")

# Test both file object and string
with open(invalid_json_path, "r") as file_obj:
for custom_frame_mappings in [
file_obj,
file_obj.read(),
]:
with pytest.raises(ValueError, match="Invalid custom frame mappings"):
VideoDecoder(
NASA_VIDEO.path,
stream_index=0,
device=device,
custom_frame_mappings=custom_frame_mappings,
)
# def setup_frame_mappings(tmp_path, file, stream_index):
# json_path = tmp_path / "custom_frame_mappings.json"
# custom_frame_mappings = NASA_VIDEO.generate_custom_frame_mappings(stream_index)
# if file:
# # Write the custom frame mappings to a JSON file
# with open(json_path, "w") as f:
# f.write(custom_frame_mappings)
# return json_path
# else:
# # Return the custom frame mappings as a JSON string
# return custom_frame_mappings

# @pytest.mark.parametrize("device", all_supported_devices())
# @pytest.mark.parametrize("stream_index", [0, 3])
# @pytest.mark.parametrize(
# "method",
# (
# partial(setup_frame_mappings, file=True),
# partial(setup_frame_mappings, file=False),
# ),
# )
# def test_custom_frame_mappings_json_and_bytes(
# self, tmp_path, device, stream_index, method
# ):
# custom_frame_mappings = method(tmp_path=tmp_path, stream_index=stream_index)
# # Optionally open the custom frame mappings file if it is a file path
# # or use a null context if it is a string.
# with (
# open(custom_frame_mappings, "r")
# if hasattr(custom_frame_mappings, "read")
# else contextlib.nullcontext()
# ) as custom_frame_mappings:
# decoder = VideoDecoder(
# NASA_VIDEO.path,
# stream_index=stream_index,
# device=device,
# custom_frame_mappings=custom_frame_mappings,
# )
# frame_0 = decoder.get_frame_at(0)
# frame_5 = decoder.get_frame_at(5)
# assert_frames_equal(
# frame_0.data,
# NASA_VIDEO.get_frame_data_by_index(0, stream_index=stream_index).to(device),
# )
# assert_frames_equal(
# frame_5.data,
# NASA_VIDEO.get_frame_data_by_index(5, stream_index=stream_index).to(device),
# )
# frames0_5 = decoder.get_frames_played_in_range(
# frame_0.pts_seconds, frame_5.pts_seconds
# )
# assert_frames_equal(
# frames0_5.data,
# NASA_VIDEO.get_frame_data_by_range(0, 5, stream_index=stream_index).to(
# device
# ),
# )

# @pytest.mark.parametrize("device", all_supported_devices())
# @pytest.mark.parametrize(
# "custom_frame_mappings,expected_match",
# [
# pytest.param(
# NASA_VIDEO.generate_custom_frame_mappings(0),
# "seek_mode",
# id="valid_content_approximate",
# ),
# ("{}", "The input is empty or missing the required 'frames' key."),
# (
# '{"valid": "json"}',
# "The input is empty or missing the required 'frames' key.",
# ),
# (
# '{"frames": [{"missing": "keys"}]}',
# "keys are required in the frame metadata.",
# ),
# ],
# )
# def test_custom_frame_mappings_init_fails(
# self, device, custom_frame_mappings, expected_match
# ):
# with pytest.raises(ValueError, match=expected_match):
# VideoDecoder(
# NASA_VIDEO.path,
# stream_index=0,
# device=device,
# custom_frame_mappings=custom_frame_mappings,
# seek_mode=("approximate" if expected_match == "seek_mode" else "exact"),
# )

# @pytest.mark.parametrize("device", all_supported_devices())
# def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device):
# invalid_json_path = tmp_path / "invalid_json"
# with open(invalid_json_path, "w+") as f:
# f.write("invalid input")

# # Test both file object and string
# with open(invalid_json_path, "r") as file_obj:
# for custom_frame_mappings in [
# file_obj,
# file_obj.read(),
# ]:
# with pytest.raises(ValueError, match="Invalid custom frame mappings"):
# VideoDecoder(
# NASA_VIDEO.path,
# stream_index=0,
# device=device,
# custom_frame_mappings=custom_frame_mappings,
# )


class TestAudioDecoder:
Expand Down
Loading