diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index add9c9bee..8c9f43635 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1125,12 +1125,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); - // If the frame index is larger than the size of allFrames, that means we - // couldn't match the pts value to the pts value of a NEXT FRAME. And - // that means that this timestamp falls during the time between when the - // last frame is displayed, and the video ends. Hence, it should map to the - // index of the last frame. - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c0f489cef..ea122b54a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -299,7 +299,12 @@ class VideoDecoder { private: struct FrameInfo { int64_t pts = 0; - int64_t nextPts = 0; + // The value of this default is important: the last frame's nextPts will be + // INT64_MAX, which ensures that the allFrames vec contains FrameInfo + // structs with *increasing* nextPts values. That's a necessary condition + // for the binary searches on those values to work properly (as typically + // done during pts -> index conversions.) + int64_t nextPts = INT64_MAX; }; struct FilterState { UniqueAVFilterGraph filterGraph; diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ed681469..1e2b1a96f 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -186,6 +186,66 @@ def test_get_frames_by_pts(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + def test_pts_apis_against_index_ref(self): + # Non-regression test for https://github.com/pytorch/torchcodec/pull/287 + # Get all frames in the video, then query all frames with all time-based + # APIs exactly where those frames are supposed to start. We assert that + # we get the expected frame. + decoder = create_from_file(str(NASA_VIDEO.path)) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder) + + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + num_frames = metadata_dict["numFrames"] + assert num_frames == 390 + + stream_index = 3 + _, all_pts_seconds_ref, _ = zip( + *[ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + ) + for frame_index in range(num_frames) + ] + ) + all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref) + + assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref)) + + _, pts_seconds, _ = zip( + *[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=0, + stop_seconds=all_pts_seconds_ref[-1] + 1e-4, + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = zip( + *[ + get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=pts, + stop_seconds=pts + 1e-4, + ) + for pts in all_pts_seconds_ref + ] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts( + decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist() + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder)