Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 116 additions & 71 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,9 @@ class VideoDecoder {
void resetDecodeStats();

private:
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
// --------------------------------------------------------------------------
// STREAMINFO AND ASSOCIATED STRUCTS
// --------------------------------------------------------------------------

struct FrameInfo {
int64_t pts = 0;
Expand Down Expand Up @@ -309,73 +311,114 @@ class VideoDecoder {
bool operator!=(const DecodedFrameContext&);
};

// Stores information for each stream.
struct StreamInfo {
int streamIndex = -1;
AVStream* stream = nullptr;
AVRational timeBase = {};
UniqueAVCodecContext codecContext;
// The current position of the cursor in the stream.

// The FrameInfo indices we built when scanFileAndUpdateMetadataAndIndex was
// called.
std::vector<FrameInfo> keyFrames;
std::vector<FrameInfo> allFrames;

// The current position of the cursor in the stream, and associated frame
// duration.
int64_t currentPts = 0;
int64_t currentDuration = 0;
// The desired position of the cursor in the stream. We send frames >=
// this pts to the user when they request a frame.
// We update this field if the user requested a seek.
// We update this field if the user requested a seek. This typically
// corresponds to the decoder's desiredPts_ attribute.
int64_t discardFramesBeforePts = INT64_MIN;
VideoStreamOptions videoStreamOptions;
// The filter state associated with this stream (for video streams). The
// actual graph will be nullptr for inactive streams.

// color-conversion fields. Only one of FilterGraphContext and
// UniqueSwsContext should be non-null.
FilterGraphContext filterGraphContext;
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
std::vector<FrameInfo> keyFrames;
std::vector<FrameInfo> allFrames;
DecodedFrameContext prevFrameContext;
UniqueSwsContext swsContext;

// Used to know whether a new FilterGraphContext or UniqueSwsContext should
// be created before decoding a new frame.
DecodedFrameContext prevFrameContext;
};

// Returns the key frame index of the presentation timestamp using FFMPEG's
// index. Note that this index may be truncated for some files.
int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts)
const;
// Returns the key frame index of the presentation timestamp using our index.
// We build this index by scanning the file in buildKeyFrameIndex().
int getKeyFrameIndexForPtsUsingScannedIndex(
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
int64_t pts) const;
int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const;
// --------------------------------------------------------------------------
// INITIALIZERS
// --------------------------------------------------------------------------

void initializeDecoder();
void updateMetadataWithCodecContext(
int streamIndex,
AVCodecContext* codecContext);

// --------------------------------------------------------------------------
// DECODING APIS AND RELATED UTILS
// --------------------------------------------------------------------------

bool canWeAvoidSeekingForStream(
const StreamInfo& stream,
int64_t currentPts,
int64_t targetPts) const;
// Returns the "best" stream index for a given media type. The "best" is
// determined by various heuristics in FFMPEG.
// See
// https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
// for more details about the heuristics.
int getBestStreamIndex(AVMediaType mediaType);
void initializeDecoder();
void validateUserProvidedStreamIndex(int streamIndex);
void validateScannedAllStreams(const std::string& msg);
void validateFrameIndex(
const StreamMetadata& streamMetadata,
int64_t frameIndex);

// Creates and initializes a filter graph for a stream. The filter graph can
// do rescaling and color conversion.
void maybeSeekToBeforeDesiredPts();

AVFrameStream decodeAVFrame(
std::function<bool(int, AVFrame*)> filterFunction);

FrameOutput getNextFrameNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);

FrameOutput convertAVFrameToFrameOutput(
AVFrameStream& avFrameStream,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

void convertAVFrameToFrameOutputOnCPU(
AVFrameStream& avFrameStream,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

torch::Tensor convertAVFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* avFrame);

int convertAVFrameToTensorUsingSwsScale(
int streamIndex,
const AVFrame* avFrame,
torch::Tensor& outputTensor);

// --------------------------------------------------------------------------
// COLOR CONVERSION LIBRARIES HANDLERS CREATION
// --------------------------------------------------------------------------

void createFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth);

