Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 16 additions & 19 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -567,10 +567,8 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrame() {

VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor) {
AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) {
StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_];
return avFrame->pts >= activeStreamInfo.discardFramesBeforePts;
});
AVFrameStream avFrameStream = decodeAVFrame(
[this](AVFrame* avFrame) { return avFrame->pts >= cursor_; });
return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor);
}

Expand Down Expand Up @@ -842,7 +840,9 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange(
// --------------------------------------------------------------------------

void VideoDecoder::setCursorPtsInSeconds(double seconds) {
desiredPtsSeconds_ = seconds;
cursorWasJustSet_ = true;
cursor_ =
secondsToClosestPts(seconds, streamInfos_[activeStreamIndex_].timeBase);
}

/*
Expand Down Expand Up @@ -870,25 +870,25 @@ I P P P I P P P I P P I P P I P

(2) is more efficient than (1) if there is an I frame between x and y.
*/
bool VideoDecoder::canWeAvoidSeeking(int64_t targetPts) const {
bool VideoDecoder::canWeAvoidSeeking() const {
int64_t lastDecodedAvFramePts =
streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts;
if (targetPts < lastDecodedAvFramePts) {
if (cursor_ < lastDecodedAvFramePts) {
// We can never skip a seek if we are seeking backwards.
return false;
}
if (lastDecodedAvFramePts == targetPts) {
if (lastDecodedAvFramePts == cursor_) {
// We are seeking to the exact same frame as we are currently at. Without
// caching we have to rewind back and decode the frame again.
// TODO: https://github.com/pytorch-labs/torchcodec/issues/84 we could
// implement caching.
return false;
}
// We are seeking forwards.
// We can only skip a seek if both lastDecodedAvFramePts and targetPts share
// the same keyframe.
// We can only skip a seek if both lastDecodedAvFramePts and
// cursor_ share the same keyframe.
int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts);
int targetKeyFrameIndex = getKeyFrameIndexForPts(targetPts);
int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_);
return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 &&
lastDecodedAvFrameIndex == targetKeyFrameIndex;
}
Expand All @@ -900,16 +900,14 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() {
validateActiveStream(AVMEDIA_TYPE_VIDEO);
StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];

int64_t desiredPts =
secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase);
streamInfo.discardFramesBeforePts = desiredPts;

decodeStats_.numSeeksAttempted++;
if (canWeAvoidSeeking(desiredPts)) {
if (canWeAvoidSeeking()) {
decodeStats_.numSeeksSkipped++;
return;
}

int64_t desiredPts = cursor_;

// For some encodings like H265, FFMPEG sometimes seeks past the point we
// set as the max_ts. So we use our own index to give it the exact pts of
// the key frame that we want to seek to.
Expand Down Expand Up @@ -948,10 +946,9 @@ VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame(

resetDecodeStats();

// Seek if needed.
if (desiredPtsSeconds_.has_value()) {
if (cursorWasJustSet_) {
maybeSeekToBeforeDesiredPts();
desiredPtsSeconds_ = std::nullopt;
cursorWasJustSet_ = false;
Copy link
Contributor Author

@NicolasHug NicolasHug Mar 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not particularly fond of this logic, but don't shoot the messenger. This PR is only making an existing stateful logic a little bit more obvious. I.e. now we have an obvious bool field, while before we had a slightly obscure std::optional logic.

We could try to simplify this further, I gave it a quick try but our tests failed in some weird edge-cases, so I'm leaving this for later.

}

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
Expand Down
20 changes: 9 additions & 11 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,15 +332,11 @@ class VideoDecoder {
std::vector<FrameInfo> keyFrames;
std::vector<FrameInfo> allFrames;

// The current position of the cursor in the stream, and associated frame
// duration.
// TODO since the decoder is single-stream, these should be decoder fields,
// not streamInfo fields. And they should be defined right next to
// `cursor_`, with joint documentation.
int64_t lastDecodedAvFramePts = 0;
int64_t lastDecodedAvFrameDuration = 0;
// The desired position of the cursor in the stream. We send frames >=
// this pts to the user when they request a frame.
// We update this field if the user requested a seek. This typically
// corresponds to the decoder's desiredPts_ attribute.
int64_t discardFramesBeforePts = INT64_MIN;
VideoStreamOptions videoStreamOptions;

// color-conversion fields. Only one of FilterGraphContext and
Expand All @@ -363,7 +359,7 @@ class VideoDecoder {
// DECODING APIS AND RELATED UTILS
// --------------------------------------------------------------------------

bool canWeAvoidSeeking(int64_t targetPts) const;
bool canWeAvoidSeeking() const;

void maybeSeekToBeforeDesiredPts();

Expand Down Expand Up @@ -466,9 +462,11 @@ class VideoDecoder {
std::map<int, StreamInfo> streamInfos_;
const int NO_ACTIVE_STREAM = -2;
int activeStreamIndex_ = NO_ACTIVE_STREAM;
// Set when the user wants to seek and stores the desired pts that the user
// wants to seek to.
std::optional<double> desiredPtsSeconds_;

bool cursorWasJustSet_ = false;
// The desired position of the cursor in the stream. We send frames >= this
// pts to the user when they request a frame.
int64_t cursor_ = INT64_MIN;
// Stores various internal decoding stats.
DecodeStats decodeStats_;
// Stores the AVIOContext for the input buffer.
Expand Down
Loading