diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6e7e72f27..97214cec1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -567,10 +567,8 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { - AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { - StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; - return avFrame->pts >= activeStreamInfo.discardFramesBeforePts; - }); + AVFrameStream avFrameStream = decodeAVFrame( + [this](AVFrame* avFrame) { return avFrame->pts >= cursor_; }); return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); } @@ -842,7 +840,9 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( // -------------------------------------------------------------------------- void VideoDecoder::setCursorPtsInSeconds(double seconds) { - desiredPtsSeconds_ = seconds; + cursorWasJustSet_ = true; + cursor_ = + secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase); } /* @@ -870,14 +870,14 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ -bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const { +bool VideoDecoder::canWeAvoidSeeking() const { int64_t lastDecodedAvFramePts = streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; - if (targetPts < lastDecodedAvFramePts) { + if (cursor_ < lastDecodedAvFramePts) { // We can never skip a seek if we are seeking backwards. return false; } - if (lastDecodedAvFramePts == targetPts) { + if (lastDecodedAvFramePts == cursor_) { // We are seeking to the exact same frame as we are currently at. Without // caching we have to rewind back and decode the frame again. // TODO: https://github.com/pytorch-labs/torchcodec/issues/84 we could @@ -885,10 +885,10 @@ bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const { return false; } // We are seeking forwards. - // We can only skip a seek if both lastDecodedAvFramePts and targetPts share - // the same keyframe. + // We can only skip a seek if both lastDecodedAvFramePts and + // cursor_ share the same keyframe. int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts); - int targetKeyFrameIndex = getKeyFrameIndexForPts(targetPts); + int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_); return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 && lastDecodedAvFrameIndex == targetKeyFrameIndex; } @@ -900,16 +900,14 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - int64_t desiredPts = - secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); - streamInfo.discardFramesBeforePts = desiredPts; - decodeStats_.numSeeksAttempted++; - if (canWeAvoidSeeking(desiredPts)) { + if (canWeAvoidSeeking()) { decodeStats_.numSeeksSkipped++; return; } + int64_t desiredPts = cursor_; + // For some encodings like H265, FFMPEG sometimes seeks past the point we // set as the max_ts. So we use our own index to give it the exact pts of // the key frame that we want to seek to. @@ -948,10 +946,9 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( resetDecodeStats(); - // Seek if needed. - if (desiredPtsSeconds_.has_value()) { + if (cursorWasJustSet_) { maybeSeekToBeforeDesiredPts(); - desiredPtsSeconds_ = std::nullopt; + cursorWasJustSet_ = false; } StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a28dcf9cb..a41ea50c2 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -332,15 +332,11 @@ class VideoDecoder { std::vector keyFrames; std::vector allFrames; - // The current position of the cursor in the stream, and associated frame - // duration. + // TODO since the decoder is single-stream, these should be decoder fields, + // not streamInfo fields. And they should be defined right next to + // `cursor_`, with joint documentation. int64_t lastDecodedAvFramePts = 0; int64_t lastDecodedAvFrameDuration = 0; - // The desired position of the cursor in the stream. We send frames >= - // this pts to the user when they request a frame. - // We update this field if the user requested a seek. This typically - // corresponds to the decoder's desiredPts_ attribute. - int64_t discardFramesBeforePts = INT64_MIN; VideoStreamOptions videoStreamOptions; // color-conversion fields. Only one of FilterGraphContext and @@ -363,7 +359,7 @@ class VideoDecoder { // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- - bool canWeAvoidSeeking(int64_t targetPts) const; + bool canWeAvoidSeeking() const; void maybeSeekToBeforeDesiredPts(); @@ -466,9 +462,11 @@ class VideoDecoder { std::map streamInfos_; const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; - // Set when the user wants to seek and stores the desired pts that the user - // wants to seek to. - std::optional desiredPtsSeconds_; + + bool cursorWasJustSet_ = false; + // The desired position of the cursor in the stream. We send frames >= this + // pts to the user when they request a frame. + int64_t cursor_ = INT64_MIN; // Stores various internal decoding stats. DecodeStats decodeStats_; // Stores the AVIOContext for the input buffer.