diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 8dcb1bb46..bba8b4e4a 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -337,12 +337,12 @@ void VideoDecoder::createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth) { - FilterState& filterState = streamInfo.filterState; - filterState.filterGraph.reset(avfilter_graph_alloc()); - TORCH_CHECK(filterState.filterGraph.get() != nullptr); + FilterGraphContext& filterGraphContext = streamInfo.filterGraphContext; + filterGraphContext.filterGraph.reset(avfilter_graph_alloc()); + TORCH_CHECK(filterGraphContext.filterGraph.get() != nullptr); if (streamInfo.videoStreamOptions.ffmpegThreadCount.has_value()) { - filterState.filterGraph->nb_threads = + filterGraphContext.filterGraph->nb_threads = streamInfo.videoStreamOptions.ffmpegThreadCount.value(); } @@ -360,12 +360,12 @@ void VideoDecoder::createFilterGraph( << codecContext->sample_aspect_ratio.den; int ffmpegStatus = avfilter_graph_create_filter( - &filterState.sourceContext, + &filterGraphContext.sourceContext, buffersrc, "in", filterArgs.str().c_str(), nullptr, - filterState.filterGraph.get()); + filterGraphContext.filterGraph.get()); if (ffmpegStatus < 0) { throw std::runtime_error( std::string("Failed to create filter graph: ") + filterArgs.str() + @@ -373,12 +373,12 @@ void VideoDecoder::createFilterGraph( } ffmpegStatus = avfilter_graph_create_filter( - &filterState.sinkContext, + &filterGraphContext.sinkContext, buffersink, "out", nullptr, nullptr, - filterState.filterGraph.get()); + filterGraphContext.filterGraph.get()); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to create filter graph: " + @@ -388,7 +388,7 @@ void VideoDecoder::createFilterGraph( enum AVPixelFormat pix_fmts[] = {AV_PIX_FMT_RGB24, AV_PIX_FMT_NONE}; ffmpegStatus = av_opt_set_int_list( - filterState.sinkContext, + filterGraphContext.sinkContext, "pix_fmts", pix_fmts, AV_PIX_FMT_NONE, @@ -403,11 +403,11 @@ void VideoDecoder::createFilterGraph( UniqueAVFilterInOut inputs(avfilter_inout_alloc()); outputs->name = av_strdup("in"); - outputs->filter_ctx = filterState.sourceContext; + outputs->filter_ctx = filterGraphContext.sourceContext; outputs->pad_idx = 0; outputs->next = nullptr; inputs->name = av_strdup("out"); - inputs->filter_ctx = filterState.sinkContext; + inputs->filter_ctx = filterGraphContext.sinkContext; inputs->pad_idx = 0; inputs->next = nullptr; @@ -418,7 +418,7 @@ void VideoDecoder::createFilterGraph( AVFilterInOut* outputsTmp = outputs.release(); AVFilterInOut* inputsTmp = inputs.release(); ffmpegStatus = avfilter_graph_parse_ptr( - filterState.filterGraph.get(), + filterGraphContext.filterGraph.get(), description.str().c_str(), &inputsTmp, &outputsTmp, @@ -431,7 +431,8 @@ void VideoDecoder::createFilterGraph( getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - ffmpegStatus = avfilter_graph_config(filterState.filterGraph.get(), nullptr); + ffmpegStatus = + avfilter_graph_config(filterGraphContext.filterGraph.get(), nullptr); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to configure filter graph: " + @@ -1057,7 +1058,7 @@ void VideoDecoder::convertAVFrameToFrameOutputOnCPU( } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - if (!streamInfo.filterState.filterGraph || + if (!streamInfo.filterGraphContext.filterGraph || streamInfo.prevFrameContext != frameContext) { createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); streamInfo.prevFrameContext = frameContext; @@ -1615,16 +1616,17 @@ int VideoDecoder::convertAVFrameToTensorUsingSwsScale( torch::Tensor VideoDecoder::convertAVFrameToTensorUsingFilterGraph( int streamIndex, const AVFrame* avFrame) { - FilterState& filterState = streamInfos_[streamIndex].filterState; + FilterGraphContext& filterGraphContext = + streamInfos_[streamIndex].filterGraphContext; int ffmpegStatus = - av_buffersrc_write_frame(filterState.sourceContext, avFrame); + av_buffersrc_write_frame(filterGraphContext.sourceContext, avFrame); if (ffmpegStatus < AVSUCCESS) { throw std::runtime_error("Failed to add frame to buffer source context"); } UniqueAVFrame filteredAVFrame(av_frame_alloc()); - ffmpegStatus = - av_buffersink_get_frame(filterState.sinkContext, filteredAVFrame.get()); + ffmpegStatus = av_buffersink_get_frame( + filterGraphContext.sinkContext, filteredAVFrame.get()); TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 5b406f85c..0c9f02a2e 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -293,7 +293,7 @@ class VideoDecoder { int64_t nextPts = INT64_MAX; }; - struct FilterState { + struct FilterGraphContext { UniqueAVFilterGraph filterGraph; AVFilterContext* sourceContext = nullptr; AVFilterContext* sinkContext = nullptr; @@ -325,7 +325,7 @@ class VideoDecoder { VideoStreamOptions videoStreamOptions; // The filter state associated with this stream (for video streams). The // actual graph will be nullptr for inactive streams. - FilterState filterState; + FilterGraphContext filterGraphContext; ColorConversionLibrary colorConversionLibrary = FILTERGRAPH; std::vector keyFrames; std::vector allFrames;