Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1119,21 +1119,21 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesDisplayedByTimestamps(

auto it = std::lower_bound(
stream.allFrames.begin(),
stream.allFrames.end(),
// We have to end the search at end() - 1 to exclude the last frame from
// the search: the last frame's nextPts field is 0, which breaks the
// working assumption of std::lower_bound() that the search space must
// be sorted. The last frame can still be correctly returned: when the
// binary search ends without a match, `end() - 1` will be returned, and
// that corresponds to the last frame.
// See https://github.com/pytorch/torchcodec/pull/286 for more details.
stream.allFrames.end() - 1,
framePts,
[&stream](const FrameInfo& info, double framePts) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts;
});
int64_t frameIndex = it - stream.allFrames.begin();
// If the frame index is larger than the size of allFrames, that means we
// couldn't match the pts value to the pts value of a NEXT FRAME. And
// that means that this timestamp falls during the time between when the
// last frame is displayed, and the video ends. Hence, it should map to the
// index of the last frame.
frameIndex = std::min(frameIndex, (int64_t)stream.allFrames.size() - 1);
frameIndices[i] = frameIndex;
}

return getFramesAtIndices(streamIndex, frameIndices);
}

Expand Down
46 changes: 46 additions & 0 deletions test/decoders/test_video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,52 @@ def test_get_frames_by_pts(self):
with pytest.raises(AssertionError):
assert_tensor_equal(frames[0], frames[-1])

def test_pts_apis_against_index_ref(self):
# Non-regression test for https://github.com/pytorch/torchcodec/pull/286
# Get all frames in the video, then query all frames with all time-based
# APIs exactly where those frames are supposed to start. We assert that
# we get the expected frame.
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
add_video_stream(decoder)

metadata = get_json_metadata(decoder)
metadata_dict = json.loads(metadata)
num_frames = metadata_dict["numFrames"]
assert num_frames == 390

stream_index = 3
_, all_pts_seconds_ref, _ = zip(
*[
get_frame_at_index(
decoder, stream_index=stream_index, frame_index=frame_index
)
for frame_index in range(num_frames)
]
)
all_pts_seconds_ref = torch.tensor(all_pts_seconds_ref)

assert len(all_pts_seconds_ref.unique() == len(all_pts_seconds_ref))

_, pts_seconds, _ = zip(
*[get_frame_at_pts(decoder, seconds=pts) for pts in all_pts_seconds_ref]
)
pts_seconds = torch.tensor(pts_seconds)
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)

_, pts_seconds, _ = get_frames_by_pts_in_range(
decoder,
stream_index=stream_index,
start_seconds=0,
stop_seconds=all_pts_seconds_ref[-1] + 1e-4,
)
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)

_, pts_seconds, _ = get_frames_by_pts(
decoder, stream_index=stream_index, timestamps=all_pts_seconds_ref.tolist()
)
assert_tensor_equal(pts_seconds, all_pts_seconds_ref)
Comment on lines +230 to +233
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails on main, we return the last frame instead of the second-to-last frame due to the logic explained in the main comment.


def test_get_frames_in_range(self):
decoder = create_from_file(str(NASA_VIDEO.path))
scan_all_streams_to_update_metadata(decoder)
Expand Down
Loading