diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 8366fddd0..4085e90b8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -608,6 +608,7 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { // we have scanned all packets and sorted by pts. FrameInfo frameInfo = {packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { + frameInfo.isKeyFrame = true; streamInfos_[streamIndex].keyFrames.push_back(frameInfo); } streamInfos_[streamIndex].allFrames.push_back(frameInfo); @@ -658,25 +659,23 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { return frameInfo1.pts < frameInfo2.pts; }); - size_t keyIndex = 0; + size_t keyFrameIndex = 0; for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) { streamInfo.allFrames[i].frameIndex = i; - - // For correctly encoded files, we shouldn't need to ensure that keyIndex - // is less than the number of key frames. That is, the relationship - // between the frames in allFrames and keyFrames should be such that - // keyIndex is always a valid index into keyFrames. But we're being - // defensive in case we encounter incorrectly encoded files. - if (keyIndex < streamInfo.keyFrames.size() && - streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) { - streamInfo.keyFrames[keyIndex].frameIndex = i; - ++keyIndex; + if (streamInfo.allFrames[i].isKeyFrame) { + TORCH_CHECK( + keyFrameIndex < streamInfo.keyFrames.size(), + "The allFrames vec claims it has MORE keyFrames than the keyFrames vec. There's a bug in torchcodec."); + streamInfo.keyFrames[keyFrameIndex].frameIndex = i; + ++keyFrameIndex; } - if (i + 1 < streamInfo.allFrames.size()) { streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts; } } + TORCH_CHECK( + keyFrameIndex == streamInfo.keyFrames.size(), + "The allFrames vec claims it has LESS keyFrames than the keyFrames vec. There's a bug in torchcodec."); } scannedAllStreams_ = true; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 81d7ec9cd..cc2ab57ad 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -294,6 +294,8 @@ class VideoDecoder { // FrameInfo structs with *increasing* nextPts values. That's a necessary // condition for the binary searches on those values to work properly (as // typically done during pts -> index conversions). + // TODO: This field is unset (left to the default) for entries in the + // keyFrames vec! int64_t nextPts = INT64_MAX; // Note that frameIndex is ALWAYS the index into all of the frames in that @@ -301,6 +303,11 @@ class VideoDecoder { // FrameInfo for a key frame, the frameIndex allows us to know which frame // that is in the stream. int64_t frameIndex = 0; + + // Indicates whether a frame is a key frame. It may appear redundant as it's + // only true for FrameInfos in the keyFrames index, but it is needed to + // correctly map frames between allFrames and keyFrames during the scan. + bool isKeyFrame = false; }; struct FilterGraphContext {