From e5713c0d85236ecf0af28e5be484d6447c690bdc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Jan 2025 15:45:08 +0000 Subject: [PATCH 1/2] Reorgnize public stuff --- src/torchcodec/decoders/_core/VideoDecoder.h | 171 +++++++++---------- 1 file changed, 80 insertions(+), 91 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index a80e0c74c..e60c2ae84 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -16,30 +16,7 @@ namespace facebook::torchcodec { -/* -The VideoDecoder class can be used to decode video frames to Tensors. - -Example usage of this class: -std::string video_file_path = "/path/to/video.mp4"; -VideoDecoder video_decoder = VideoDecoder::createFromFilePath(video_file_path); - -// After creating the decoder, we can query the metadata: -auto metadata = video_decoder.getContainerMetadata(); - -// We can also add streams to the decoder: -// -1 sets the default stream. -video_decoder.addVideoStreamDecoder(-1); - -// API for seeking and frame extraction: -// Let's extract the first frame at or after pts=5.0 seconds. -video_decoder.setCursorPtsInSeconds(5.0); -auto output = video_decoder->getNextFrameOutput(); -torch::Tensor frame = output.frame; -double presentation_timestamp = output.ptsSeconds; -// Note that presentation_timestamp can be any timestamp at 5.0 or above -// because the frame time may not align exactly with the seek time. -CHECK_GE(presentation_timestamp, 5.0); -*/ +// The VideoDecoder class can be used to decode video frames to Tensors. // Note that VideoDecoder is not thread-safe. // Do not call non-const APIs concurrently on the same object. class VideoDecoder { @@ -53,16 +30,12 @@ class VideoDecoder { enum class SeekMode { exact, approximate }; // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); - - // Creates a VideoDecoder from a given buffer. Note that the buffer is not - // owned by the VideoDecoder. - explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); - static std::unique_ptr createFromFilePath( const std::string& videoFilePath, SeekMode seekMode = SeekMode::exact); + // Creates a VideoDecoder from a given buffer. Note that the buffer is not + // owned by the VideoDecoder. static std::unique_ptr createFromBuffer( const void* buffer, size_t length, @@ -71,8 +44,10 @@ class VideoDecoder { // -------------------------------------------------------------------------- // VIDEO METADATA QUERY API // -------------------------------------------------------------------------- + // Updates the metadata of the video to accurate values obtained by scanning - // the contents of the video file. + // the contents of the video file. Also updates each StreamInfo's index, i.e. + // the allFrames and keyFrames vectors. void scanFileAndUpdateMetadataAndIndex(); struct StreamMetadata { @@ -88,7 +63,6 @@ class VideoDecoder { std::optional numKeyFrames; std::optional averageFps; std::optional bitRate; - std::optional> keyFrames; // More accurate duration, obtained by scanning the file. // These presentation timestamps are in time base. @@ -126,6 +100,7 @@ class VideoDecoder { // -------------------------------------------------------------------------- // ADDING STREAMS API // -------------------------------------------------------------------------- + enum ColorConversionLibrary { // TODO: Add an AUTO option later. // Use the libavfilter library for color conversion. @@ -164,39 +139,26 @@ class VideoDecoder { int streamIndex, const AudioStreamOptions& audioStreamOptions = AudioStreamOptions()); - torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); - - // ---- SINGLE FRAME SEEK AND DECODING API ---- - // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextFrameNoDemuxInternal() will return the first frame at - // or after this position. - void setCursorPtsInSeconds(double seconds); - - // This structure ensures we always keep the streamIndex and AVFrame together - // Note that AVFrame itself doesn't retain the streamIndex. - struct AVFrameStream { - // The actual decoded output as a unique pointer to an AVFrame. - UniqueAVFrame avFrame; - // The stream index of the decoded frame. - int streamIndex; - }; + // -------------------------------------------------------------------------- + // DECODING AND SEEKING APIs + // -------------------------------------------------------------------------- + // All public decoding entry points return either a FrameOutput or a + // FrameBatchOutput. + // They are the equivalent of the user-facing Frame and FrameBatch classes in + // Python. They contain RGB decoded frames along with some associated data + // like PTS and duration. struct FrameOutput { - // The actual decoded output as a Tensor. - torch::Tensor data; - // The stream index of the decoded frame. Used to distinguish - // between streams that are of the same type. + torch::Tensor data; // 3D: of shape CHW or HWC. int streamIndex; - // The presentation timestamp of the decoded frame in seconds. double ptsSeconds; - // The duration of the decoded frame in seconds. double durationSeconds; }; struct FrameBatchOutput { - torch::Tensor data; - torch::Tensor ptsSeconds; - torch::Tensor durationSeconds; + torch::Tensor data; // 4D: of shape NCHW or NHWC. + torch::Tensor ptsSeconds; // 1D of shape (N,) + torch::Tensor durationSeconds; // 1D of shape (N,) explicit FrameBatchOutput( int64_t numFrames, @@ -204,31 +166,16 @@ class VideoDecoder { const StreamMetadata& streamMetadata); }; - class EndOfFileException : public std::runtime_error { - public: - explicit EndOfFileException(const std::string& msg) - : std::runtime_error(msg) {} - }; + // Places the cursor at the first frame on or after the position in seconds. + // Calling getNextFrameNoDemux() will return the first frame at + // or after this position. + void setCursorPtsInSeconds(double seconds); // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. FrameOutput getNextFrameNoDemux(); - // Decodes the first frame in any added stream that is visible at a given - // timestamp. Frames in the video have a presentation timestamp and a - // duration. For example, if a frame has presentation timestamp of 5.0s and a - // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). - // i.e. it will be returned when this function is called with seconds=5.0 or - // seconds=5.999, etc. - FrameOutput getFramePlayedAtNoDemux(double seconds); FrameOutput getFrameAtIndex(int streamIndex, int64_t frameIndex); - // This is morally private but needs to be exposed for C++ tests. Once - // getFrameAtIndex supports the preAllocatedOutputTensor parameter, we can - // move it back to private. - FrameOutput getFrameAtIndexInternal( - int streamIndex, - int64_t frameIndex, - std::optional preAllocatedOutputTensor = std::nullopt); // Returns frames at the given indices for a given stream as a single stacked // Tensor. @@ -236,21 +183,27 @@ class VideoDecoder { int streamIndex, const std::vector& frameIndices); + // Returns frames within a given range. The range is defined by [start, stop). + // The values retrieved from the range are: [start, start+step, + // start+(2*step), start+(3*step), ..., stop). The default for step is 1. + FrameBatchOutput + getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); + + // Decodes the first frame in any added stream that is visible at a given + // timestamp. Frames in the video have a presentation timestamp and a + // duration. For example, if a frame has presentation timestamp of 5.0s and a + // duration of 1.0s, it will be visible in the timestamp range [5.0, 6.0). + // i.e. it will be returned when this function is called with seconds=5.0 or + // seconds=5.999, etc. + FrameOutput getFramePlayedAtNoDemux(double seconds); + FrameBatchOutput getFramesPlayedAt( int streamIndex, const std::vector& timestamps); - // Returns frames within a given range for a given stream as a single stacked - // Tensor. The range is defined by [start, stop). The values retrieved from - // the range are: - // [start, start+step, start+(2*step), start+(3*step), ..., stop) - // The default for step is 1. - FrameBatchOutput - getFramesInRange(int streamIndex, int64_t start, int64_t stop, int64_t step); - - // Returns frames within a given pts range for a given stream as a single - // stacked tensor. The range is defined by [startSeconds, stopSeconds) with - // respect to the pts values for frames. The returned frames are in pts order. + // Returns frames within a given pts range. The range is defined by + // [startSeconds, stopSeconds) with respect to the pts values for frames. The + // returned frames are in pts order. // // Note that while stopSeconds is excluded in the half open range, this really // only makes a difference when stopSeconds is exactly the pts value for a @@ -270,11 +223,44 @@ class VideoDecoder { double startSeconds, double stopSeconds); + class EndOfFileException : public std::runtime_error { + public: + explicit EndOfFileException(const std::string& msg) + : std::runtime_error(msg) {} + }; + // -------------------------------------------------------------------------- - // DECODER PERFORMANCE STATISTICS API + // MORALLY PRIVATE APIS // -------------------------------------------------------------------------- + // These are APIs that should be private, but that are effectively exposed for + // practical reasons, typically for testing purposes. + + // This struct is needed because AVFrame doesn't retain the streamIndex. Only + // the AVPacket knows its stream. This is what the low-level private decoding + // entry points return. The AVFrameStream is then converted to a FrameOutput + // with convertAVFrameToFrameOutput. It should be private, but is currently + // used by DeviceInterface. + struct AVFrameStream { + // The actual decoded output as a unique pointer to an AVFrame. + // Usually, this is a YUV frame. It'll be converted to RGB in + // convertAVFrameToFrameOutput. + UniqueAVFrame avFrame; + // The stream index of the decoded frame. + int streamIndex; + }; + + // Once getFrameAtIndex supports the preAllocatedOutputTensor parameter, we + // can move it back to private. + FrameOutput getFrameAtIndexInternal( + int streamIndex, + int64_t frameIndex, + std::optional preAllocatedOutputTensor = std::nullopt); + + // Exposed for _test_frame_pts_equality, which is used to test non-regression + // of pts resolution (64 to 32 bit floats) + double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); - // Only exposed for performance testing. + // Exposed for performance testing. struct DecodeStats { int64_t numSeeksAttempted = 0; int64_t numSeeksDone = 0; @@ -288,9 +274,11 @@ class VideoDecoder { DecodeStats getDecodeStats() const; void resetDecodeStats(); - double getPtsSecondsForFrame(int streamIndex, int64_t frameIndex); - private: + explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); + torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + struct FrameInfo { int64_t pts = 0; // The value of this default is important: the last frame's nextPts will be @@ -401,6 +389,7 @@ class VideoDecoder { const enum AVColorSpace colorspace); void maybeSeekToBeforeDesiredPts(); + AVFrameStream getAVFrameUsingFilterFunction( std::function); // Once we create a decoder can update the metadata with the codec context. From e233f06664bd3c413abdcab37f0c861ef552aac1 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 24 Jan 2025 16:26:55 +0000 Subject: [PATCH 2/2] Rename FilterState into FilterGraphContext --- .../decoders/_core/VideoDecoder.cpp | 38 ++++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 4 +- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 59936f9e3..2576d75ea 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -337,12 +337,12 @@ void VideoDecoder::createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth) { - FilterState& filterState = streamInfo.filterState; - filterState.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterState.filterGraph.get() != nullptr); + FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext; + filterGraphContext.filterGraph.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr); if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) { - filterState.filterGraph->nb_threads = + filterGraphContext.filterGraph->nb_threads = streamInfo.videoStreamOptions.ffmpegThreadCount.value(); } @@ -360,12 +360,12 @@ void VideoDecoder::createFilterGraph( << codecContext->sample_aspect_ratio.den; int ffmpegStatus = avfilter_graph_create_filter( - &filterState.sourceContext, + &filterGraphContext.sourceContext, buffersrc, "in", filterArgs.str().c_str(), nullptr, - filterState.filterGraph.get()); + filterGraphContext.filterGraph.get()); if (ffmpegStatus < 0) { throw std::runtime_error( std::string("Failed to create filter graph: ") + filterArgs.str() + @@ -373,12 +373,12 @@ void VideoDecoder::createFilterGraph( } ffmpegStatus = avfilter_graph_create_filter( - &filterState.sinkContext, + &filterGraphContext.sinkContext, buffersink, "out", nullptr, nullptr, - filterState.filterGraph.get()); + filterGraphContext.filterGraph.get()); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to create filter graph: " + @@ -388,7 +388,7 @@ void VideoDecoder::createFilterGraph( enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; ffmpegStatus = av_opt_set_int_list( - filterState.sinkContext, + filterGraphContext.sinkContext, "pix_fmts", pix_fmts, AV_PIX_FMT_NONE, @@ -403,11 +403,11 @@ void VideoDecoder::createFilterGraph( UniqueAVFilterInOut inputs(avfilter_inout_alloc()); outputs->name = av_strdup("in"); - outputs->filter_ctx = filterState.sourceContext; + outputs->filter_ctx = filterGraphContext.sourceContext; outputs->pad_idx = 0; outputs->next = nullptr; inputs->name = av_strdup("out"); - inputs->filter_ctx = filterState.sinkContext; + inputs->filter_ctx = filterGraphContext.sinkContext; inputs->pad_idx = 0; inputs->next = nullptr; @@ -418,7 +418,7 @@ void VideoDecoder::createFilterGraph( AVFilterInOut* outputsTmp = outputs.release(); AVFilterInOut* inputsTmp = inputs.release(); ffmpegStatus = avfilter_graph_parse_ptr( - filterState.filterGraph.get(), + filterGraphContext.filterGraph.get(), description.str().c_str(), &inputsTmp, &outputsTmp, @@ -431,7 +431,8 @@ void VideoDecoder::createFilterGraph( getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr); + ffmpegStatus = + avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to configure filter graph: " + @@ -1027,7 +1028,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!streamInfo.filterState.filterGraph || + if (!streamInfo.filterGraphContext.filterGraph || streamInfo.prevFrameContext != frameContext) { createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); streamInfo.prevFrameContext = frameContext; @@ -1585,16 +1586,17 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale( torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* avFrame) { - FilterState& filterState = streamInfos_[streamIndex].filterState; + FilterGraphContext& filterGraphContext = + streamInfos_[streamIndex].filterGraphContext; int ffmpegStatus = - av_buffersrc_write_frame(filterState.sourceContext, avFrame); + av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } UniqueAVFrame filteredAVFrame(av_frame_alloc()); - ffmpegStatus = - av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get()); + ffmpegStatus = av_buffersink_get_frame( + filterGraphContext.sinkContext, filteredAVFrame.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index e60c2ae84..851aff396 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -289,7 +289,7 @@ class VideoDecoder { int64_t nextPts = INT64_MAX; }; - struct FilterState { + struct FilterGraphContext { UniqueAVFilterGraph filterGraph; AVFilterContext* sourceContext = nullptr; AVFilterContext* sinkContext = nullptr; @@ -321,7 +321,7 @@ class VideoDecoder { VideoStreamOptions videoStreamOptions; // The filter state associated with this stream (for video streams). The // actual graph will be nullptr for inactive streams. - FilterState filterState; + FilterGraphContext filterGraphContext; ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; std::vector keyFrames; std::vector allFrames;