From e5713c0d85236ecf0af28e5be484d6447c690bdc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Jan 2025 15:45:08 +0000 Subject: [PATCH 1/3] Reorgnize public stuff --- src/torchcodec/decoders/_core/VideoDecoder.h | 171 +++++++++---------- 1 file changed, 80 insertions(+), 91 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a80e0c74c..e60c2ae84 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -16,30 +16,7 @@ namespace facebook::torchcodec { -/* -The VideoDecoder class can be used to decode video frames to Tensors. - -Example usage of this class: -std::string video_file_path = "/path/to/video.mp4"; -VideoDecoder video_decoder = VideoDecoder::createFromFilePath(video_file_path); - -// After creating the decoder, we can query the metadata: -auto metadata = video_decoder.getContainerMetadata(); - -// We can also add streams to the decoder: -// -1 sets the default stream. -video_decoder.addVideoStreamDecoder(-1); - -// API for seeking and frame extraction: -// Let's extract the first frame at or after pts=5.0 seconds. -video_decoder.setCursorPtsInSeconds(5.0); -auto output = video_decoder->getNextFrameOutput(); -torch::Tensor frame = output.frame; -double presentation_timestamp = output.ptsSeconds; -// Note that presentation_timestamp can be any timestamp at 5.0 or above -// because the frame time may not align exactly with the seek time. -CHECK_GE(presentation_timestamp, 5.0); -*/ +// The VideoDecoder class can be used to decode video frames to Tensors. // Note that VideoDecoder is not thread-safe. // Do not call non-const APIs concurrently on the same object. class VideoDecoder { @@ -53,16 +30,12 @@ class VideoDecoder { enum class SeekMode { exact, approximate }; // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); - - // Creates a VideoDecoder from a given buffer. Note that the buffer is not - // owned by the VideoDecoder. - explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); - static std::unique_ptr createFromFilePath( const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); + // Creates a VideoDecoder from a given buffer. Note that the buffer is not + // owned by the VideoDecoder. static std::unique_ptr createFromBuffer( const void* buffer, size_t length, @@ -71,8 +44,10 @@ class VideoDecoder { // -------------------------------------------------------------------------- // VIDEO METADATA QUERY API // -------------------------------------------------------------------------- + // Updates the metadata of the video to accurate values obtained by scanning - // the contents of the video file. + // the contents of the video file. Also updates each StreamInfo's index, i.e. + // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); struct StreamMetadata { @@ -88,7 +63,6 @@ class VideoDecoder { std::optional numKeyFrames; std::optional averageFps; std::optional bitRate; - std::optional> keyFrames; // More accurate duration, obtained by scanning the file. // These presentation timestamps are in time base. @@ -126,6 +100,7 @@ class VideoDecoder { // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- + enum ColorConversionLibrary { // TODO: Add an AUTO option later. // Use the libavfilter library for color conversion. @@ -164,39 +139,26 @@ class VideoDecoder { int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); - torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); - - // ---- SINGLE FRAME SEEK AND DECODING API ---- - // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextFrameNoDemuxInternal() will return the first frame at - // or after this position. - void setCursorPtsInSeconds(double seconds); - - // This structure ensures we always keep the streamIndex and AVFrame together - // Note that AVFrame itself doesn't retain the streamIndex. - struct AVFrameStream { - // The actual decoded output as a unique pointer to an AVFrame. - UniqueAVFrame avFrame; - // The stream index of the decoded frame. - int streamIndex; - }; + // -------------------------------------------------------------------------- + // DECODING AND SEEKING APIs + // -------------------------------------------------------------------------- + // All public decoding entry points return either a FrameOutput or a + // FrameBatchOutput. + // They are the equivalent of the user-facing Frame and FrameBatch classes in + // Python. They contain RGB decoded frames along with some associated data + // like PTS and duration. struct FrameOutput { - // The actual decoded output as a Tensor. - torch::Tensor data; - // The stream index of the decoded frame. Used to distinguish - // between streams that are of the same type. + torch::Tensor data; // 3D: of shape CHW or HWC. int streamIndex; - // The presentation timestamp of the decoded frame in seconds. double ptsSeconds; - // The duration of the decoded frame in seconds. double durationSeconds; }; struct FrameBatchOutput { - torch::Tensor data; - torch::Tensor ptsSeconds; - torch::Tensor durationSeconds; + torch::Tensor data; // 4D: of shape NCHW or NHWC. + torch::Tensor ptsSeconds; // 1D of shape (N,) + torch::Tensor durationSeconds; // 1D of shape (N,) explicit FrameBatchOutput( int64_t numFrames, @@ -204,31 +166,16 @@ class VideoDecoder { const StreamMetadata& streamMetadata); }; - class EndOfFileException : public std::runtime_error { - public: - explicit EndOfFileException(const std::string& msg) - : std::runtime_error(msg) {} - }; + // Places the cursor at the first frame on or after the position in seconds. + // Calling getNextFrameNoDemux() will return the first frame at + // or after this position. + void setCursorPtsInSeconds(double seconds); // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. FrameOutput getNextFrameNoDemux(); - // Decodes the first frame in any added stream that is visible at a given - // timestamp. Frames in the video have a presentation timestamp and a - // duration. For example, if a frame has presentation timestamp of 5.0s and a - // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). - // i.e. it will be returned when this function is called with seconds=5.0 or - // seconds=5.999, etc. - FrameOutput getFramePlayedAtNoDemux(double seconds); FrameOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); - // This is morally private but needs to be exposed for C++ tests. Once - // getFrameAtIndex supports the preAllocatedOutputTensor parameter, we can - // move it back to private. - FrameOutput getFrameAtIndexInternal( - int streamIndex, - int64_t frameIndex, - std::optional preAllocatedOutputTensor = std::nullopt); // Returns frames at the given indices for a given stream as a single stacked // Tensor. @@ -236,21 +183,27 @@ class VideoDecoder { int streamIndex, const std::vector& frameIndices); + // Returns frames within a given range. The range is defined by [start, stop). + // The values retrieved from the range are: [start, start+step, + // start+(2*step), start+(3*step), ..., stop). The default for step is 1. + FrameBatchOutput + getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); + + // Decodes the first frame in any added stream that is visible at a given + // timestamp. Frames in the video have a presentation timestamp and a + // duration. For example, if a frame has presentation timestamp of 5.0s and a + // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). + // i.e. it will be returned when this function is called with seconds=5.0 or + // seconds=5.999, etc. + FrameOutput getFramePlayedAtNoDemux(double seconds); + FrameBatchOutput getFramesPlayedAt( int streamIndex, const std::vector& timestamps); - // Returns frames within a given range for a given stream as a single stacked - // Tensor. The range is defined by [start, stop). The values retrieved from - // the range are: - // [start, start+step, start+(2*step), start+(3*step), ..., stop) - // The default for step is 1. - FrameBatchOutput - getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); - - // Returns frames within a given pts range for a given stream as a single - // stacked tensor. The range is defined by [startSeconds, stopSeconds) with - // respect to the pts values for frames. The returned frames are in pts order. + // Returns frames within a given pts range. The range is defined by + // [startSeconds, stopSeconds) with respect to the pts values for frames. The + // returned frames are in pts order. // // Note that while stopSeconds is excluded in the half open range, this really // only makes a difference when stopSeconds is exactly the pts value for a @@ -270,11 +223,44 @@ class VideoDecoder { double startSeconds, double stopSeconds); + class EndOfFileException : public std::runtime_error { + public: + explicit EndOfFileException(const std::string& msg) + : std::runtime_error(msg) {} + }; + // -------------------------------------------------------------------------- - // DECODER PERFORMANCE STATISTICS API + // MORALLY PRIVATE APIS // -------------------------------------------------------------------------- + // These are APIs that should be private, but that are effectively exposed for + // practical reasons, typically for testing purposes. + + // This struct is needed because AVFrame doesn't retain the streamIndex. Only + // the AVPacket knows its stream. This is what the low-level private decoding + // entry points return. The AVFrameStream is then converted to a FrameOutput + // with convertAVFrameToFrameOutput. It should be private, but is currently + // used by DeviceInterface. + struct AVFrameStream { + // The actual decoded output as a unique pointer to an AVFrame. + // Usually, this is a YUV frame. It'll be converted to RGB in + // convertAVFrameToFrameOutput. + UniqueAVFrame avFrame; + // The stream index of the decoded frame. + int streamIndex; + }; + + // Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we + // can move it back to private. + FrameOutput getFrameAtIndexInternal( + int streamIndex, + int64_t frameIndex, + std::optional preAllocatedOutputTensor = std::nullopt); + + // Exposed for _test_frame_pts_equality, which is used to test non-regression + // of pts resolution (64 to 32 bit floats) + double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); - // Only exposed for performance testing. + // Exposed for performance testing. struct DecodeStats { int64_t numSeeksAttempted = 0; int64_t numSeeksDone = 0; @@ -288,9 +274,11 @@ class VideoDecoder { DecodeStats getDecodeStats() const; void resetDecodeStats(); - double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); - private: + explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); + torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + struct FrameInfo { int64_t pts = 0; // The value of this default is important: the last frame's nextPts will be @@ -401,6 +389,7 @@ class VideoDecoder { const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); + AVFrameStream getAVFrameUsingFilterFunction( std::function); // Once we create a decoder can update the metadata with the codec context. From e233f06664bd3c413abdcab37f0c861ef552aac1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Jan 2025 16:26:55 +0000 Subject: [PATCH 2/3] Rename FilterState into FilterGraphContext --- .../decoders/_core/VideoDecoder.cpp | 38 ++++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 59936f9e3..2576d75ea 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -337,12 +337,12 @@ void VideoDecoder::createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth) { - FilterState& filterState = streamInfo.filterState; - filterState.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterState.filterGraph.get() != nullptr); + FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext; + filterGraphContext.filterGraph.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr); if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) { - filterState.filterGraph->nb_threads = + filterGraphContext.filterGraph->nb_threads = streamInfo.videoStreamOptions.ffmpegThreadCount.value(); } @@ -360,12 +360,12 @@ void VideoDecoder::createFilterGraph( << codecContext->sample_aspect_ratio.den; int ffmpegStatus = avfilter_graph_create_filter( - &filterState.sourceContext, + &filterGraphContext.sourceContext, buffersrc, "in", filterArgs.str().c_str(), nullptr, - filterState.filterGraph.get()); + filterGraphContext.filterGraph.get()); if (ffmpegStatus < 0) { throw std::runtime_error( std::string("Failed to create filter graph: ") + filterArgs.str() + @@ -373,12 +373,12 @@ void VideoDecoder::createFilterGraph( } ffmpegStatus = avfilter_graph_create_filter( - &filterState.sinkContext, + &filterGraphContext.sinkContext, buffersink, "out", nullptr, nullptr, - filterState.filterGraph.get()); + filterGraphContext.filterGraph.get()); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to create filter graph: " + @@ -388,7 +388,7 @@ void VideoDecoder::createFilterGraph( enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; ffmpegStatus = av_opt_set_int_list( - filterState.sinkContext, + filterGraphContext.sinkContext, "pix_fmts", pix_fmts, AV_PIX_FMT_NONE, @@ -403,11 +403,11 @@ void VideoDecoder::createFilterGraph( UniqueAVFilterInOut inputs(avfilter_inout_alloc()); outputs->name = av_strdup("in"); - outputs->filter_ctx = filterState.sourceContext; + outputs->filter_ctx = filterGraphContext.sourceContext; outputs->pad_idx = 0; outputs->next = nullptr; inputs->name = av_strdup("out"); - inputs->filter_ctx = filterState.sinkContext; + inputs->filter_ctx = filterGraphContext.sinkContext; inputs->pad_idx = 0; inputs->next = nullptr; @@ -418,7 +418,7 @@ void VideoDecoder::createFilterGraph( AVFilterInOut* outputsTmp = outputs.release(); AVFilterInOut* inputsTmp = inputs.release(); ffmpegStatus = avfilter_graph_parse_ptr( - filterState.filterGraph.get(), + filterGraphContext.filterGraph.get(), description.str().c_str(), &inputsTmp, &outputsTmp, @@ -431,7 +431,8 @@ void VideoDecoder::createFilterGraph( getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr); + ffmpegStatus = + avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to configure filter graph: " + @@ -1027,7 +1028,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!streamInfo.filterState.filterGraph || + if (!streamInfo.filterGraphContext.filterGraph || streamInfo.prevFrameContext != frameContext) { createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); streamInfo.prevFrameContext = frameContext; @@ -1585,16 +1586,17 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale( torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* avFrame) { - FilterState& filterState = streamInfos_[streamIndex].filterState; + FilterGraphContext& filterGraphContext = + streamInfos_[streamIndex].filterGraphContext; int ffmpegStatus = - av_buffersrc_write_frame(filterState.sourceContext, avFrame); + av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } UniqueAVFrame filteredAVFrame(av_frame_alloc()); - ffmpegStatus = - av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get()); + ffmpegStatus = av_buffersink_get_frame( + filterGraphContext.sinkContext, filteredAVFrame.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index e60c2ae84..851aff396 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -289,7 +289,7 @@ class VideoDecoder { int64_t nextPts = INT64_MAX; }; - struct FilterState { + struct FilterGraphContext { UniqueAVFilterGraph filterGraph; AVFilterContext* sourceContext = nullptr; AVFilterContext* sinkContext = nullptr; @@ -321,7 +321,7 @@ class VideoDecoder { VideoStreamOptions videoStreamOptions; // The filter state associated with this stream (for video streams). The // actual graph will be nullptr for inactive streams. - FilterState filterState; + FilterGraphContext filterGraphContext; ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; std::vector keyFrames; std::vector allFrames; From be1fb026160034d9c386eff82f05a34dc8408b96 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Jan 2025 17:03:55 +0000 Subject: [PATCH 3/3] Reorganize private part of header --- src/torchcodec/decoders/_core/VideoDecoder.h | 194 ++++++++++++------- 1 file changed, 119 insertions(+), 75 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 851aff396..dc5fa0bdc 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -275,10 +275,9 @@ class VideoDecoder { void resetDecodeStats(); private: - explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); - explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); - torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); - + // -------------------------------------------------------------------------- + // STREAMINFO AND ASSOCIATED STRUCTS + // -------------------------------------------------------------------------- struct FrameInfo { int64_t pts = 0; // The value of this default is important: the last frame's nextPts will be @@ -305,73 +304,117 @@ class VideoDecoder { bool operator!=(const DecodedFrameContext&); }; - // Stores information for each stream. struct StreamInfo { int streamIndex = -1; AVStream* stream = nullptr; AVRational timeBase = {}; UniqueAVCodecContext codecContext; - // The current position of the cursor in the stream. + + // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was + // called. + std::vector keyFrames; + std::vector allFrames; + + // The current position of the cursor in the stream, and associated frame + // duration. int64_t currentPts = 0; int64_t currentDuration = 0; // The desired position of the cursor in the stream. We send frames >= // this pts to the user when they request a frame. - // We update this field if the user requested a seek. + // We update this field if the user requested a seek. This typically + // corresponds to the decoder's desiredPts_ attribute. int64_t discardFramesBeforePts = INT64_MIN; VideoStreamOptions videoStreamOptions; - // The filter state associated with this stream (for video streams). The - // actual graph will be nullptr for inactive streams. - FilterGraphContext filterGraphContext; + + // color-conversion fields. Only one of FilterGraphContextr and + // UniqueSwsContext should be non-null. ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; - std::vector keyFrames; - std::vector allFrames; - DecodedFrameContext prevFrameContext; + FilterGraphContext filterGraphContext; UniqueSwsContext swsContext; + + // Used to know whether a new FilterGraphContext or UniqueSwsContext should + // be created before decoding a new frame. + DecodedFrameContext prevFrameContext; }; - // Returns the key frame index of the presentation timestamp using FFMPEG's - // index. Note that this index may be truncated for some files. - int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts) - const; - // Returns the key frame index of the presentation timestamp using our index. - // We build this index by scanning the file in buildKeyFrameIndex(). - int getKeyFrameIndexForPtsUsingScannedIndex( - const std::vector& keyFrames, - int64_t pts) const; - int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const; + // -------------------------------------------------------------------------- + // CONSTRUCTORS AND INITIALIZERS + // -------------------------------------------------------------------------- + // Don't use those, use the static methods to create a decoder object. + + explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); + void initializeDecoder(); + void updateMetadataWithCodecContext( + int streamIndex, + AVCodecContext* codecContext); + + // -------------------------------------------------------------------------- + // DECODING APIS AND RELATED UTILS + // -------------------------------------------------------------------------- + bool canWeAvoidSeekingForStream( const StreamInfo& stream, int64_t currentPts, int64_t targetPts) const; - // Returns the "best" stream index for a given media type. The "best" is - // determined by various heuristics in FFMPEG. - // See - // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2 - // for more details about the heuristics. - int getBestStreamIndex(AVMediaType mediaType); - void initializeDecoder(); - void validateUserProvidedStreamIndex(int streamIndex); - void validateScannedAllStreams(const std::string& msg); - void validateFrameIndex( - const StreamMetadata& streamMetadata, - int64_t frameIndex); - // Creates and initializes a filter graph for a stream. The filter graph can - // do rescaling and color conversion. + void maybeSeekToBeforeDesiredPts(); + + AVFrameStream getAVFrameUsingFilterFunction( + std::function); + + FrameOutput getNextFrameNoDemuxInternal( + std::optional preAllocatedOutputTensor = std::nullopt); + + torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + + FrameOutput convertAVFrameToFrameOutput( + AVFrameStream& avFrameStream, + std::optional preAllocatedOutputTensor = std::nullopt); + + void convertAVFrameToFrameOutputOnCPU( + AVFrameStream& avFrameStream, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); + + torch::Tensor convertAVFrameToTensorUsingFilterGraph( + int streamIndex, + const AVFrame* avFrame); + + int convertAVFrameToTensorUsingSwsScale( + int streamIndex, + const AVFrame* avFrame, + torch::Tensor& outputTensor); + + // -------------------------------------------------------------------------- + // COLOR CONVERSION LIBRARIES HANDLERS CREATION + // -------------------------------------------------------------------------- + void createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth); - int64_t getNumFrames(const StreamMetadata& streamMetadata); + void createSwsContext( + StreamInfo& streamInfo, + const DecodedFrameContext& frameContext, + const enum AVColorSpace colorspace); - int64_t getPts( - const StreamInfo& streamInfo, - const StreamMetadata& streamMetadata, - int64_t frameIndex); + // -------------------------------------------------------------------------- + // PTS <-> INDEX CONVERSIONS + // -------------------------------------------------------------------------- - double getMinSeconds(const StreamMetadata& streamMetadata); - double getMaxSeconds(const StreamMetadata& streamMetadata); + int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const; + + // Returns the key frame index of the presentation timestamp using our index. + // We build this index by scanning the file in + // scanFileAndUpdateMetadataAndIndex + int getKeyFrameIndexForPtsUsingScannedIndex( + const std::vector& keyFrames, + int64_t pts) const; + // Return key frame index, from FFmpeg. Potentially less accurate + int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts) + const; int64_t secondsToIndexLowerBound( double seconds, @@ -383,39 +426,41 @@ class VideoDecoder { const StreamInfo& streamInfo, const StreamMetadata& streamMetadata); - void createSwsContext( - StreamInfo& streamInfo, - const DecodedFrameContext& frameContext, - const enum AVColorSpace colorspace); + int64_t getPts( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, + int64_t frameIndex); - void maybeSeekToBeforeDesiredPts(); + // -------------------------------------------------------------------------- + // STREAM AND METADATA APIS + // -------------------------------------------------------------------------- - AVFrameStream getAVFrameUsingFilterFunction( - std::function); - // Once we create a decoder can update the metadata with the codec context. - // For example, for video streams, we can add the height and width of the - // decoded stream. - void updateMetadataWithCodecContext( - int streamIndex, - AVCodecContext* codecContext); - void populateVideoMetadataFromStreamIndex(int streamIndex); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( - int streamIndex, - const AVFrame* avFrame); - int convertAVFrameToTensorUsingSwsScale( - int streamIndex, - const AVFrame* avFrame, - torch::Tensor& outputTensor); - FrameOutput convertAVFrameToFrameOutput( - AVFrameStream& avFrameStream, - std::optional preAllocatedOutputTensor = std::nullopt); - void convertAVFrameToFrameOutputOnCPU( - AVFrameStream& avFrameStream, - FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); + // Returns the "best" stream index for a given media type. The "best" is + // determined by various heuristics in FFMPEG. + // See + // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2 + // for more details about the heuristics. + // Returns the key frame index of the presentation timestamp using FFMPEG's + // index. Note that this index may be truncated for some files. + int getBestStreamIndex(AVMediaType mediaType); - FrameOutput getNextFrameNoDemuxInternal( - std::optional preAllocatedOutputTensor = std::nullopt); + int64_t getNumFrames(const StreamMetadata& streamMetadata); + double getMinSeconds(const StreamMetadata& streamMetadata); + double getMaxSeconds(const StreamMetadata& streamMetadata); + + // -------------------------------------------------------------------------- + // VALIDATION UTILS + // -------------------------------------------------------------------------- + + void validateUserProvidedStreamIndex(int streamIndex); + void validateScannedAllStreams(const std::string& msg); + void validateFrameIndex( + const StreamMetadata& streamMetadata, + int64_t frameIndex); + + // -------------------------------------------------------------------------- + // ATTRIBUTES + // -------------------------------------------------------------------------- SeekMode seekMode_; ContainerMetadata containerMetadata_; @@ -427,7 +472,6 @@ class VideoDecoder { // Set when the user wants to seek and stores the desired pts that the user // wants to seek to. std::optional desiredPtsSeconds_; - // Stores various internal decoding stats. DecodeStats decodeStats_; // Stores the AVIOContext for the input buffer.