int64_t getNumFrames(const StreamMetadata& streamMetadata);
void createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace);

int64_t getPts(
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata,
int64_t frameIndex);
// --------------------------------------------------------------------------
// PTS <-> INDEX CONVERSIONS
// --------------------------------------------------------------------------

double getMinSeconds(const StreamMetadata& streamMetadata);
double getMaxSeconds(const StreamMetadata& streamMetadata);
int getKeyFrameIndexForPts(const StreamInfo& stream, int64_t pts) const;

// Returns the key frame index of the presentation timestamp using our index.
// We build this index by scanning the file in
// scanFileAndUpdateMetadataAndIndex
int getKeyFrameIndexForPtsUsingScannedIndex(
const std::vector<VideoDecoder::FrameInfo>& keyFrames,
int64_t pts) const;
// Return key frame index, from FFmpeg. Potentially less accurate
int getKeyFrameIndexForPtsUsingEncoderIndex(AVStream* stream, int64_t pts)
const;

int64_t secondsToIndexLowerBound(
double seconds,
Expand All @@ -387,40 +430,43 @@ class VideoDecoder {
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata);

void createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace);

void maybeSeekToBeforeDesiredPts();
int64_t getPts(
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata,
int64_t frameIndex);

AVFrameStream decodeAVFrame(
std::function<bool(int, AVFrame*)> filterFunction);
// --------------------------------------------------------------------------
// STREAM AND METADATA APIS
// --------------------------------------------------------------------------

// Once we create a decoder can update the metadata with the codec context.
// For example, for video streams, we can add the height and width of the
// decoded stream.
void updateMetadataWithCodecContext(
int streamIndex,
AVCodecContext* codecContext);
void populateVideoMetadataFromStreamIndex(int streamIndex);
torch::Tensor convertAVFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* avFrame);
int convertAVFrameToTensorUsingSwsScale(
int streamIndex,
const AVFrame* avFrame,
torch::Tensor& outputTensor);
FrameOutput convertAVFrameToFrameOutput(
AVFrameStream& avFrameStream,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
void convertAVFrameToFrameOutputOnCPU(
AVFrameStream& avFrameStream,
FrameOutput& frameOutput,
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);

FrameOutput getNextFrameNoDemuxInternal(
std::optional<torch::Tensor> preAllocatedOutputTensor = std::nullopt);
// Returns the "best" stream index for a given media type. The "best" is
// determined by various heuristics in FFMPEG.
// See
// https://ffmpeg.org/doxygen/trunk/group__lavf__decoding.html#ga757780d38f482deb4d809c6c521fbcc2
// for more details about the heuristics.
// Returns the key frame index of the presentation timestamp using FFMPEG's
// index. Note that this index may be truncated for some files.
int getBestStreamIndex(AVMediaType mediaType);

int64_t getNumFrames(const StreamMetadata& streamMetadata);
double getMinSeconds(const StreamMetadata& streamMetadata);
double getMaxSeconds(const StreamMetadata& streamMetadata);

// --------------------------------------------------------------------------
// VALIDATION UTILS
// --------------------------------------------------------------------------

void validateUserProvidedStreamIndex(int streamIndex);
void validateScannedAllStreams(const std::string& msg);
void validateFrameIndex(
const StreamMetadata& streamMetadata,
int64_t frameIndex);

// --------------------------------------------------------------------------
// ATTRIBUTES
// --------------------------------------------------------------------------

SeekMode seekMode_;
ContainerMetadata containerMetadata_;
Expand All @@ -432,7 +478,6 @@ class VideoDecoder {
// Set when the user wants to seek and stores the desired pts that the user
// wants to seek to.
std::optional<double> desiredPtsSeconds_;

// Stores various internal decoding stats.
DecodeStats decodeStats_;
// Stores the AVIOContext for the input buffer.
Expand Down
Loading