From 06ed6be84c3d8aca764db8f78d50d08de2146618 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 15:18:34 +0100 Subject: [PATCH 1/3] Alternative fix --- .../decoders/_core/VideoDecoder.cpp | 6 -- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- test/decoders/test_video_decoder_ops.py | 60 +++++++++++++++++++ 3 files changed, 61 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index add9c9bee..8c9f43635 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1125,12 +1125,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); - // If the frame index is larger than the size of allFrames, that means we - // couldn't match the pts value to the pts value of a NEXT FRAME. And - // that means that this timestamp falls during the time between when the - // last frame is displayed, and the video ends. Hence, it should map to the - // index of the last frame. - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c0f489cef..49acfb1d8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -299,7 +299,7 @@ class VideoDecoder { private: struct FrameInfo { int64_t pts = 0; - int64_t nextPts = 0; + int64_t nextPts = INT64_MAX; }; struct FilterState { UniqueAVFilterGraph filterGraph; diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ed681469..da8a9541b 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -186,6 +186,66 @@ def test_get_frames_by_pts(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + def test_pts_apis_against_index_ref(self): + # Non-regression test for https://github.com/pytorch/torchcodec/pull/286 + # Get all frames in the video, then query all frames with all time-based + # APIs exactly where those frames are supposed to start. We assert that + # we get the expected frame. + decoder = create_from_file(str(NASA_VIDEO.path)) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder) + + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + num_frames = metadata_dict["numFrames"] + assert num_frames == 390 + + stream_index = 3 + _, all_pts_seconds_ref, _ = zip( + *[ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + ) + for frame_index in range(num_frames) + ] + ) + all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref) + + assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref)) + + _, pts_seconds, _ = zip( + *[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=0, + stop_seconds=all_pts_seconds_ref[-1] + 1e-4, + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = zip( + *[ + get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=pts, + stop_seconds=pts + 1e-4, + ) + for pts in all_pts_seconds_ref + ] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts( + decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist() + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder) From 98a675871273ab81908d2171315d1b10cf61a0a6 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 15:19:32 +0100 Subject: [PATCH 2/3] nit --- test/decoders/test_video_decoder_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index da8a9541b..1e2b1a96f 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -187,7 +187,7 @@ def test_get_frames_by_pts(self): assert_tensor_equal(frames[0], frames[-1]) def test_pts_apis_against_index_ref(self): - # Non-regression test for https://github.com/pytorch/torchcodec/pull/286 + # Non-regression test for https://github.com/pytorch/torchcodec/pull/287 # Get all frames in the video, then query all frames with all time-based # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. From 7d5972eb009a5e2df277ba0955f57edc190fe19f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 24 Oct 2024 15:32:06 +0100 Subject: [PATCH 3/3] Comment --- src/torchcodec/decoders/_core/VideoDecoder.h | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 49acfb1d8..ea122b54a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -299,6 +299,11 @@ class VideoDecoder { private: struct FrameInfo { int64_t pts = 0; + // The value of this default is important: the last frame's nextPts will be + // INT64_MAX, which ensures that the allFrames vec contains FrameInfo + // structs with *increasing* nextPts values. That's a necessary condition + // for the binary searches on those values to work properly (as typically + // done during pts -> index conversions.) int64_t nextPts = INT64_MAX; }; struct FilterState {