diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index add9c9bee..997269fe7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1119,21 +1119,21 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps( auto it = std::lower_bound( stream.allFrames.begin(), - stream.allFrames.end(), + // We have to end the search at end() - 1 to exclude the last frame from + // the search: the last frame's nextPts field is 0, which breaks the + // working assumption of std::lower_bound() that the search space must + // be sorted. The last frame can still be correctly returned: when the + // binary search ends without a match, `end() - 1` will be returned, and + // that corresponds to the last frame. + // See https://github.com/pytorch/torchcodec/pull/286 for more details. + stream.allFrames.end() - 1, framePts, [&stream](const FrameInfo& info, double framePts) { return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; }); int64_t frameIndex = it - stream.allFrames.begin(); - // If the frame index is larger than the size of allFrames, that means we - // couldn't match the pts value to the pts value of a NEXT FRAME. And - // that means that this timestamp falls during the time between when the - // last frame is displayed, and the video ends. Hence, it should map to the - // index of the last frame. - frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1); frameIndices[i] = frameIndex; } - return getFramesAtIndices(streamIndex, frameIndices); } diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 0ed681469..6ad774b5d 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -186,6 +186,52 @@ def test_get_frames_by_pts(self): with pytest.raises(AssertionError): assert_tensor_equal(frames[0], frames[-1]) + def test_pts_apis_against_index_ref(self): + # Non-regression test for https://github.com/pytorch/torchcodec/pull/286 + # Get all frames in the video, then query all frames with all time-based + # APIs exactly where those frames are supposed to start. We assert that + # we get the expected frame. + decoder = create_from_file(str(NASA_VIDEO.path)) + scan_all_streams_to_update_metadata(decoder) + add_video_stream(decoder) + + metadata = get_json_metadata(decoder) + metadata_dict = json.loads(metadata) + num_frames = metadata_dict["numFrames"] + assert num_frames == 390 + + stream_index = 3 + _, all_pts_seconds_ref, _ = zip( + *[ + get_frame_at_index( + decoder, stream_index=stream_index, frame_index=frame_index + ) + for frame_index in range(num_frames) + ] + ) + all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref) + + assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref)) + + _, pts_seconds, _ = zip( + *[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref] + ) + pts_seconds = torch.tensor(pts_seconds) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts_in_range( + decoder, + stream_index=stream_index, + start_seconds=0, + stop_seconds=all_pts_seconds_ref[-1] + 1e-4, + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + + _, pts_seconds, _ = get_frames_by_pts( + decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist() + ) + assert_tensor_equal(pts_seconds, all_pts_seconds_ref) + def test_get_frames_in_range(self): decoder = create_from_file(str(NASA_VIDEO.path)) scan_all_streams_to_update_metadata(decoder)