diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 3f9c527b5..76c744936 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -873,6 +873,15 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( return output; } +// Note [preAllocatedOutputTensor with swscale and filtergraph]: +// Callers may pass a pre-allocated tensor, where the output frame tensor will +// be stored. This parameter is honored in any case, but it only leads to a +// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the +// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet +// found a way to do that with filtegraph. +// TODO: Figure out whether that's possilbe! +// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of +// `dimension_order` parameter. It's up to callers to re-shape it if needed. void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( VideoDecoder::RawDecodedOutput& rawOutput, DecodedOutput& output, @@ -880,9 +889,9 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( int streamIndex = rawOutput.streamIndex; AVFrame* frame = rawOutput.frame.get(); auto& streamInfo = streams_[streamIndex]; + torch::Tensor tensor; if (output.streamType == AVMEDIA_TYPE_VIDEO) { if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - torch::Tensor tensor; int width = streamInfo.options.width.value_or(frame->width); int height = streamInfo.options.height.value_or(frame->height); if (preAllocatedOutputTensor.has_value()) { @@ -908,7 +917,13 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame); + tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); + if (preAllocatedOutputTensor.has_value()) { + preAllocatedOutputTensor.value().copy_(tensor); + output.frame = preAllocatedOutputTensor.value(); + } else { + output.frame = tensor; + } } else { throw std::runtime_error( "Invalid color conversion library: " + @@ -1060,10 +1075,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( } else { DecodedOutput singleOut = getFrameAtIndex( streamIndex, indexInVideo, output.frames[indexInOutput]); - if (options.colorConversionLibrary == - ColorConversionLibrary::FILTERGRAPH) { - output.frames[indexInOutput] = singleOut.frame; - } output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; output.durationSeconds[indexInOutput] = singleOut.durationSeconds; } @@ -1140,9 +1151,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( for (int64_t i = start, f = 0; i < stop; i += step, ++f) { DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); - if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; - } output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } @@ -1236,9 +1244,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange( BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); - if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - output.frames[f] = singleOut.frame; - } output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 1f7876bf5..389d47f4f 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -387,6 +387,42 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { } } +TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) { + std::string path = getResourcePath("nasa_13013.mp4"); + auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8}); + + std::unique_ptr ourDecoder = + VideoDecoderTest::createDecoderFromPath(path, GetParam()); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + int bestVideoStreamIndex = + *ourDecoder->getContainerMetadata().bestVideoStreamIndex; + ourDecoder->addVideoStreamDecoder( + bestVideoStreamIndex, + VideoDecoder::VideoStreamDecoderOptions( + "color_conversion_library=filtergraph")); + auto output = ourDecoder->getFrameAtIndex( + bestVideoStreamIndex, 0, preAllocatedOutputTensor); + EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr()); +} + +TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) { + std::string path = getResourcePath("nasa_13013.mp4"); + auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8}); + + std::unique_ptr ourDecoder = + VideoDecoderTest::createDecoderFromPath(path, GetParam()); + ourDecoder->scanFileAndUpdateMetadataAndIndex(); + int bestVideoStreamIndex = + *ourDecoder->getContainerMetadata().bestVideoStreamIndex; + ourDecoder->addVideoStreamDecoder( + bestVideoStreamIndex, + VideoDecoder::VideoStreamDecoderOptions( + "color_conversion_library=swscale")); + auto output = ourDecoder->getFrameAtIndex( + bestVideoStreamIndex, 0, preAllocatedOutputTensor); + EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr()); +} + TEST_P(VideoDecoderTest, GetAudioMetadata) { std::string path = getResourcePath("nasa_13013.mp4.audio.mp3"); std::unique_ptr decoder =