diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 0c9f02a2e..760b534df 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -281,7 +281,9 @@ class VideoDecoder { void resetDecodeStats(); private: - torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + // -------------------------------------------------------------------------- + // STREAMINFO AND ASSOCIATED STRUCTS + // -------------------------------------------------------------------------- struct FrameInfo { int64_t pts = 0; @@ -309,73 +311,114 @@ class VideoDecoder { bool operator!=(const DecodedFrameContext&); }; - // Stores information for each stream. struct StreamInfo { int streamIndex = -1; AVStream* stream = nullptr; AVRational timeBase = {}; UniqueAVCodecContext codecContext; - // The current position of the cursor in the stream. + + // The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was + // called. + std::vector keyFrames; + std::vector allFrames; + + // The current position of the cursor in the stream, and associated frame + // duration. int64_t currentPts = 0; int64_t currentDuration = 0; // The desired position of the cursor in the stream. We send frames >= // this pts to the user when they request a frame. - // We update this field if the user requested a seek. + // We update this field if the user requested a seek. This typically + // corresponds to the decoder's desiredPts_ attribute. int64_t discardFramesBeforePts = INT64_MIN; VideoStreamOptions videoStreamOptions; - // The filter state associated with this stream (for video streams). The - // actual graph will be nullptr for inactive streams. + + // color-conversion fields. Only one of FilterGraphContext and + // UniqueSwsContext should be non-null. FilterGraphContext filterGraphContext; ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; - std::vector keyFrames; - std::vector allFrames; - DecodedFrameContext prevFrameContext; UniqueSwsContext swsContext; + + // Used to know whether a new FilterGraphContext or UniqueSwsContext should + // be created before decoding a new frame. + DecodedFrameContext prevFrameContext; }; - // Returns the key frame index of the presentation timestamp using FFMPEG's - // index. Note that this index may be truncated for some files. - int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts) - const; - // Returns the key frame index of the presentation timestamp using our index. - // We build this index by scanning the file in buildKeyFrameIndex(). - int getKeyFrameIndexForPtsUsingScannedIndex( - const std::vector& keyFrames, - int64_t pts) const; - int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const; + // -------------------------------------------------------------------------- + // INITIALIZERS + // -------------------------------------------------------------------------- + + void initializeDecoder(); + void updateMetadataWithCodecContext( + int streamIndex, + AVCodecContext* codecContext); + + // -------------------------------------------------------------------------- + // DECODING APIS AND RELATED UTILS + // -------------------------------------------------------------------------- + bool canWeAvoidSeekingForStream( const StreamInfo& stream, int64_t currentPts, int64_t targetPts) const; - // Returns the "best" stream index for a given media type. The "best" is - // determined by various heuristics in FFMPEG. - // See - // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2 - // for more details about the heuristics. - int getBestStreamIndex(AVMediaType mediaType); - void initializeDecoder(); - void validateUserProvidedStreamIndex(int streamIndex); - void validateScannedAllStreams(const std::string& msg); - void validateFrameIndex( - const StreamMetadata& streamMetadata, - int64_t frameIndex); - // Creates and initializes a filter graph for a stream. The filter graph can - // do rescaling and color conversion. + void maybeSeekToBeforeDesiredPts(); + + AVFrameStream decodeAVFrame( + std::function filterFunction); + + FrameOutput getNextFrameNoDemuxInternal( + std::optional preAllocatedOutputTensor = std::nullopt); + + torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + + FrameOutput convertAVFrameToFrameOutput( + AVFrameStream& avFrameStream, + std::optional preAllocatedOutputTensor = std::nullopt); + + void convertAVFrameToFrameOutputOnCPU( + AVFrameStream& avFrameStream, + FrameOutput& frameOutput, + std::optional preAllocatedOutputTensor = std::nullopt); + + torch::Tensor convertAVFrameToTensorUsingFilterGraph( + int streamIndex, + const AVFrame* avFrame); + + int convertAVFrameToTensorUsingSwsScale( + int streamIndex, + const AVFrame* avFrame, + torch::Tensor& outputTensor); + + // -------------------------------------------------------------------------- + // COLOR CONVERSION LIBRARIES HANDLERS CREATION + // -------------------------------------------------------------------------- + void createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth); - int64_t getNumFrames(const StreamMetadata& streamMetadata); + void createSwsContext( + StreamInfo& streamInfo, + const DecodedFrameContext& frameContext, + const enum AVColorSpace colorspace); - int64_t getPts( - const StreamInfo& streamInfo, - const StreamMetadata& streamMetadata, - int64_t frameIndex); + // -------------------------------------------------------------------------- + // PTS <-> INDEX CONVERSIONS + // -------------------------------------------------------------------------- - double getMinSeconds(const StreamMetadata& streamMetadata); - double getMaxSeconds(const StreamMetadata& streamMetadata); + int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const; + + // Returns the key frame index of the presentation timestamp using our index. + // We build this index by scanning the file in + // scanFileAndUpdateMetadataAndIndex + int getKeyFrameIndexForPtsUsingScannedIndex( + const std::vector& keyFrames, + int64_t pts) const; + // Return key frame index, from FFmpeg. Potentially less accurate + int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts) + const; int64_t secondsToIndexLowerBound( double seconds, @@ -387,40 +430,43 @@ class VideoDecoder { const StreamInfo& streamInfo, const StreamMetadata& streamMetadata); - void createSwsContext( - StreamInfo& streamInfo, - const DecodedFrameContext& frameContext, - const enum AVColorSpace colorspace); - - void maybeSeekToBeforeDesiredPts(); + int64_t getPts( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, + int64_t frameIndex); - AVFrameStream decodeAVFrame( - std::function filterFunction); + // -------------------------------------------------------------------------- + // STREAM AND METADATA APIS + // -------------------------------------------------------------------------- - // Once we create a decoder can update the metadata with the codec context. - // For example, for video streams, we can add the height and width of the - // decoded stream. - void updateMetadataWithCodecContext( - int streamIndex, - AVCodecContext* codecContext); void populateVideoMetadataFromStreamIndex(int streamIndex); - torch::Tensor convertAVFrameToTensorUsingFilterGraph( - int streamIndex, - const AVFrame* avFrame); - int convertAVFrameToTensorUsingSwsScale( - int streamIndex, - const AVFrame* avFrame, - torch::Tensor& outputTensor); - FrameOutput convertAVFrameToFrameOutput( - AVFrameStream& avFrameStream, - std::optional preAllocatedOutputTensor = std::nullopt); - void convertAVFrameToFrameOutputOnCPU( - AVFrameStream& avFrameStream, - FrameOutput& frameOutput, - std::optional preAllocatedOutputTensor = std::nullopt); - FrameOutput getNextFrameNoDemuxInternal( - std::optional preAllocatedOutputTensor = std::nullopt); + // Returns the "best" stream index for a given media type. The "best" is + // determined by various heuristics in FFMPEG. + // See + // https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2 + // for more details about the heuristics. + // Returns the key frame index of the presentation timestamp using FFMPEG's + // index. Note that this index may be truncated for some files. + int getBestStreamIndex(AVMediaType mediaType); + + int64_t getNumFrames(const StreamMetadata& streamMetadata); + double getMinSeconds(const StreamMetadata& streamMetadata); + double getMaxSeconds(const StreamMetadata& streamMetadata); + + // -------------------------------------------------------------------------- + // VALIDATION UTILS + // -------------------------------------------------------------------------- + + void validateUserProvidedStreamIndex(int streamIndex); + void validateScannedAllStreams(const std::string& msg); + void validateFrameIndex( + const StreamMetadata& streamMetadata, + int64_t frameIndex); + + // -------------------------------------------------------------------------- + // ATTRIBUTES + // -------------------------------------------------------------------------- SeekMode seekMode_; ContainerMetadata containerMetadata_; @@ -432,7 +478,6 @@ class VideoDecoder { // Set when the user wants to seek and stores the desired pts that the user // wants to seek to. std::optional desiredPtsSeconds_; - // Stores various internal decoding stats. DecodeStats decodeStats_; // Stores the AVIOContext for the input buffer.