diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 8375753ae..c8870dfb1 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1092,13 +1092,6 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const { // Returns true if we can avoid seeking in the AVFormatContext based on // heuristics that rely on the target cursor_ and the last decoded frame. // Seeking is expensive, so we try to avoid it when possible. - // Note that this function itself isn't always that cheap to call: in - // particular the calls to getKeyFrameIndexForPts below in approximate mode - // are sometimes slow. - // TODO we should understand why (is it because it reads the file?) and - // potentially optimize it. E.g. we may not want to ever seek, or even *check* - // if we need to seek in some cases, like if we're going to decode 80% of the - // frames anyway. const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { // For audio, we only need to seek if a backwards seek was requested @@ -1145,10 +1138,10 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const { // I P P P I P P P I P P I P // x j y // (2) is only more efficient than (1) if there is an I frame between x and y. - int lastKeyFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_); - int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_); - return lastKeyFrameIndex >= 0 && targetKeyFrameIndex >= 0 && - lastKeyFrameIndex == targetKeyFrameIndex; + int lastKeyFrame = getKeyFrameIdentifier(lastDecodedAvFramePts_); + int targetKeyFrame = getKeyFrameIdentifier(cursor_); + return lastKeyFrame >= 0 && targetKeyFrame >= 0 && + lastKeyFrame == targetKeyFrame; } // This method looks at currentPts and desiredPts and seeks in the @@ -1365,7 +1358,19 @@ torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- -int SingleStreamDecoder::getKeyFrameIndexForPts(int64_t pts) const { +int SingleStreamDecoder::getKeyFrameIdentifier(int64_t pts) const { + // This function "identifies" a key frame for a given pts value. + // We use the term "identifier" rather than "index" because the nature of the + // index that is returned depends on various factors: + // - If seek_mode is exact, we return the index of the key frame in the + // scanned key-frame vector (streamInfo.keyFrames). So the returned value is + // in [0, num_key_frames). + // - If seek_mode is approximate, we use av_index_search_timestamp() which + // may return a value in [0, num_key_frames) like for mkv, but also a value + // in [0, num_frames) like for mp4. It really depends on the container. + // + // The range of the "identifier" doesn't matter that much, for now we only + // use it to uniquely identify a key frame in canWeAvoidSeeking(). const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.keyFrames.empty()) { return av_index_search_timestamp( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 305a9c2a6..cbb9cbc2f 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -282,7 +282,7 @@ class SingleStreamDecoder { // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- - int getKeyFrameIndexForPts(int64_t pts) const; + int getKeyFrameIdentifier(int64_t pts) const; // Returns the key frame index of the presentation timestamp using our index. // We build this index by scanning the file in