Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 20 additions & 18 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,12 @@ void VideoDecoder::createFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth) {
FilterState& filterState = streamInfo.filterState;
filterState.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterState.filterGraph.get() != nullptr);
FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext;
filterGraphContext.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr);

if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) {
filterState.filterGraph->nb_threads =
filterGraphContext.filterGraph->nb_threads =
streamInfo.videoStreamOptions.ffmpegThreadCount.value();
}

Expand All @@ -360,25 +360,25 @@ void VideoDecoder::createFilterGraph(
<< codecContext->sample_aspect_ratio.den;

int ffmpegStatus = avfilter_graph_create_filter(
&filterState.sourceContext,
&filterGraphContext.sourceContext,
buffersrc,
"in",
filterArgs.str().c_str(),
nullptr,
filterState.filterGraph.get());
filterGraphContext.filterGraph.get());
if (ffmpegStatus < 0) {
throw std::runtime_error(
std::string("Failed to create filter graph: ") + filterArgs.str() +
": " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}

ffmpegStatus = avfilter_graph_create_filter(
&filterState.sinkContext,
&filterGraphContext.sinkContext,
buffersink,
"out",
nullptr,
nullptr,
filterState.filterGraph.get());
filterGraphContext.filterGraph.get());
if (ffmpegStatus < 0) {
throw std::runtime_error(
"Failed to create filter graph: " +
Expand All @@ -388,7 +388,7 @@ void VideoDecoder::createFilterGraph(
enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE};

ffmpegStatus = av_opt_set_int_list(
filterState.sinkContext,
filterGraphContext.sinkContext,
"pix_fmts",
pix_fmts,
AV_PIX_FMT_NONE,
Expand All @@ -403,11 +403,11 @@ void VideoDecoder::createFilterGraph(
UniqueAVFilterInOut inputs(avfilter_inout_alloc());

outputs->name = av_strdup("in");
outputs->filter_ctx = filterState.sourceContext;
outputs->filter_ctx = filterGraphContext.sourceContext;
outputs->pad_idx = 0;
outputs->next = nullptr;
inputs->name = av_strdup("out");
inputs->filter_ctx = filterState.sinkContext;
inputs->filter_ctx = filterGraphContext.sinkContext;
inputs->pad_idx = 0;
inputs->next = nullptr;

Expand All @@ -418,7 +418,7 @@ void VideoDecoder::createFilterGraph(
AVFilterInOut* outputsTmp = outputs.release();
AVFilterInOut* inputsTmp = inputs.release();
ffmpegStatus = avfilter_graph_parse_ptr(
filterState.filterGraph.get(),
filterGraphContext.filterGraph.get(),
description.str().c_str(),
&inputsTmp,
&outputsTmp,
Expand All @@ -431,7 +431,8 @@ void VideoDecoder::createFilterGraph(
getFFMPEGErrorStringFromErrorCode(ffmpegStatus));
}

ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr);
ffmpegStatus =
avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr);
if (ffmpegStatus < 0) {
throw std::runtime_error(
"Failed to configure filter graph: " +
Expand Down Expand Up @@ -1057,7 +1058,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU(
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
if (!streamInfo.filterState.filterGraph ||
if (!streamInfo.filterGraphContext.filterGraph ||
streamInfo.prevFrameContext != frameContext) {
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrameContext = frameContext;
Expand Down Expand Up @@ -1615,16 +1616,17 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale(
torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph(
int streamIndex,
const AVFrame* avFrame) {
FilterState& filterState = streamInfos_[streamIndex].filterState;
FilterGraphContext& filterGraphContext =
streamInfos_[streamIndex].filterGraphContext;
int ffmpegStatus =
av_buffersrc_write_frame(filterState.sourceContext, avFrame);
av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame);
if (ffmpegStatus < AVSUCCESS) {
throw std::runtime_error("Failed to add frame to buffer source context");
}

UniqueAVFrame filteredAVFrame(av_frame_alloc());
ffmpegStatus =
av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get());
ffmpegStatus = av_buffersink_get_frame(
filterGraphContext.sinkContext, filteredAVFrame.get());
TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24);

auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get());
Expand Down
4 changes: 2 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ class VideoDecoder {
int64_t nextPts = INT64_MAX;
};

struct FilterState {
struct FilterGraphContext {
UniqueAVFilterGraph filterGraph;
AVFilterContext* sourceContext = nullptr;
AVFilterContext* sinkContext = nullptr;
Expand Down Expand Up @@ -325,7 +325,7 @@ class VideoDecoder {
VideoStreamOptions videoStreamOptions;
// The filter state associated with this stream (for video streams). The
// actual graph will be nullptr for inactive streams.
FilterState filterState;
FilterGraphContext filterGraphContext;
ColorConversionLibrary colorConversionLibrary = FILTERGRAPH;
std::vector<FrameInfo> keyFrames;
std::vector<FrameInfo> allFrames;
Expand Down
Loading