From dfa81eebaad39d9eaf256ae56b245def0f88cb21 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 9 Feb 2025 16:10:01 +0000 Subject: [PATCH 1/6] Use validateActiveStream --- .../decoders/_core/VideoDecoder.cpp | 42 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 +- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 47b105294..9d8757562 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -349,7 +349,7 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const { } torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); validateScannedAllStreams("getKeyFrameIndices"); const std::vector& keyFrames = streamInfos_[streamIndex].keyFrames; @@ -565,7 +565,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); const auto& streamInfo = streamInfos_[streamIndex]; const auto& streamMetadata = @@ -580,7 +580,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -643,7 +643,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( int64_t start, int64_t stop, int64_t step) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); const auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; @@ -720,7 +720,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux( VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt( int streamIndex, const std::vector& timestamps) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); const auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; @@ -754,7 +754,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( int streamIndex, double startSeconds, double stopSeconds) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); const auto& streamMetadata = containerMetadata_.allStreamMetadata[streamIndex]; @@ -898,9 +898,7 @@ bool VideoDecoder::canWeAvoidSeekingForStream( // AVFormatContext if it is needed. We can skip seeking in certain cases. See // the comment of canWeAvoidSeeking() for details. void VideoDecoder::maybeSeekToBeforeDesiredPts() { - if (activeStreamIndex_ == NO_ACTIVE_STREAM) { - return; - } + validateActiveStream(); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.discardFramesBeforePts = secondsToClosestPts(*desiredPtsSeconds_, streamInfo.timeBase); @@ -950,9 +948,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( std::function filterFunction) { - if (activeStreamIndex_ == NO_ACTIVE_STREAM) { - throw std::runtime_error("No active streams configured."); - } + validateActiveStream(); resetDecodeStats(); @@ -1636,18 +1632,20 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { // VALIDATION UTILS // -------------------------------------------------------------------------- -void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) { - int streamsSize = +void VideoDecoder::validateActiveStream() { + auto errorMsg = + "Provided stream index=" + std::to_string(activeStreamIndex_) + + " was not previously added."; + TORCH_CHECK(activeStreamIndex_ != NO_ACTIVE_STREAM, errorMsg); + TORCH_CHECK(streamInfos_.count(activeStreamIndex_) > 0, errorMsg); + + int allStreamMetadataSize = static_cast(containerMetadata_.allStreamMetadata.size()); TORCH_CHECK( - streamIndex >= 0 && streamIndex < streamsSize, - "Invalid stream index=" + std::to_string(streamIndex) + + activeStreamIndex_ >= 0 && activeStreamIndex_ < allStreamMetadataSize, + "Invalid stream index=" + std::to_string(activeStreamIndex_) + "; valid indices are in the range [0, " + - std::to_string(streamsSize) + ")."); - TORCH_CHECK( - streamInfos_.count(streamIndex) > 0, - "Provided stream index=" + std::to_string(streamIndex) + - " was not previously added."); + std::to_string(allStreamMetadataSize) + ")."); } void VideoDecoder::validateScannedAllStreams(const std::string& msg) { @@ -1697,7 +1695,7 @@ void VideoDecoder::resetDecodeStats() { double VideoDecoder::getPtsSecondsForFrame( int streamIndex, int64_t frameIndex) { - validateUserProvidedStreamIndex(streamIndex); + validateActiveStream(); validateScannedAllStreams("getPtsSecondsForFrame"); const auto& streamInfo = streamInfos_[streamIndex]; diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 7db5a4a67..451edeb14 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -467,7 +467,7 @@ class VideoDecoder { // VALIDATION UTILS // -------------------------------------------------------------------------- - void validateUserProvidedStreamIndex(int streamIndex); + void validateActiveStream(); void validateScannedAllStreams(const std::string& msg); void validateFrameIndex( const StreamMetadata& streamMetadata, From 6df28130356a7af0829d9782036cf4e4c4930a54 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 9 Feb 2025 16:29:44 +0000 Subject: [PATCH 2/6] Stuff --- .../decoders/_core/VideoDecoder.cpp | 116 +++++++----------- src/torchcodec/decoders/_core/VideoDecoder.h | 35 ++---- .../decoders/_core/VideoDecoderOps.cpp | 31 +++-- 3 files changed, 73 insertions(+), 109 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 9d8757562..c01d7cd30 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -348,11 +348,12 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const { return containerMetadata_; } -torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) { +torch::Tensor VideoDecoder::getKeyFrameIndices() { validateActiveStream(); validateScannedAllStreams("getKeyFrameIndices"); - const std::vector& keyFrames = streamInfos_[streamIndex].keyFrames; + const std::vector& keyFrames = + streamInfos_[activeStreamIndex_].keyFrames; torch::Tensor keyFrameIndices = torch::empty({static_cast(keyFrames.size())}, {torch::kInt64}); for (size_t i = 0; i < keyFrames.size(); ++i) { @@ -540,7 +541,7 @@ void VideoDecoder::updateMetadataWithCodecContext( VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { auto output = getNextFrameNoDemuxInternal(); - output.data = maybePermuteHWC2CHW(output.streamIndex, output.data); + output.data = maybePermuteHWC2CHW(output.data); return output; } @@ -553,23 +554,20 @@ VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( return convertAVFrameToFrameOutput(avFrameStream, preAllocatedOutputTensor); } -VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex( - int streamIndex, - int64_t frameIndex) { - auto frameOutput = getFrameAtIndexInternal(streamIndex, frameIndex); - frameOutput.data = maybePermuteHWC2CHW(streamIndex, frameOutput.data); +VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndex(int64_t frameIndex) { + auto frameOutput = getFrameAtIndexInternal(frameIndex); + frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); return frameOutput; } VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( - int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateActiveStream(); - const auto& streamInfo = streamInfos_[streamIndex]; + const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; validateFrameIndex(streamMetadata, frameIndex); int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); @@ -578,7 +576,6 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( - int streamIndex, const std::vector& frameIndices) { validateActiveStream(); @@ -602,8 +599,8 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( } const auto& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; - const auto& streamInfo = streamInfos_[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( frameIndices.size(), videoStreamOptions, streamMetadata); @@ -626,28 +623,24 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( frameBatchOutput.durationSeconds[previousIndexInOutput]; } else { FrameOutput frameOutput = getFrameAtIndexInternal( - streamIndex, indexInVideo, frameBatchOutput.data[indexInOutput]); + indexInVideo, frameBatchOutput.data[indexInOutput]); frameBatchOutput.ptsSeconds[indexInOutput] = frameOutput.ptsSeconds; frameBatchOutput.durationSeconds[indexInOutput] = frameOutput.durationSeconds; } previousIndexInVideo = indexInVideo; } - frameBatchOutput.data = - maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } -VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( - int streamIndex, - int64_t start, - int64_t stop, - int64_t step) { +VideoDecoder::FrameBatchOutput +VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { validateActiveStream(); const auto& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; - const auto& streamInfo = streamInfos_[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + const auto& streamInfo = streamInfos_[activeStreamIndex_]; int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); @@ -665,12 +658,11 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesInRange( for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput frameOutput = - getFrameAtIndexInternal(streamIndex, i, frameBatchOutput.data[f]); + getFrameAtIndexInternal(i, frameBatchOutput.data[f]); frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; } - frameBatchOutput.data = - maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -712,19 +704,17 @@ VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux( // Convert the frame to tensor. FrameOutput frameOutput = convertAVFrameToFrameOutput(avFrameStream); - frameOutput.data = - maybePermuteHWC2CHW(frameOutput.streamIndex, frameOutput.data); + frameOutput.data = maybePermuteHWC2CHW(frameOutput.data); return frameOutput; } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt( - int streamIndex, const std::vector& timestamps) { validateActiveStream(); const auto& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; - const auto& streamInfo = streamInfos_[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; + const auto& streamInfo = streamInfos_[activeStreamIndex_]; double minSeconds = getMinSeconds(streamMetadata); double maxSeconds = getMaxSeconds(streamMetadata); @@ -747,24 +737,23 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedAt( secondsToIndexLowerBound(frameSeconds, streamInfo, streamMetadata); } - return getFramesAtIndices(streamIndex, frameIndices); + return getFramesAtIndices(frameIndices); } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( - int streamIndex, double startSeconds, double stopSeconds) { validateActiveStream(); const auto& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + std::to_string(stopSeconds) + "."); - const auto& streamInfo = streamInfos_[streamIndex]; + const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; // Special case needed to implement a half-open range. At first glance, this @@ -786,8 +775,7 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( // need this special case below. if (startSeconds == stopSeconds) { FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); - frameBatchOutput.data = - maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -827,12 +815,11 @@ VideoDecoder::FrameBatchOutput VideoDecoder::getFramesPlayedInRange( numFrames, videoStreamOptions, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput frameOutput = - getFrameAtIndexInternal(streamIndex, i, frameBatchOutput.data[f]); + getFrameAtIndexInternal(i, frameBatchOutput.data[f]); frameBatchOutput.ptsSeconds[f] = frameOutput.ptsSeconds; frameBatchOutput.durationSeconds[f] = frameOutput.durationSeconds; } - frameBatchOutput.data = - maybePermuteHWC2CHW(streamIndex, frameBatchOutput.data); + frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -871,9 +858,8 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ bool VideoDecoder::canWeAvoidSeekingForStream( - const StreamInfo& streamInfo, int64_t currentPts, - int64_t targetPts) const { + int64_t targetPts) { if (targetPts < currentPts) { // We can never skip a seek if we are seeking backwards. return false; @@ -888,8 +874,8 @@ bool VideoDecoder::canWeAvoidSeekingForStream( // We are seeking forwards. // We can only skip a seek if both currentPts and targetPts share the same // keyframe. - int currentKeyFrameIndex = getKeyFrameIndexForPts(streamInfo, currentPts); - int targetKeyFrameIndex = getKeyFrameIndexForPts(streamInfo, targetPts); + int currentKeyFrameIndex = getKeyFrameIndexForPts(currentPts); + int targetKeyFrameIndex = getKeyFrameIndexForPts(targetPts); return currentKeyFrameIndex >= 0 && targetKeyFrameIndex >= 0 && currentKeyFrameIndex == targetKeyFrameIndex; } @@ -906,8 +892,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { decodeStats_.numSeeksAttempted++; int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; - if (canWeAvoidSeekingForStream( - streamInfo, streamInfo.currentPts, desiredPtsForStream)) { + if (canWeAvoidSeekingForStream(streamInfo.currentPts, desiredPtsForStream)) { decodeStats_.numSeeksSkipped++; return; } @@ -948,7 +933,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::AVFrameStream VideoDecoder::decodeAVFrame( std::function filterFunction) { - validateActiveStream(); + validateActiveStream(); resetDecodeStats(); @@ -1116,9 +1101,8 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( VideoDecoder::AVFrameStream& avFrameStream, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - int streamIndex = avFrameStream.streamIndex; AVFrame* avFrame = avFrameStream.avFrame.get(); - auto& streamInfo = streamInfos_[streamIndex]; + auto& streamInfo = streamInfos_[activeStreamIndex_]; auto frameDims = getHeightAndWidthFromOptionsOrAVFrame( streamInfo.videoStreamOptions, *avFrame); @@ -1164,7 +1148,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( streamInfo.prevFrameContext = frameContext; } int resultHeight = - convertAVFrameToTensorUsingSwsScale(streamIndex, avFrame, outputTensor); + convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? @@ -1184,7 +1168,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); streamInfo.prevFrameContext = frameContext; } - outputTensor = convertAVFrameToTensorUsingFilterGraph(streamIndex, avFrame); + outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. @@ -1215,10 +1199,9 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( } int VideoDecoder::convertAVFrameToTensorUsingSwsScale( - int streamIndex, const AVFrame* avFrame, torch::Tensor& outputTensor) { - StreamInfo& activeStreamInfo = streamInfos_[streamIndex]; + StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; SwsContext* swsContext = activeStreamInfo.swsContext.get(); uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; @@ -1236,10 +1219,9 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale( } torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( - int streamIndex, const AVFrame* avFrame) { FilterGraphContext& filterGraphContext = - streamInfos_[streamIndex].filterGraphContext; + streamInfos_[activeStreamIndex_].filterGraphContext; int ffmpegStatus = av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame); if (ffmpegStatus < AVSUCCESS) { @@ -1308,10 +1290,9 @@ torch::Tensor allocateEmptyHWCTensor( // or 4D. // Calling permute() is guaranteed to return a view as per the docs: // https://pytorch.org/docs/stable/generated/torch.permute.html -torch::Tensor VideoDecoder::maybePermuteHWC2CHW( - int streamIndex, - torch::Tensor& hwcTensor) { - if (streamInfos_[streamIndex].videoStreamOptions.dimensionOrder == "NHWC") { +torch::Tensor VideoDecoder::maybePermuteHWC2CHW(torch::Tensor& hwcTensor) { + if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder == + "NHWC") { return hwcTensor; } auto numDimensions = hwcTensor.dim(); @@ -1503,9 +1484,8 @@ void VideoDecoder::createSwsContext( // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- -int VideoDecoder::getKeyFrameIndexForPts( - const StreamInfo& streamInfo, - int64_t pts) const { +int VideoDecoder::getKeyFrameIndexForPts(int64_t pts) { + const StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; if (streamInfo.keyFrames.empty()) { return av_index_search_timestamp( streamInfo.stream, pts, AVSEEK_FLAG_BACKWARD); @@ -1516,7 +1496,7 @@ int VideoDecoder::getKeyFrameIndexForPts( int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( const std::vector& keyFrames, - int64_t pts) const { + int64_t pts) { auto upperBound = std::upper_bound( keyFrames.begin(), keyFrames.end(), @@ -1692,15 +1672,13 @@ void VideoDecoder::resetDecodeStats() { decodeStats_ = DecodeStats{}; } -double VideoDecoder::getPtsSecondsForFrame( - int streamIndex, - int64_t frameIndex) { +double VideoDecoder::getPtsSecondsForFrame(int64_t frameIndex) { validateActiveStream(); validateScannedAllStreams("getPtsSecondsForFrame"); - const auto& streamInfo = streamInfos_[streamIndex]; + const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; validateFrameIndex(streamMetadata, frameIndex); return ptsToSeconds( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 451edeb14..f4425b7a7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -99,7 +99,7 @@ class VideoDecoder { // Returns the key frame indices as a tensor. The tensor is 1D and contains // int64 values, where each value is the frame index for a key frame. - torch::Tensor getKeyFrameIndices(int streamIndex); + torch::Tensor getKeyFrameIndices(); // -------------------------------------------------------------------------- // ADDING STREAMS API @@ -179,19 +179,16 @@ class VideoDecoder { // the cursor to the next frame. FrameOutput getNextFrameNoDemux(); - FrameOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); + FrameOutput getFrameAtIndex(int64_t frameIndex); // Returns frames at the given indices for a given stream as a single stacked // Tensor. - FrameBatchOutput getFramesAtIndices( - int streamIndex, - const std::vector& frameIndices); + FrameBatchOutput getFramesAtIndices(const std::vector& frameIndices); // Returns frames within a given range. The range is defined by [start, stop). // The values retrieved from the range are: [start, start+step, // start+(2*step), start+(3*step), ..., stop). The default for step is 1. - FrameBatchOutput - getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); + FrameBatchOutput getFramesInRange(int64_t start, int64_t stop, int64_t step); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a @@ -201,9 +198,7 @@ class VideoDecoder { // seconds=5.999, etc. FrameOutput getFramePlayedAtNoDemux(double seconds); - FrameBatchOutput getFramesPlayedAt( - int streamIndex, - const std::vector& timestamps); + FrameBatchOutput getFramesPlayedAt(const std::vector& timestamps); // Returns frames within a given pts range. The range is defined by // [startSeconds, stopSeconds) with respect to the pts values for frames. The @@ -223,7 +218,6 @@ class VideoDecoder { // // [minPtsSecondsFromScan, maxPtsSecondsFromScan) FrameBatchOutput getFramesPlayedInRange( - int streamIndex, double startSeconds, double stopSeconds); @@ -259,13 +253,12 @@ class VideoDecoder { // Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we // can move it back to private. FrameOutput getFrameAtIndexInternal( - int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor = std::nullopt); // Exposed for _test_frame_pts_equality, which is used to test non-regression // of pts resolution (64 to 32 bit floats) - double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); + double getPtsSecondsForFrame(int64_t frameIndex); // Exposed for performance testing. struct DecodeStats { @@ -372,10 +365,7 @@ class VideoDecoder { // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- - bool canWeAvoidSeekingForStream( - const StreamInfo& stream, - int64_t currentPts, - int64_t targetPts) const; + bool canWeAvoidSeekingForStream(int64_t currentPts, int64_t targetPts); void maybeSeekToBeforeDesiredPts(); @@ -384,7 +374,7 @@ class VideoDecoder { FrameOutput getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); - torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor); FrameOutput convertAVFrameToFrameOutput( AVFrameStream& avFrameStream, @@ -395,12 +385,9 @@ class VideoDecoder { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( - int streamIndex, - const AVFrame* avFrame); + torch::Tensor convertAVFrameToTensorUsingFilterGraph(const AVFrame* avFrame); int convertAVFrameToTensorUsingSwsScale( - int streamIndex, const AVFrame* avFrame, torch::Tensor& outputTensor); @@ -422,14 +409,14 @@ class VideoDecoder { // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- - int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const; + int getKeyFrameIndexForPts(int64_t pts); // Returns the key frame index of the presentation timestamp using our index. // We build this index by scanning the file in // scanFileAndUpdateMetadataAndIndex int getKeyFrameIndexForPtsUsingScannedIndex( const std::vector& keyFrames, - int64_t pts) const; + int64_t pts); int64_t secondsToIndexLowerBound( double seconds, diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 78ecc4258..88c376d77 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -253,54 +253,53 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { OpsFrameOutput get_frame_at_index( at::Tensor& decoder, - int64_t stream_index, + [[maybe_unused]] int64_t stream_index, int64_t frame_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFrameAtIndex(stream_index, frame_index); + auto result = videoDecoder->getFrameAtIndex(frame_index); return makeOpsFrameOutput(result); } OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, - int64_t stream_index, + [[maybe_unused]] int64_t stream_index, at::IntArrayRef frame_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector frameIndicesVec( frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices(stream_index, frameIndicesVec); + auto result = videoDecoder->getFramesAtIndices(frameIndicesVec); return makeOpsFrameBatchOutput(result); } OpsFrameBatchOutput get_frames_in_range( at::Tensor& decoder, - int64_t stream_index, + [[maybe_unused]] int64_t stream_index, int64_t start, int64_t stop, std::optional step) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFramesInRange( - stream_index, start, stop, step.value_or(1)); + auto result = videoDecoder->getFramesInRange(start, stop, step.value_or(1)); return makeOpsFrameBatchOutput(result); } OpsFrameBatchOutput get_frames_by_pts( at::Tensor& decoder, - int64_t stream_index, + [[maybe_unused]] int64_t stream_index, at::ArrayRef timestamps) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); std::vector timestampsVec(timestamps.begin(), timestamps.end()); - auto result = videoDecoder->getFramesPlayedAt(stream_index, timestampsVec); + auto result = videoDecoder->getFramesPlayedAt(timestampsVec); return makeOpsFrameBatchOutput(result); } OpsFrameBatchOutput get_frames_by_pts_in_range( at::Tensor& decoder, - int64_t stream_index, + [[maybe_unused]] int64_t stream_index, double start_seconds, double stop_seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFramesPlayedInRange( - stream_index, start_seconds, stop_seconds); + auto result = + videoDecoder->getFramesPlayedInRange(start_seconds, stop_seconds); return makeOpsFrameBatchOutput(result); } @@ -328,19 +327,19 @@ std::string mapToJson(const std::map& metadataMap) { bool _test_frame_pts_equality( at::Tensor& decoder, - int64_t stream_index, + [[maybe_unused]] int64_t stream_index, int64_t frame_index, double pts_seconds_to_test) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); return pts_seconds_to_test == - videoDecoder->getPtsSecondsForFrame(stream_index, frame_index); + videoDecoder->getPtsSecondsForFrame(frame_index); } torch::Tensor _get_key_frame_indices( at::Tensor& decoder, - int64_t stream_index) { + [[maybe_unused]] int64_t stream_index) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - return videoDecoder->getKeyFrameIndices(stream_index); + return videoDecoder->getKeyFrameIndices(); } std::string get_json_metadata(at::Tensor& decoder) { From 678d4fef3fd516ff552405894200acf0d1cc34f7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 9 Feb 2025 17:18:34 +0000 Subject: [PATCH 3/6] more --- .../decoders/_core/VideoDecoder.cpp | 52 ++++++++----------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 -- test/decoders/VideoDecoderTest.cpp | 12 ++--- 3 files changed, 28 insertions(+), 40 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c01d7cd30..d8509ab27 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -419,7 +419,7 @@ VideoDecoder::VideoStreamOptions::VideoStreamOptions( } void VideoDecoder::addVideoStreamDecoder( - int preferredStreamIndex, + int streamIndex, const VideoStreamOptions& videoStreamOptions) { TORCH_CHECK( activeStreamIndex_ == NO_ACTIVE_STREAM, @@ -427,26 +427,22 @@ void VideoDecoder::addVideoStreamDecoder( TORCH_CHECK(formatContext_.get() != nullptr); AVCodecOnlyUseForCallingAVFindBestStream avCodec = nullptr; - int streamIndex = av_find_best_stream( - formatContext_.get(), - AVMEDIA_TYPE_VIDEO, - preferredStreamIndex, - -1, - &avCodec, - 0); - if (streamIndex < 0) { + + activeStreamIndex_ = av_find_best_stream( + formatContext_.get(), AVMEDIA_TYPE_VIDEO, streamIndex, -1, &avCodec, 0); + if (activeStreamIndex_ < 0) { throw std::invalid_argument("No valid stream found in input file."); } TORCH_CHECK(avCodec != nullptr); - StreamInfo& streamInfo = streamInfos_[streamIndex]; - streamInfo.streamIndex = streamIndex; - streamInfo.timeBase = formatContext_->streams[streamIndex]->time_base; - streamInfo.stream = formatContext_->streams[streamIndex]; + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + streamInfo.streamIndex = activeStreamIndex_; + streamInfo.timeBase = formatContext_->streams[activeStreamIndex_]->time_base; + streamInfo.stream = formatContext_->streams[activeStreamIndex_]; if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) { throw std::invalid_argument( - "Stream with index " + std::to_string(streamIndex) + + "Stream with index " + std::to_string(activeStreamIndex_) + " is not a video stream."); } @@ -458,11 +454,12 @@ void VideoDecoder::addVideoStreamDecoder( } StreamMetadata& streamMetadata = - containerMetadata_.allStreamMetadata[streamIndex]; + containerMetadata_.allStreamMetadata[activeStreamIndex_]; if (seekMode_ == SeekMode::approximate && !streamMetadata.averageFps.has_value()) { throw std::runtime_error( - "Seek mode is approximate, but stream " + std::to_string(streamIndex) + + "Seek mode is approximate, but stream " + + std::to_string(activeStreamIndex_) + " does not have an average fps in its metadata."); } @@ -483,6 +480,7 @@ void VideoDecoder::addVideoStreamDecoder( TORCH_CHECK( false, "Invalid device type: " + videoStreamOptions.device.str()); } + streamInfo.videoStreamOptions = videoStreamOptions; retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); if (retVal < AVSUCCESS) { @@ -490,9 +488,14 @@ void VideoDecoder::addVideoStreamDecoder( } codecContext->time_base = streamInfo.stream->time_base; - activeStreamIndex_ = streamIndex; - updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); - streamInfo.videoStreamOptions = videoStreamOptions; + + containerMetadata_.allStreamMetadata[activeStreamIndex_].width = + codecContext->width; + containerMetadata_.allStreamMetadata[activeStreamIndex_].height = + codecContext->height; + auto codedId = codecContext->codec_id; + containerMetadata_.allStreamMetadata[activeStreamIndex_].codecName = + std::string(avcodec_get_name(codedId)); // We will only need packets from the active stream, so we tell FFmpeg to // discard packets from the other streams. Note that av_read_frame() may still @@ -524,17 +527,6 @@ void VideoDecoder::addVideoStreamDecoder( videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); } -void VideoDecoder::updateMetadataWithCodecContext( - int streamIndex, - AVCodecContext* codecContext) { - containerMetadata_.allStreamMetadata[streamIndex].width = codecContext->width; - containerMetadata_.allStreamMetadata[streamIndex].height = - codecContext->height; - auto codedId = codecContext->codec_id; - containerMetadata_.allStreamMetadata[streamIndex].codecName = - std::string(avcodec_get_name(codedId)); -} - // -------------------------------------------------------------------------- // HIGH-LEVEL DECODING ENTRY-POINTS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index f4425b7a7..7097a904b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -357,10 +357,6 @@ class VideoDecoder { // -------------------------------------------------------------------------- void initializeDecoder(); - void updateMetadataWithCodecContext( - int streamIndex, - AVCodecContext* codecContext); - // -------------------------------------------------------------------------- // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index c297bb7b5..00d73961a 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -208,7 +208,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNCHW) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStreamDecoder(bestVideoStreamIndex); // Frame with index 180 corresponds to timestamp 6.006. - auto output = ourDecoder->getFramesAtIndices(bestVideoStreamIndex, {0, 180}); + auto output = ourDecoder->getFramesAtIndices({0, 180}); auto tensor = output.data; EXPECT_EQ(tensor.sizes(), std::vector({2, 3, 270, 480})); @@ -232,7 +232,7 @@ TEST_P(VideoDecoderTest, DecodesFramesInABatchInNHWC) { bestVideoStreamIndex, VideoDecoder::VideoStreamOptions("dimension_order=NHWC")); // Frame with index 180 corresponds to timestamp 6.006. - auto output = ourDecoder->getFramesAtIndices(bestVideoStreamIndex, {0, 180}); + auto output = ourDecoder->getFramesAtIndices({0, 180}); auto tensor = output.data; EXPECT_EQ(tensor.sizes(), std::vector({2, 270, 480, 3})); @@ -397,8 +397,8 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) { ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, VideoDecoder::VideoStreamOptions("color_conversion_library=filtergraph")); - auto output = ourDecoder->getFrameAtIndexInternal( - bestVideoStreamIndex, 0, preAllocatedOutputTensor); + auto output = + ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); } @@ -414,8 +414,8 @@ TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { ourDecoder->addVideoStreamDecoder( bestVideoStreamIndex, VideoDecoder::VideoStreamOptions("color_conversion_library=swscale")); - auto output = ourDecoder->getFrameAtIndexInternal( - bestVideoStreamIndex, 0, preAllocatedOutputTensor); + auto output = + ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); } From 10022abf7f0e09669e26ff9531989b89b76164ce Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 9 Feb 2025 17:19:25 +0000 Subject: [PATCH 4/6] Remove NoDemux --- .../decoders/_core/VideoDecoder.cpp | 10 +++--- src/torchcodec/decoders/_core/VideoDecoder.h | 8 ++--- .../decoders/_core/VideoDecoderOps.cpp | 4 +-- test/decoders/VideoDecoderTest.cpp | 32 +++++++++---------- 4 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index d8509ab27..478e3f6cb 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -531,13 +531,13 @@ void VideoDecoder::addVideoStreamDecoder( // HIGH-LEVEL DECODING ENTRY-POINTS // -------------------------------------------------------------------------- -VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemux() { - auto output = getNextFrameNoDemuxInternal(); +VideoDecoder::FrameOutput VideoDecoder::getNextFrame() { + auto output = getNextFrameInternal(); output.data = maybePermuteHWC2CHW(output.data); return output; } -VideoDecoder::FrameOutput VideoDecoder::getNextFrameNoDemuxInternal( +VideoDecoder::FrameOutput VideoDecoder::getNextFrameInternal( std::optional preAllocatedOutputTensor) { AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { StreamInfo& activeStreamInfo = streamInfos_[activeStreamIndex_]; @@ -564,7 +564,7 @@ VideoDecoder::FrameOutput VideoDecoder::getFrameAtIndexInternal( int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); - return getNextFrameNoDemuxInternal(preAllocatedOutputTensor); + return getNextFrameInternal(preAllocatedOutputTensor); } VideoDecoder::FrameBatchOutput VideoDecoder::getFramesAtIndices( @@ -658,7 +658,7 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { return frameBatchOutput; } -VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAtNoDemux( +VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt( double seconds) { for (auto& [streamIndex, streamInfo] : streamInfos_) { double frameStartTime = diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 7097a904b..8cf1e85c1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -171,13 +171,13 @@ class VideoDecoder { }; // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextFrameNoDemux() will return the first frame at + // Calling getNextFrame() will return the first frame at // or after this position. void setCursorPtsInSeconds(double seconds); // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - FrameOutput getNextFrameNoDemux(); + FrameOutput getNextFrame(); FrameOutput getFrameAtIndex(int64_t frameIndex); @@ -196,7 +196,7 @@ class VideoDecoder { // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). // i.e. it will be returned when this function is called with seconds=5.0 or // seconds=5.999, etc. - FrameOutput getFramePlayedAtNoDemux(double seconds); + FrameOutput getFramePlayedAt(double seconds); FrameBatchOutput getFramesPlayedAt(const std::vector& timestamps); @@ -367,7 +367,7 @@ class VideoDecoder { AVFrameStream decodeAVFrame(std::function filterFunction); - FrameOutput getNextFrameNoDemuxInternal( + FrameOutput getNextFrameInternal( std::optional preAllocatedOutputTensor = std::nullopt); torch::Tensor maybePermuteHWC2CHW(torch::Tensor& hwcTensor); diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 88c376d77..398232b5c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -233,7 +233,7 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::FrameOutput result; try { - result = videoDecoder->getNextFrameNoDemux(); + result = videoDecoder->getNextFrame(); } catch (const VideoDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } @@ -247,7 +247,7 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder) { OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - auto result = videoDecoder->getFramePlayedAtNoDemux(seconds); + auto result = videoDecoder->getFramePlayedAt(seconds); return makeOpsFrameOutput(result); } diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 00d73961a..7e3376b5d 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -149,7 +149,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { videoStreamOptions.width = 100; videoStreamOptions.height = 120; decoder->addVideoStreamDecoder(-1, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrameNoDemux().data; + torch::Tensor tensor = decoder->getNextFrame().data; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -159,7 +159,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { VideoDecoder::VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; decoder->addVideoStreamDecoder(-1, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrameNoDemux().data; + torch::Tensor tensor = decoder->getNextFrame().data; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -168,11 +168,11 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getNextFrameNoDemux(); + auto output = ourDecoder->getNextFrame(); torch::Tensor tensor0FromOurDecoder = output.data; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); - output = ourDecoder->getNextFrameNoDemux(); + output = ourDecoder->getNextFrame(); torch::Tensor tensor1FromOurDecoder = output.data; EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -252,11 +252,11 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); - auto output = ourDecoder->getNextFrameNoDemux(); + auto output = ourDecoder->getNextFrame(); EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000); - output = ourDecoder->getNextFrameNoDemux(); + output = ourDecoder->getNextFrame(); EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); - EXPECT_THROW(ourDecoder->getNextFrameNoDemux(), std::exception); + EXPECT_THROW(ourDecoder->getNextFrame(), std::exception); } TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { @@ -264,18 +264,18 @@ TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getFramePlayedAtNoDemux(6.006); + auto output = ourDecoder->getFramePlayedAt(6.006); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame's duration is 0.033367 according to ffprobe, // so the next frame is played at timestamp=6.039367. const double kNextFramePts = 6.039366666666667; // The frame that is played a microsecond before the next frame is still // the previous frame. - output = ourDecoder->getFramePlayedAtNoDemux(kNextFramePts - 1e-6); + output = ourDecoder->getFramePlayedAt(kNextFramePts - 1e-6); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame that is played at the exact pts of the frame is the next // frame. - output = ourDecoder->getFramePlayedAtNoDemux(kNextFramePts); + output = ourDecoder->getFramePlayedAt(kNextFramePts); EXPECT_EQ(output.ptsSeconds, kNextFramePts); // This is the timestamp of the last frame in this video. @@ -286,7 +286,7 @@ TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { // Sanity check: make sure duration is strictly positive. EXPECT_GT(kPtsPlusDurationOfLastFrame, kPtsOfLastFrameInVideoStream); output = - ourDecoder->getFramePlayedAtNoDemux(kPtsPlusDurationOfLastFrame - 1e-6); + ourDecoder->getFramePlayedAt(kPtsPlusDurationOfLastFrame - 1e-6); EXPECT_EQ(output.ptsSeconds, kPtsOfLastFrameInVideoStream); } @@ -296,7 +296,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(6.0); - auto output = ourDecoder->getNextFrameNoDemux(); + auto output = ourDecoder->getNextFrame(); torch::Tensor tensor6FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); torch::Tensor tensor6FromFFMPEG = @@ -312,7 +312,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 180); ourDecoder->setCursorPtsInSeconds(6.1); - output = ourDecoder->getNextFrameNoDemux(); + output = ourDecoder->getNextFrame(); torch::Tensor tensor61FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000); torch::Tensor tensor61FromFFMPEG = @@ -332,7 +332,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_LT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 10); ourDecoder->setCursorPtsInSeconds(10.0); - output = ourDecoder->getNextFrameNoDemux(); + output = ourDecoder->getNextFrame(); torch::Tensor tensor10FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000); torch::Tensor tensor10FromFFMPEG = @@ -349,7 +349,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 60); ourDecoder->setCursorPtsInSeconds(6.0); - output = ourDecoder->getNextFrameNoDemux(); + output = ourDecoder->getNextFrame(); tensor6FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); @@ -364,7 +364,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9 ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream); - output = ourDecoder->getNextFrameNoDemux(); + output = ourDecoder->getNextFrame(); torch::Tensor tensor7FromOurDecoder = output.data; EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); torch::Tensor tensor7FromFFMPEG = From 5ac15231b8d2dd4f83d917ccc930a55fe62218bd Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Sun, 9 Feb 2025 17:20:48 +0000 Subject: [PATCH 5/6] Clean up some logic --- .../decoders/_core/VideoDecoder.cpp | 24 ++++++++----------- test/decoders/VideoDecoderTest.cpp | 3 +-- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 478e3f6cb..ea9706f12 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -658,20 +658,16 @@ VideoDecoder::getFramesInRange(int64_t start, int64_t stop, int64_t step) { return frameBatchOutput; } -VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt( - double seconds) { - for (auto& [streamIndex, streamInfo] : streamInfos_) { - double frameStartTime = - ptsToSeconds(streamInfo.currentPts, streamInfo.timeBase); - double frameEndTime = ptsToSeconds( - streamInfo.currentPts + streamInfo.currentDuration, - streamInfo.timeBase); - if (seconds >= frameStartTime && seconds < frameEndTime) { - // We are in the same frame as the one we just returned. However, since we - // don't cache it locally, we have to rewind back. - seconds = frameStartTime; - break; - } +VideoDecoder::FrameOutput VideoDecoder::getFramePlayedAt(double seconds) { + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + double frameStartTime = + ptsToSeconds(streamInfo.currentPts, streamInfo.timeBase); + double frameEndTime = ptsToSeconds( + streamInfo.currentPts + streamInfo.currentDuration, streamInfo.timeBase); + if (seconds >= frameStartTime && seconds < frameEndTime) { + // We are in the same frame as the one we just returned. However, since we + // don't cache it locally, we have to rewind back. + seconds = frameStartTime; } setCursorPtsInSeconds(seconds); diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 7e3376b5d..87a69f28b 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -285,8 +285,7 @@ TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { kPtsOfLastFrameInVideoStream + kDurationOfLastFrameInVideoStream; // Sanity check: make sure duration is strictly positive. EXPECT_GT(kPtsPlusDurationOfLastFrame, kPtsOfLastFrameInVideoStream); - output = - ourDecoder->getFramePlayedAt(kPtsPlusDurationOfLastFrame - 1e-6); + output = ourDecoder->getFramePlayedAt(kPtsPlusDurationOfLastFrame - 1e-6); EXPECT_EQ(output.ptsSeconds, kPtsOfLastFrameInVideoStream); } From 1b2c8878c4cbccfc70339d94b60f07f564366d15 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 10 Feb 2025 10:23:16 +0000 Subject: [PATCH 6/6] Put back const qualifiers --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 13 ++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 6 +++--- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index ea9706f12..1df0ff275 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -845,9 +845,8 @@ I P P P I P P P I P P I P P I P (2) is more efficient than (1) if there is an I frame between x and y. */ -bool VideoDecoder::canWeAvoidSeekingForStream( - int64_t currentPts, - int64_t targetPts) { +bool VideoDecoder::canWeAvoidSeeking(int64_t currentPts, int64_t targetPts) + const { if (targetPts < currentPts) { // We can never skip a seek if we are seeking backwards. return false; @@ -880,7 +879,7 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { decodeStats_.numSeeksAttempted++; int64_t desiredPtsForStream = *desiredPtsSeconds_ * streamInfo.timeBase.den; - if (canWeAvoidSeekingForStream(streamInfo.currentPts, desiredPtsForStream)) { + if (canWeAvoidSeeking(streamInfo.currentPts, desiredPtsForStream)) { decodeStats_.numSeeksSkipped++; return; } @@ -1472,8 +1471,8 @@ void VideoDecoder::createSwsContext( // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- -int VideoDecoder::getKeyFrameIndexForPts(int64_t pts) { - const StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; +int VideoDecoder::getKeyFrameIndexForPts(int64_t pts) const { + const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.keyFrames.empty()) { return av_index_search_timestamp( streamInfo.stream, pts, AVSEEK_FLAG_BACKWARD); @@ -1484,7 +1483,7 @@ int VideoDecoder::getKeyFrameIndexForPts(int64_t pts) { int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( const std::vector& keyFrames, - int64_t pts) { + int64_t pts) const { auto upperBound = std::upper_bound( keyFrames.begin(), keyFrames.end(), diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8cf1e85c1..37142bd7b 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -361,7 +361,7 @@ class VideoDecoder { // DECODING APIS AND RELATED UTILS // -------------------------------------------------------------------------- - bool canWeAvoidSeekingForStream(int64_t currentPts, int64_t targetPts); + bool canWeAvoidSeeking(int64_t currentPts, int64_t targetPts) const; void maybeSeekToBeforeDesiredPts(); @@ -405,14 +405,14 @@ class VideoDecoder { // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- - int getKeyFrameIndexForPts(int64_t pts); + int getKeyFrameIndexForPts(int64_t pts) const; // Returns the key frame index of the presentation timestamp using our index. // We build this index by scanning the file in // scanFileAndUpdateMetadataAndIndex int getKeyFrameIndexForPtsUsingScannedIndex( const std::vector& keyFrames, - int64_t pts); + int64_t pts) const; int64_t secondsToIndexLowerBound( double seconds,