From c03294bbf95486f3f12bc9bb6746f26ddb3a54da Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 10:36:52 +0100 Subject: [PATCH 1/4] Fix binary search of getFramesDisplayedByTimestamps --- .../decoders/_core/VideoDecoder.cpp | 9 +--- test/decoders/test_video_decoder_ops.py | 45 +++++++++++++++++++ 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index add9c9bee..9243365c5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,21 +1119,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - stream.allFrames.end(), + stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double framePts) { return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); - // If the frame index is larger than the size of allFrames, that means we - // couldn't match the pts value to the pts value of a NEXT FRAME. And - // that means that this timestamp falls during the time between when the - // last frame is displayed, and the video ends. Hence, it should map to the - // index of the last frame. - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } - return getFramesAtIndices(streamIndex, frameIndices); } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ed681469..5ebd7830e 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -186,6 +186,51 @@ def test_get_frames_by_pts(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + def test_pts_apis_against_index_ref(self): + # Get all frames in the video, then query all frames with all time-based + # APIs exactly where those frames are supposed to start. We assert that + # we get the expected frame. + decoder = create_from_file(str(NASA_VIDEO.path)) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder) + + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + num_frames = metadata_dict["numFrames"] + assert num_frames == 390 + + stream_index = 3 + _, all_pts_seconds_ref, _ = zip( + *[ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + ) + for frame_index in range(num_frames) + ] + ) + all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref) + + assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref)) + + _, pts_seconds, _ = zip( + *[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=0, + stop_seconds=all_pts_seconds_ref[-1] + 1e-4, + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts( + decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist() + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From 5ab33b982d582ce5a8bd2b795a7eae69d396b1e6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 10:39:07 +0100 Subject: [PATCH 2/4] Comment --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 9243365c5..97d55c067 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,6 +1119,8 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), + // See https://github.com/pytorch/torchcodec/pull/286 for why the `- 1` + // is needed. stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double framePts) { From fa374bc0709baa7dc993de69e80ba8626cb438f6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 11:13:40 +0100 Subject: [PATCH 3/4] comment --- test/decoders/test_video_decoder_ops.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 5ebd7830e..6ad774b5d 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -187,6 +187,7 @@ def test_get_frames_by_pts(self): assert_tensor_equal(frames[0], frames[-1]) def test_pts_apis_against_index_ref(self): + # Non-regression test for https://github.com/pytorch/torchcodec/pull/286 # Get all frames in the video, then query all frames with all time-based # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. From 448c4a60ee3ad436b3e6c30bbacc9c9053611970 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 11:53:28 +0100 Subject: [PATCH 4/4] More inline comment --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 97d55c067..997269fe7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,8 +1119,13 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - // See https://github.com/pytorch/torchcodec/pull/286 for why the `- 1` - // is needed. + // We have to end the search at end() - 1 to exclude the last frame from + // the search: the last frame's nextPts field is 0, which breaks the + // working assumption of std::lower_bound() that the search space must + // be sorted. The last frame can still be correctly returned: when the + // binary search ends without a match, `end() - 1` will be returned, and + // that corresponds to the last frame. + // See https://github.com/pytorch/torchcodec/pull/286 for more details. stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double framePts) {