From 0219f92468985064367d45809c8e33324b845895 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 27 Jan 2025 13:02:15 -0800 Subject: [PATCH 1/8] Rough implementation of getting key frame indices --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 14 ++++++++++++++ src/torchcodec/decoders/_core/VideoDecoder.h | 2 ++ src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 9 +++++++++ src/torchcodec/decoders/_core/VideoDecoderOps.h | 4 ++++ src/torchcodec/decoders/_core/__init__.py | 1 + src/torchcodec/decoders/_core/video_decoder_ops.py | 6 ++++++ src/torchcodec/decoders/_video_decoder.py | 3 +++ test/decoders/test_video_decoder.py | 6 ++++++ 8 files changed, 45 insertions(+) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index bba8b4e4a..88724c131 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -553,6 +553,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const { return containerMetadata_; } +std::vector VideoDecoder::getKeyFrameIndices(int streamIndex) { + validateUserProvidedStreamIndex(streamIndex); + validateScannedAllStreams("getKeyFrameIndices"); + + std::vector keyFrameIndices; + const StreamInfo& streamInfo = streamInfos_[streamIndex]; + for (const FrameInfo& frameInfo : streamInfo.keyFrames) { + keyFrameIndices.push_back( + getKeyFrameIndexForPtsUsingScannedIndex(streamInfo.keyFrames, frameInfo.pts)); + } + + return keyFrameIndices; +} + int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex( AVStream* stream, int64_t pts) const { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 760b534df..8688d8c9e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -100,6 +100,8 @@ class VideoDecoder { // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; + std::vector getKeyFrameIndices(int streamIndex); + // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index e46a5066f..4018cdbf8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -48,6 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); + m.def("get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -334,6 +335,13 @@ bool _test_frame_pts_equality( videoDecoder->getPtsSecondsForFrame(stream_index, frame_index); } +std::vector get_key_frame_indices( + at::Tensor& decoder, + int64_t stream_index) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + return videoDecoder->getKeyFrameIndices(stream_index); +} + std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -526,6 +534,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); m.impl("get_next_frame", &get_next_frame); + m.impl("get_key_frame_indices", &get_key_frame_indices); m.impl("get_json_metadata", &get_json_metadata); m.impl("get_container_json_metadata", &get_container_json_metadata); m.impl("get_stream_json_metadata", &get_stream_json_metadata); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 5b25e7f69..7e2c5040c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -137,6 +137,10 @@ bool _test_frame_pts_equality( int64_t frame_index, double pts_seconds_to_test); +std::vector get_key_frame_indices( + at::Tensor& decoder, + int64_t stream_index); + // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index a1ac9a478..512323036 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -26,6 +26,7 @@ get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, + get_key_frame_indices, get_next_frame, scan_all_streams_to_update_metadata, seek_to_pts, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 16e633921..1b56a4eb0 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -77,6 +77,7 @@ def load_torchcodec_extension(): get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default +get_key_frame_indices = torch.ops.torchcodec_ns.get_key_frame_indices.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default _test_frame_pts_equality = torch.ops.torchcodec_ns._test_frame_pts_equality.default _get_container_json_metadata = ( @@ -255,6 +256,11 @@ def get_frames_by_pts_in_range_abstract( ) +@register_fake("torchcodec_ns::get_key_frame_indices") +def get_key_frame_indices_abstract(decoder: torch.Tensor, *, stream_index: int) -> List[int]: + return [] + + @register_fake("torchcodec_ns::get_json_metadata") def get_json_metadata_abstract(decoder: torch.Tensor) -> str: return "" diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 46279135d..6117acdcc 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -185,6 +185,9 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor: f"Unsupported key type: {type(key)}. Supported types are int and slice." ) + def _get_key_frame_indices(self) -> list[int]: + return core.get_key_frame_indices(self._decoder, stream_index=self.stream_index) + def get_frame_at(self, index: int) -> Frame: """Return a single frame at the given index. diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index c05b33d9e..eb0f2beb4 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -831,6 +831,12 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): with pytest.raises(ValueError, match="Invalid stop seconds"): frame = decoder.get_frames_played_in_range(0, 23) # noqa + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_key_frame_indices(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") + key_frame_indices = decoder._get_key_frame_indices() + assert len(key_frame_indices) > 0 + if __name__ == "__main__": pytest.main() From 1bcfaa5a579013a7964b6fdc34538d93c3c01fab Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 27 Jan 2025 13:04:04 -0800 Subject: [PATCH 2/8] Formatting --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 4 ++-- src/torchcodec/decoders/_core/video_decoder_ops.py | 4 +++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 88724c131..f9f1dbcfe 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -560,8 +560,8 @@ std::vector VideoDecoder::getKeyFrameIndices(int streamIndex) { std::vector keyFrameIndices; const StreamInfo& streamInfo = streamInfos_[streamIndex]; for (const FrameInfo& frameInfo : streamInfo.keyFrames) { - keyFrameIndices.push_back( - getKeyFrameIndexForPtsUsingScannedIndex(streamInfo.keyFrames, frameInfo.pts)); + keyFrameIndices.push_back(getKeyFrameIndexForPtsUsingScannedIndex( + streamInfo.keyFrames, frameInfo.pts)); } return keyFrameIndices; diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 1b56a4eb0..9ea882b1c 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -257,7 +257,9 @@ def get_frames_by_pts_in_range_abstract( @register_fake("torchcodec_ns::get_key_frame_indices") -def get_key_frame_indices_abstract(decoder: torch.Tensor, *, stream_index: int) -> List[int]: +def get_key_frame_indices_abstract( + decoder: torch.Tensor, *, stream_index: int +) -> List[int]: return [] From 736401cb51166220452e0abb27560396fc697ebe Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 28 Jan 2025 06:09:33 -0800 Subject: [PATCH 3/8] Make core op private --- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 6 +++--- src/torchcodec/decoders/_core/__init__.py | 2 +- src/torchcodec/decoders/_core/video_decoder_ops.py | 4 ++-- src/torchcodec/decoders/_video_decoder.py | 2 +- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 4018cdbf8..bdec13ca5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -48,7 +48,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); - m.def("get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]"); + m.def("_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -335,7 +335,7 @@ bool _test_frame_pts_equality( videoDecoder->getPtsSecondsForFrame(stream_index, frame_index); } -std::vector get_key_frame_indices( +std::vector _get_key_frame_indices( at::Tensor& decoder, int64_t stream_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -534,7 +534,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); m.impl("get_next_frame", &get_next_frame); - m.impl("get_key_frame_indices", &get_key_frame_indices); + m.impl("_get_key_frame_indices", &_get_key_frame_indices); m.impl("get_json_metadata", &get_json_metadata); m.impl("get_container_json_metadata", &get_container_json_metadata); m.impl("get_stream_json_metadata", &get_stream_json_metadata); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index 512323036..d39d3d237 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -13,6 +13,7 @@ ) from .video_decoder_ops import ( _add_video_stream, + _get_key_frame_indices, _test_frame_pts_equality, add_video_stream, create_from_bytes, @@ -26,7 +27,6 @@ get_frames_by_pts_in_range, get_frames_in_range, get_json_metadata, - get_key_frame_indices, get_next_frame, scan_all_streams_to_update_metadata, seek_to_pts, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 9ea882b1c..d04b49496 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -77,12 +77,12 @@ def load_torchcodec_extension(): get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default -get_key_frame_indices = torch.ops.torchcodec_ns.get_key_frame_indices.default get_json_metadata = torch.ops.torchcodec_ns.get_json_metadata.default _test_frame_pts_equality = torch.ops.torchcodec_ns._test_frame_pts_equality.default _get_container_json_metadata = ( torch.ops.torchcodec_ns.get_container_json_metadata.default ) +_get_key_frame_indices = torch.ops.torchcodec_ns._get_key_frame_indices.default scan_all_streams_to_update_metadata = ( torch.ops.torchcodec_ns.scan_all_streams_to_update_metadata.default ) @@ -256,7 +256,7 @@ def get_frames_by_pts_in_range_abstract( ) -@register_fake("torchcodec_ns::get_key_frame_indices") +@register_fake("torchcodec_ns::_get_key_frame_indices") def get_key_frame_indices_abstract( decoder: torch.Tensor, *, stream_index: int ) -> List[int]: diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 6117acdcc..f32dc0d8c 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -186,7 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor: ) def _get_key_frame_indices(self) -> list[int]: - return core.get_key_frame_indices(self._decoder, stream_index=self.stream_index) + return core._get_key_frame_indices(self._decoder, stream_index=self.stream_index) def get_frame_at(self, index: int) -> Frame: """Return a single frame at the given index. From 121a9fdf7ed39247c4d579bd0385a25001c80908 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 28 Jan 2025 06:36:11 -0800 Subject: [PATCH 4/8] Change return type to a 1D tensor of int64 --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 14 ++++++++------ src/torchcodec/decoders/_core/VideoDecoder.h | 4 +++- src/torchcodec/decoders/_core/VideoDecoderOps.cpp | 5 +++-- src/torchcodec/decoders/_core/VideoDecoderOps.h | 4 +--- src/torchcodec/decoders/_video_decoder.py | 4 +++- test/decoders/test_video_decoder.py | 4 +++- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index f9f1dbcfe..fcfa9e8ad 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -553,15 +553,17 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const { return containerMetadata_; } -std::vector VideoDecoder::getKeyFrameIndices(int streamIndex) { +torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getKeyFrameIndices"); - std::vector keyFrameIndices; - const StreamInfo& streamInfo = streamInfos_[streamIndex]; - for (const FrameInfo& frameInfo : streamInfo.keyFrames) { - keyFrameIndices.push_back(getKeyFrameIndexForPtsUsingScannedIndex( - streamInfo.keyFrames, frameInfo.pts)); + const std::vector& keyFrames = streamInfos_[streamIndex].keyFrames; + torch::Tensor keyFrameIndices = + torch::empty({static_cast(keyFrames.size())}, {torch::kInt64}); + for (size_t i = 0; i < keyFrames.size(); ++i) { + int64_t pts = keyFrames[i].pts; + keyFrameIndices[i] = + getKeyFrameIndexForPtsUsingScannedIndex(keyFrames, pts); } return keyFrameIndices; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8688d8c9e..249c4ec5c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -100,7 +100,9 @@ class VideoDecoder { // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; - std::vector getKeyFrameIndices(int streamIndex); + // Returns the key frame indices as a tensor. The tensor is 1D and contains + // int64 values, where each value is the frame index for a key frame. + torch::Tensor getKeyFrameIndices(int streamIndex); // -------------------------------------------------------------------------- // ADDING STREAMS API diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index bdec13ca5..41b1dd031 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -48,7 +48,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); - m.def("_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> int[]"); + m.def( + "_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -335,7 +336,7 @@ bool _test_frame_pts_equality( videoDecoder->getPtsSecondsForFrame(stream_index, frame_index); } -std::vector _get_key_frame_indices( +torch::Tensor _get_key_frame_indices( at::Tensor& decoder, int64_t stream_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 7e2c5040c..241d80983 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -137,9 +137,7 @@ bool _test_frame_pts_equality( int64_t frame_index, double pts_seconds_to_test); -std::vector get_key_frame_indices( - at::Tensor& decoder, - int64_t stream_index); +torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index); // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index f32dc0d8c..3c5367ab8 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -186,7 +186,9 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor: ) def _get_key_frame_indices(self) -> list[int]: - return core._get_key_frame_indices(self._decoder, stream_index=self.stream_index) + return core._get_key_frame_indices( + self._decoder, stream_index=self.stream_index + ) def get_frame_at(self, index: int) -> Frame: """Return a single frame at the given index. diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index eb0f2beb4..7f04e6a4b 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -835,7 +835,9 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): def test_get_key_frame_indices(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") key_frame_indices = decoder._get_key_frame_indices() - assert len(key_frame_indices) > 0 + size = key_frame_indices.size() + assert size[0] > 0 + assert len(size) == 1 if __name__ == "__main__": From ef644f97b68b77b3e3390a0f9d1c606d539536f8 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 28 Jan 2025 12:56:35 -0800 Subject: [PATCH 5/8] Create key frame index manually --- .../decoders/_core/VideoDecoder.cpp | 10 +++++++--- src/torchcodec/decoders/_core/VideoDecoder.h | 17 +++++++++++----- test/decoders/test_video_decoder.py | 20 ++++++++++++++++--- 3 files changed, 36 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index fcfa9e8ad..7f1020cf0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -561,9 +561,7 @@ torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) { torch::Tensor keyFrameIndices = torch::empty({static_cast(keyFrames.size())}, {torch::kInt64}); for (size_t i = 0; i < keyFrames.size(); ++i) { - int64_t pts = keyFrames[i].pts; - keyFrameIndices[i] = - getKeyFrameIndexForPtsUsingScannedIndex(keyFrames, pts); + keyFrameIndices[i] = keyFrames[i].frameIndex; } return keyFrameIndices; @@ -685,7 +683,13 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { return frameInfo1.pts < frameInfo2.pts; }); + size_t keyIndex = 0; for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { + streamInfo.allFrames[i].frameIndex = i; + if (streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) { + streamInfo.keyFrames[keyIndex].frameIndex = i; + ++keyIndex; + } if (i + 1 < streamInfo.allFrames.size()) { streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 249c4ec5c..53428ca45 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -291,12 +291,19 @@ class VideoDecoder { struct FrameInfo { int64_t pts = 0; - // The value of this default is important: the last frame's nextPts will be - // INT64_MAX, which ensures that the allFrames vec contains FrameInfo - // structs with *increasing* nextPts values. That's a necessary condition - // for the binary searches on those values to work properly (as typically - // done during pts -> index conversions.) + + // The value of the nextPts default is important: the last frame's nextPts + // will be INT64_MAX, which ensures that the allFrames vec contains + // FrameInfo structs with *increasing* nextPts values. That's a necessary + // condition for the binary searches on those values to work properly (as + // typically done during pts -> index conversions). int64_t nextPts = INT64_MAX; + + // Note that frameIndex is ALWAYS the index into all of the frames in that + // stream, even when the FrameInfo is part of the key frame index. Given a + // FrameInfo for a key frame, the frameIndex allows us to know which frame + // that is in the stream. + int64_t frameIndex = 0; }; struct FilterGraphContext { diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 7f04e6a4b..ed2645e1d 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -835,9 +835,23 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): def test_get_key_frame_indices(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") key_frame_indices = decoder._get_key_frame_indices() - size = key_frame_indices.size() - assert size[0] > 0 - assert len(size) == 1 + + # The key frame indices were generated from the following command: + # $ ffprobe -v error -hide_banner -select_streams v:1 -show_frames -of csv test/resources/nasa_13013.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt + # What it's doing: + # 1. Calling ffprobe on the second video stream, which is absolute stream index 3. + # 2. Showing all frames for that stream. + # 3. Using grep to find and count the "I" frames, which are the key frames. + # 4. Using cut to extract just the count for the frame. + # Finally, because the above produces a count, which is index + 1, we subtract + # one from all values manually to arrive at the values below. + # TODO: decide if/how we want to incorporate key frame indices into the utils + # framework. + reference_key_frame_indices = torch.tensor([0, 240]) + + torch.testing.assert_close( + key_frame_indices, reference_key_frame_indices, atol=0, rtol=0 + ) if __name__ == "__main__": From 09ada7ceedbf52db9293d676f95d5857b45968ef Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 28 Jan 2025 13:04:34 -0800 Subject: [PATCH 6/8] Fix return type --- src/torchcodec/decoders/_core/video_decoder_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index d04b49496..66f5dd9a9 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -259,8 +259,8 @@ def get_frames_by_pts_in_range_abstract( @register_fake("torchcodec_ns::_get_key_frame_indices") def get_key_frame_indices_abstract( decoder: torch.Tensor, *, stream_index: int -) -> List[int]: - return [] +) -> torch.Tensor: + return torch.empty([], dtype=torch.int) @register_fake("torchcodec_ns::get_json_metadata") From d80e795dfd619d7cdbbbe0b56d65e31bcd6ca089 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 29 Jan 2025 10:59:08 -0800 Subject: [PATCH 7/8] More defensive init and more testing --- .../decoders/_core/VideoDecoder.cpp | 10 ++++++- test/decoders/test_video_decoder.py | 26 ++++++++++++++++--- 2 files changed, 32 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 42ed5592d..95ea88e8d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -671,10 +671,18 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { size_t keyIndex = 0; for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { streamInfo.allFrames[i].frameIndex = i; - if (streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) { + + // For correctly encoded files, we shouldn't need to ensure that keyIndex + // is less than the number of key frames. That is, the relationship + // between the frames in allFrames and keyFrames should be such that + // keyIndex is always a valid index into keyFrames. But we're being + // defensive in case we encounter incorrectly encoded files. + if (keyIndex < streamInfo.keyFrames.size() && + streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) { streamInfo.keyFrames[keyIndex].frameIndex = i; ++keyIndex; } + if (i + 1 < streamInfo.allFrames.size()) { streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; } diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index ed2645e1d..6e6ceafb4 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -841,18 +841,38 @@ def test_get_key_frame_indices(self, device): # What it's doing: # 1. Calling ffprobe on the second video stream, which is absolute stream index 3. # 2. Showing all frames for that stream. - # 3. Using grep to find and count the "I" frames, which are the key frames. + # 3. Using grep to find the "I" frames, which are the key frames. We also get the line + # number, which is also the count of the rames. # 4. Using cut to extract just the count for the frame. # Finally, because the above produces a count, which is index + 1, we subtract # one from all values manually to arrive at the values below. # TODO: decide if/how we want to incorporate key frame indices into the utils # framework. - reference_key_frame_indices = torch.tensor([0, 240]) + nasa_reference_key_frame_indices = torch.tensor([0, 240]) torch.testing.assert_close( - key_frame_indices, reference_key_frame_indices, atol=0, rtol=0 + key_frame_indices, nasa_reference_key_frame_indices, atol=0, rtol=0 ) + decoder = VideoDecoder(AV1_VIDEO.path, device=device, seek_mode="exact") + key_frame_indices = decoder._get_key_frame_indices() + + # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt + av1_reference_key_frame_indices = torch.tensor([0]) + + torch.testing.assert_close( + key_frame_indices, av1_reference_key_frame_indices, atol=0, rtol=0 + ) + + decoder = VideoDecoder(H265_VIDEO.path, device=device, seek_mode="exact") + key_frame_indices = decoder._get_key_frame_indices() + + # ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt + h265_reference_key_frame_indices = torch.tensor([0, 2, 4, 6, 8]) + + torch.testing.assert_close( + key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 + ) if __name__ == "__main__": pytest.main() From 1c06f1e35a35f2b90a174b907b14d17a4f248388 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 29 Jan 2025 10:59:41 -0800 Subject: [PATCH 8/8] Formatting --- test/decoders/test_video_decoder.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 6e6ceafb4..d836fd8ff 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -874,5 +874,6 @@ def test_get_key_frame_indices(self, device): key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 ) + if __name__ == "__main__": pytest.main()