Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 17 additions & 12 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -873,16 +873,25 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
return output;
}

// Note [preAllocatedOutputTensor with swscale and filtergraph]:
// Callers may pass a pre-allocated tensor, where the output frame tensor will
// be stored. This parameter is honored in any case, but it only leads to a
// speed-up when swscale is used. With swscale, we can tell ffmpeg to place the
// decoded frame directly into `preAllocatedtensor.data_ptr()`. We haven't yet
// found a way to do that with filtegraph.
// TODO: Figure out whether that's possilbe!
// Dimension order of the preAllocatedOutputTensor must be HWC, regardless of
// `dimension_order` parameter. It's up to callers to re-shape it if needed.
void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
VideoDecoder::RawDecodedOutput& rawOutput,
DecodedOutput& output,
std::optional<torch::Tensor> preAllocatedOutputTensor) {
int streamIndex = rawOutput.streamIndex;
AVFrame* frame = rawOutput.frame.get();
auto& streamInfo = streams_[streamIndex];
torch::Tensor tensor;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
torch::Tensor tensor;
int width = streamInfo.options.width.value_or(frame->width);
int height = streamInfo.options.height.value_or(frame->height);
if (preAllocatedOutputTensor.has_value()) {
Expand All @@ -908,7 +917,13 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
tensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);
if (preAllocatedOutputTensor.has_value()) {
preAllocatedOutputTensor.value().copy_(tensor);
output.frame = preAllocatedOutputTensor.value();
} else {
output.frame = tensor;
}
} else {
throw std::runtime_error(
"Invalid color conversion library: " +
Expand Down Expand Up @@ -1060,10 +1075,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
} else {
DecodedOutput singleOut = getFrameAtIndex(
streamIndex, indexInVideo, output.frames[indexInOutput]);
if (options.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
output.frames[indexInOutput] = singleOut.frame;
}
output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds;
output.durationSeconds[indexInOutput] = singleOut.durationSeconds;
}
Expand Down Expand Up @@ -1140,9 +1151,6 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(

for (int64_t i = start, f = 0; i < stop; i += step, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
output.frames[f] = singleOut.frame;
}
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
Expand Down Expand Up @@ -1236,9 +1244,6 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
BatchDecodedOutput output(numFrames, options, streamMetadata);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]);
if (options.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
output.frames[f] = singleOut.frame;
}
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
Expand Down
36 changes: 36 additions & 0 deletions test/decoders/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,42 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) {
}
}

TEST_P(VideoDecoderTest, PreAllocatedTensorFilterGraph) {
std::string path = getResourcePath("nasa_13013.mp4");
auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8});

std::unique_ptr<VideoDecoder> ourDecoder =
VideoDecoderTest::createDecoderFromPath(path, GetParam());
ourDecoder->scanFileAndUpdateMetadataAndIndex();
int bestVideoStreamIndex =
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
ourDecoder->addVideoStreamDecoder(
bestVideoStreamIndex,
VideoDecoder::VideoStreamDecoderOptions(
"color_conversion_library=filtergraph"));
auto output = ourDecoder->getFrameAtIndex(
bestVideoStreamIndex, 0, preAllocatedOutputTensor);
EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr());
}

TEST_P(VideoDecoderTest, PreAllocatedTensorSwscale) {
std::string path = getResourcePath("nasa_13013.mp4");
auto preAllocatedOutputTensor = torch::empty({270, 480, 3}, {torch::kUInt8});

std::unique_ptr<VideoDecoder> ourDecoder =
VideoDecoderTest::createDecoderFromPath(path, GetParam());
ourDecoder->scanFileAndUpdateMetadataAndIndex();
int bestVideoStreamIndex =
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
ourDecoder->addVideoStreamDecoder(
bestVideoStreamIndex,
VideoDecoder::VideoStreamDecoderOptions(
"color_conversion_library=swscale"));
auto output = ourDecoder->getFrameAtIndex(
bestVideoStreamIndex, 0, preAllocatedOutputTensor);
EXPECT_EQ(output.frame.data_ptr(), preAllocatedOutputTensor.data_ptr());
}

TEST_P(VideoDecoderTest, GetAudioMetadata) {
std::string path = getResourcePath("nasa_13013.mp4.audio.mp3");
std::unique_ptr<VideoDecoder> decoder =
Expand Down
Loading