diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 09c451182..eebbfcef5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -661,9 +661,10 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double frameStartTime = - ptsToSeconds(streamInfo.currentPts, streamInfo.timeBase); + ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); double frameEndTime = ptsToSeconds( - streamInfo.currentPts + streamInfo.currentDuration, streamInfo.timeBase); + streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration, + streamInfo.timeBase); if (seconds >= frameStartTime && seconds < frameEndTime) { // We are in the same frame as the one we just returned. However, since we // don't cache it locally, we have to rewind back. @@ -824,8 +825,8 @@ void VideoDecoder::setCursorPtsInSeconds(double seconds) { Videos have I frames and non-I frames (P and B frames). Non-I frames need data from the previous I frame to be decoded. -Imagine the cursor is at a random frame with PTS=x and we wish to seek to a -user-specified PTS=y. +Imagine the cursor is at a random frame with PTS=lastDecodedAvFramePts (x for +brevity) and we wish to seek to a user-specified PTS=y. If y < x, we don't have a choice but to seek backwards to the highest I frame before y. @@ -845,13 +846,14 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ -bool VideoDecoder::canWeAvoidSeeking(int64_t currentPts, int64_t targetPts) - const { - if (targetPts < currentPts) { +bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const { + int64_t lastDecodedAvFramePts = + streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; + if (targetPts < lastDecodedAvFramePts) { // We can never skip a seek if we are seeking backwards. return false; } - if (currentPts == targetPts) { + if (lastDecodedAvFramePts == targetPts) { // We are seeking to the exact same frame as we are currently at. Without // caching we have to rewind back and decode the frame again. // TODO: https://github.com/pytorch-labs/torchcodec/issues/84 we could @@ -859,9 +861,9 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t currentPts, int64_t targetPts) return false; } // We are seeking forwards. - // We can only skip a seek if both currentPts and targetPts share the same - // keyframe. - int currentKeyFrameIndex = getKeyFrameIndexForPts(currentPts); + // We can only skip a seek if both lastDecodedAvFramePts and targetPts share + // the same keyframe. + int currentKeyFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts); int targetKeyFrameIndex = getKeyFrameIndexForPts(targetPts); return currentKeyFrameIndex >= 0 && targetKeyFrameIndex >= 0 && currentKeyFrameIndex == targetKeyFrameIndex; @@ -879,7 +881,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { decodeStats_.numSeeksAttempted++; int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; - if (canWeAvoidSeeking(streamInfo.currentPts, desiredPtsForStream)) { + if (canWeAvoidSeeking(desiredPtsForStream)) { decodeStats_.numSeeksSkipped++; return; } @@ -1032,8 +1034,8 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( // haven't received as frames. Eventually we will either hit AVERROR_EOF from // av_receive_frame() or the user will have seeked to a different location in // the file and that will flush the decoder. - streamInfo.currentPts = avFrame->pts; - streamInfo.currentDuration = getDuration(avFrame); + streamInfo.lastDecodedAvFramePts = avFrame->pts; + streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); return AVFrameStream(std::move(avFrame), activeStreamIndex_); } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 11f11a4ce..8eae2a76a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -332,8 +332,8 @@ class VideoDecoder { // The current position of the cursor in the stream, and associated frame // duration. - int64_t currentPts = 0; - int64_t currentDuration = 0; + int64_t lastDecodedAvFramePts = 0; + int64_t lastDecodedAvFrameDuration = 0; // The desired position of the cursor in the stream. We send frames >= // this pts to the user when they request a frame. // We update this field if the user requested a seek. This typically @@ -361,7 +361,7 @@ class VideoDecoder { // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- - bool canWeAvoidSeeking(int64_t currentPts, int64_t targetPts) const; + bool canWeAvoidSeeking(int64_t targetPts) const; void maybeSeekToBeforeDesiredPts();