diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index b773e0188..95ea88e8d 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -538,6 +538,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const { return containerMetadata_; } +torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) { + validateUserProvidedStreamIndex(streamIndex); + validateScannedAllStreams("getKeyFrameIndices"); + + const std::vector& keyFrames = streamInfos_[streamIndex].keyFrames; + torch::Tensor keyFrameIndices = + torch::empty({static_cast(keyFrames.size())}, {torch::kInt64}); + for (size_t i = 0; i < keyFrames.size(); ++i) { + keyFrameIndices[i] = keyFrames[i].frameIndex; + } + + return keyFrameIndices; +} + int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex( AVStream* stream, int64_t pts) const { @@ -654,7 +668,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { return frameInfo1.pts < frameInfo2.pts; }); + size_t keyIndex = 0; for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { + streamInfo.allFrames[i].frameIndex = i; + + // For correctly encoded files, we shouldn't need to ensure that keyIndex + // is less than the number of key frames. That is, the relationship + // between the frames in allFrames and keyFrames should be such that + // keyIndex is always a valid index into keyFrames. But we're being + // defensive in case we encounter incorrectly encoded files. + if (keyIndex < streamInfo.keyFrames.size() && + streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) { + streamInfo.keyFrames[keyIndex].frameIndex = i; + ++keyIndex; + } + if (i + 1 < streamInfo.allFrames.size()) { streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 0d4bfb1c7..2f7372895 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -97,6 +97,10 @@ class VideoDecoder { // Returns the metadata for the container. ContainerMetadata getContainerMetadata() const; + // Returns the key frame indices as a tensor. The tensor is 1D and contains + // int64 values, where each value is the frame index for a key frame. + torch::Tensor getKeyFrameIndices(int streamIndex); + // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- @@ -284,12 +288,19 @@ class VideoDecoder { struct FrameInfo { int64_t pts = 0; - // The value of this default is important: the last frame's nextPts will be - // INT64_MAX, which ensures that the allFrames vec contains FrameInfo - // structs with *increasing* nextPts values. That's a necessary condition - // for the binary searches on those values to work properly (as typically - // done during pts -> index conversions.) + + // The value of the nextPts default is important: the last frame's nextPts + // will be INT64_MAX, which ensures that the allFrames vec contains + // FrameInfo structs with *increasing* nextPts values. That's a necessary + // condition for the binary searches on those values to work properly (as + // typically done during pts -> index conversions). int64_t nextPts = INT64_MAX; + + // Note that frameIndex is ALWAYS the index into all of the frames in that + // stream, even when the FrameInfo is part of the key frame index. Given a + // FrameInfo for a key frame, the frameIndex allows us to know which frame + // that is in the stream. + int64_t frameIndex = 0; }; struct FilterGraphContext { diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 05bc903bf..78ecc4258 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -48,6 +48,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)"); + m.def( + "_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor"); m.def("get_json_metadata(Tensor(a!) decoder) -> str"); m.def("get_container_json_metadata(Tensor(a!) decoder) -> str"); m.def( @@ -334,6 +336,13 @@ bool _test_frame_pts_equality( videoDecoder->getPtsSecondsForFrame(stream_index, frame_index); } +torch::Tensor _get_key_frame_indices( + at::Tensor& decoder, + int64_t stream_index) { + auto videoDecoder = unwrapTensorToGetDecoder(decoder); + return videoDecoder->getKeyFrameIndices(stream_index); +} + std::string get_json_metadata(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); @@ -526,6 +535,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); m.impl("get_next_frame", &get_next_frame); + m.impl("_get_key_frame_indices", &_get_key_frame_indices); m.impl("get_json_metadata", &get_json_metadata); m.impl("get_container_json_metadata", &get_container_json_metadata); m.impl("get_stream_json_metadata", &get_stream_json_metadata); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 5b25e7f69..241d80983 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -137,6 +137,8 @@ bool _test_frame_pts_equality( int64_t frame_index, double pts_seconds_to_test); +torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index); + // Get the metadata from the video as a string. std::string get_json_metadata(at::Tensor& decoder); diff --git a/src/torchcodec/decoders/_core/__init__.py b/src/torchcodec/decoders/_core/__init__.py index a1ac9a478..d39d3d237 100644 --- a/src/torchcodec/decoders/_core/__init__.py +++ b/src/torchcodec/decoders/_core/__init__.py @@ -13,6 +13,7 @@ ) from .video_decoder_ops import ( _add_video_stream, + _get_key_frame_indices, _test_frame_pts_equality, add_video_stream, create_from_bytes, diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 16e633921..66f5dd9a9 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -82,6 +82,7 @@ def load_torchcodec_extension(): _get_container_json_metadata = ( torch.ops.torchcodec_ns.get_container_json_metadata.default ) +_get_key_frame_indices = torch.ops.torchcodec_ns._get_key_frame_indices.default scan_all_streams_to_update_metadata = ( torch.ops.torchcodec_ns.scan_all_streams_to_update_metadata.default ) @@ -255,6 +256,13 @@ def get_frames_by_pts_in_range_abstract( ) +@register_fake("torchcodec_ns::_get_key_frame_indices") +def get_key_frame_indices_abstract( + decoder: torch.Tensor, *, stream_index: int +) -> torch.Tensor: + return torch.empty([], dtype=torch.int) + + @register_fake("torchcodec_ns::get_json_metadata") def get_json_metadata_abstract(decoder: torch.Tensor) -> str: return "" diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 46279135d..3c5367ab8 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -185,6 +185,11 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor: f"Unsupported key type: {type(key)}. Supported types are int and slice." ) + def _get_key_frame_indices(self) -> list[int]: + return core._get_key_frame_indices( + self._decoder, stream_index=self.stream_index + ) + def get_frame_at(self, index: int) -> Frame: """Return a single frame at the given index. diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index c05b33d9e..d836fd8ff 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -831,6 +831,49 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): with pytest.raises(ValueError, match="Invalid stop seconds"): frame = decoder.get_frames_played_in_range(0, 23) # noqa + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_key_frame_indices(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact") + key_frame_indices = decoder._get_key_frame_indices() + + # The key frame indices were generated from the following command: + # $ ffprobe -v error -hide_banner -select_streams v:1 -show_frames -of csv test/resources/nasa_13013.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt + # What it's doing: + # 1. Calling ffprobe on the second video stream, which is absolute stream index 3. + # 2. Showing all frames for that stream. + # 3. Using grep to find the "I" frames, which are the key frames. We also get the line + # number, which is also the count of the rames. + # 4. Using cut to extract just the count for the frame. + # Finally, because the above produces a count, which is index + 1, we subtract + # one from all values manually to arrive at the values below. + # TODO: decide if/how we want to incorporate key frame indices into the utils + # framework. + nasa_reference_key_frame_indices = torch.tensor([0, 240]) + + torch.testing.assert_close( + key_frame_indices, nasa_reference_key_frame_indices, atol=0, rtol=0 + ) + + decoder = VideoDecoder(AV1_VIDEO.path, device=device, seek_mode="exact") + key_frame_indices = decoder._get_key_frame_indices() + + # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt + av1_reference_key_frame_indices = torch.tensor([0]) + + torch.testing.assert_close( + key_frame_indices, av1_reference_key_frame_indices, atol=0, rtol=0 + ) + + decoder = VideoDecoder(H265_VIDEO.path, device=device, seek_mode="exact") + key_frame_indices = decoder._get_key_frame_indices() + + # ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt + h265_reference_key_frame_indices = torch.tensor([0, 2, 4, 6, 8]) + + torch.testing.assert_close( + key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 + ) + if __name__ == "__main__": pytest.main()