diff --git a/benchmarks/decoders/BenchmarkDecodersMain.cpp b/benchmarks/decoders/BenchmarkDecodersMain.cpp index 25d518de4..fd6b137dc 100644 --- a/benchmarks/decoders/BenchmarkDecodersMain.cpp +++ b/benchmarks/decoders/BenchmarkDecodersMain.cpp @@ -63,7 +63,7 @@ void runNDecodeIterations( decoder->addVideoStreamDecoder(-1); for (double pts : ptsList) { decoder->setCursorPtsInSeconds(pts); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; } if (i + 1 == warmupIterations) { start = std::chrono::high_resolution_clock::now(); @@ -95,7 +95,7 @@ void runNdecodeIterationsGrabbingConsecutiveFrames( VideoDecoder::createFromFilePath(videoPath); decoder->addVideoStreamDecoder(-1); for (int j = 0; j < consecutiveFrameCount; ++j) { - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; } if (i + 1 == warmupIterations) { start = std::chrono::high_resolution_clock::now(); diff --git a/packaging/check_glibcxx.py b/packaging/check_glibcxx.py index 37ff654c7..b7efd9813 100644 --- a/packaging/check_glibcxx.py +++ b/packaging/check_glibcxx.py @@ -46,7 +46,9 @@ all_symbols.add(match.group(0)) if not all_symbols: - raise ValueError(f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong.") + raise ValueError( + f"No GLIBCXX symbols found in {symbol_matches}. Something is wrong." + ) all_versions = (symbol.split("_")[1].split(".") for symbol in all_symbols) all_versions = (tuple(int(v) for v in version) for version in all_versions) diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index dea0e7293..5da3f4928 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -240,11 +240,6 @@ void convertAVFrameToDecodedOutputOnCuda( std::chrono::duration duration = end - start; VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width << " took: " << duration.count() << "us" << std::endl; - if (options.dimensionOrder == "NCHW") { - // The docs guaranty this to return a view: - // https://pytorch.org/docs/stable/generated/torch.permute.html - dst = dst.permute({2, 0, 1}); - } } } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 76c744936..94d2b9538 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -34,31 +34,6 @@ double ptsToSeconds(int64_t pts, const AVRational& timeBase) { return ptsToSeconds(pts, timeBase.den); } -// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. -// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D -// or 4D. -// Calling permute() is guaranteed to return a view as per the docs: -// https://pytorch.org/docs/stable/generated/torch.permute.html -torch::Tensor MaybePermuteHWC2CHW( - const VideoDecoder::VideoStreamDecoderOptions& options, - torch::Tensor& hwcTensor) { - if (options.dimensionOrder == "NHWC") { - return hwcTensor; - } - auto numDimensions = hwcTensor.dim(); - auto shape = hwcTensor.sizes(); - if (numDimensions == 3) { - TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape); - return hwcTensor.permute({2, 0, 1}); - } else if (numDimensions == 4) { - TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape); - return hwcTensor.permute({0, 3, 1, 2}); - } else { - TORCH_CHECK( - false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions); - } -} - struct AVInput { UniqueAVFormatContext formatContext; std::unique_ptr ioBytesContext; @@ -136,6 +111,31 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth( } // namespace +// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. +// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D +// or 4D. +// Calling permute() is guaranteed to return a view as per the docs: +// https://pytorch.org/docs/stable/generated/torch.permute.html +torch::Tensor VideoDecoder::MaybePermuteHWC2CHW( + int streamIndex, + torch::Tensor& hwcTensor) { + if (streams_[streamIndex].options.dimensionOrder == "NHWC") { + return hwcTensor; + } + auto numDimensions = hwcTensor.dim(); + auto shape = hwcTensor.sizes(); + if (numDimensions == 3) { + TORCH_CHECK(shape[2] == 3, "Not a HWC tensor: ", shape); + return hwcTensor.permute({2, 0, 1}); + } else if (numDimensions == 4) { + TORCH_CHECK(shape[3] == 3, "Not a NHWC tensor: ", shape); + return hwcTensor.permute({0, 3, 1, 2}); + } else { + TORCH_CHECK( + false, "Expected tensor with 3 or 4 dimensions, got ", numDimensions); + } +} + VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( const std::string& optionsString) { std::vector tokens = @@ -929,14 +929,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( "Invalid color conversion library: " + std::to_string(static_cast(streamInfo.colorConversionLibrary))); } - if (!preAllocatedOutputTensor.has_value()) { - // We only convert to CHW if a pre-allocated tensor wasn't passed. When a - // pre-allocated tensor is passed, it's up to the caller (typically a - // batch API) to do the conversion. This is more efficient as it allows - // batch NHWC tensors to be permuted only once, instead of permuting HWC - // tensors N times. - output.frame = MaybePermuteHWC2CHW(streamInfo.options, output.frame); - } } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement @@ -978,7 +970,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( return seconds >= frameStartTime && seconds < frameEndTime; }); // Convert the frame to tensor. - return convertAVFrameToDecodedOutput(rawOutput); + auto output = convertAVFrameToDecodedOutput(rawOutput); + output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); + return output; } void VideoDecoder::validateUserProvidedStreamIndex(uint64_t streamIndex) { @@ -1015,6 +1009,16 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor) { + auto output = getFrameAtIndexInternal( + streamIndex, frameIndex, preAllocatedOutputTensor); + output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); + return output; +} + +VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( + int streamIndex, + int64_t frameIndex, + std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getFrameAtIndex"); @@ -1023,7 +1027,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int64_t pts = stream.allFrames[frameIndex].pts; setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); - return getNextDecodedOutputNoDemux(preAllocatedOutputTensor); + return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( @@ -1073,14 +1077,14 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( output.durationSeconds[indexInOutput] = output.durationSeconds[previousIndexInOutput]; } else { - DecodedOutput singleOut = getFrameAtIndex( + DecodedOutput singleOut = getFrameAtIndexInternal( streamIndex, indexInVideo, output.frames[indexInOutput]); output.ptsSeconds[indexInOutput] = singleOut.ptsSeconds; output.durationSeconds[indexInOutput] = singleOut.durationSeconds; } previousIndexInVideo = indexInVideo; } - output.frames = MaybePermuteHWC2CHW(options, output.frames); + output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1150,11 +1154,12 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( BatchDecodedOutput output(numOutputFrames, options, streamMetadata); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { - DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); + DecodedOutput singleOut = + getFrameAtIndexInternal(streamIndex, i, output.frames[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } - output.frames = MaybePermuteHWC2CHW(options, output.frames); + output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1207,7 +1212,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // need this special case below. if (startSeconds == stopSeconds) { BatchDecodedOutput output(0, options, streamMetadata); - output.frames = MaybePermuteHWC2CHW(options, output.frames); + output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1243,11 +1248,12 @@ VideoDecoder::getFramesPlayedByTimestampInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = getFrameAtIndex(streamIndex, i, output.frames[f]); + DecodedOutput singleOut = + getFrameAtIndexInternal(streamIndex, i, output.frames[f]); output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } - output.frames = MaybePermuteHWC2CHW(options, output.frames); + output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1261,7 +1267,13 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { return rawOutput; } -VideoDecoder::DecodedOutput VideoDecoder::getNextDecodedOutputNoDemux( +VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { + auto output = getNextFrameOutputNoDemuxInternal(); + output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); + return output; +} + +VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( std::optional preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); @@ -1283,7 +1295,7 @@ double VideoDecoder::getPtsSecondsForFrame( int streamIndex, int64_t frameIndex) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFrameAtIndex"); + validateScannedAllStreams("getPtsSecondsForFrame"); const auto& stream = streams_[streamIndex]; validateFrameIndex(stream, frameIndex); diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c6b70c895..7411af187 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -157,10 +157,12 @@ class VideoDecoder { int streamIndex, const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions()); + torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextDecodedOutputNoDemux() will return the first frame at or - // after this position. + // Calling getNextFrameOutputNoDemuxInternal() will return the first frame at + // or after this position. void setCursorPtsInSeconds(double seconds); // This is an internal structure that is used to store the decoded output // from decoding a frame through color conversion. Example usage is: @@ -214,8 +216,7 @@ class VideoDecoder { }; // Decodes the frame where the current cursor position is. It also advances // the cursor to the next frame. - DecodedOutput getNextDecodedOutputNoDemux( - std::optional preAllocatedOutputTensor = std::nullopt); + DecodedOutput getNextFrameNoDemux(); // Decodes the first frame in any added stream that is visible at a given // timestamp. Frames in the video have a presentation timestamp and a // duration. For example, if a frame has presentation timestamp of 5.0s and a @@ -386,6 +387,13 @@ class VideoDecoder { DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); + DecodedOutput getFrameAtIndexInternal( + int streamIndex, + int64_t frameIndex, + std::optional preAllocatedOutputTensor = std::nullopt); + DecodedOutput getNextFrameOutputNoDemuxInternal( + std::optional preAllocatedOutputTensor = std::nullopt); + DecoderOptions options_; ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index a769c8b53..7117ccab4 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -193,7 +193,7 @@ OpsDecodedOutput get_next_frame(at::Tensor& decoder) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); VideoDecoder::DecodedOutput result; try { - result = videoDecoder->getNextDecodedOutputNoDemux(); + result = videoDecoder->getNextFrameNoDemux(); } catch (const VideoDecoder::EndOfFileException& e) { C10_THROW_ERROR(IndexError, e.what()); } diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 389d47f4f..5f2e62203 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -148,7 +148,7 @@ TEST(VideoDecoderTest, RespectsWidthAndHeightFromOptions) { streamOptions.width = 100; streamOptions.height = 120; decoder->addVideoStreamDecoder(-1, streamOptions); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -159,7 +159,7 @@ TEST(VideoDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { VideoDecoder::VideoStreamDecoderOptions streamOptions; streamOptions.dimensionOrder = "NHWC"; decoder->addVideoStreamDecoder(-1, streamOptions); - torch::Tensor tensor = decoder->getNextDecodedOutputNoDemux().frame; + torch::Tensor tensor = decoder->getNextFrameNoDemux().frame; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -168,12 +168,12 @@ TEST_P(VideoDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextFrameNoDemux(); torch::Tensor tensor0FromOurDecoder = output.frame; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 0.0); EXPECT_EQ(output.pts, 0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextFrameNoDemux(); torch::Tensor tensor1FromOurDecoder = output.frame; EXPECT_EQ(tensor1FromOurDecoder.sizes(), std::vector({3, 270, 480})); EXPECT_EQ(output.ptsSeconds, 1'001. / 30'000); @@ -254,11 +254,11 @@ TEST_P(VideoDecoderTest, SeeksCloseToEof) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextFrameNoDemux(); EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextFrameNoDemux(); EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); - EXPECT_THROW(ourDecoder->getNextDecodedOutputNoDemux(), std::exception); + EXPECT_THROW(ourDecoder->getNextFrameNoDemux(), std::exception); } TEST_P(VideoDecoderTest, GetsFramePlayedAtTimestamp) { @@ -298,7 +298,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { createDecoderFromPath(path, GetParam()); ourDecoder->addVideoStreamDecoder(-1); ourDecoder->setCursorPtsInSeconds(6.0); - auto output = ourDecoder->getNextDecodedOutputNoDemux(); + auto output = ourDecoder->getNextFrameNoDemux(); torch::Tensor tensor6FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); torch::Tensor tensor6FromFFMPEG = @@ -314,7 +314,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 180); ourDecoder->setCursorPtsInSeconds(6.1); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextFrameNoDemux(); torch::Tensor tensor61FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 183'183. / 30'000); torch::Tensor tensor61FromFFMPEG = @@ -334,7 +334,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_LT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 10); ourDecoder->setCursorPtsInSeconds(10.0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextFrameNoDemux(); torch::Tensor tensor10FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 300'300. / 30'000); torch::Tensor tensor10FromFFMPEG = @@ -351,7 +351,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { EXPECT_GT(ourDecoder->getDecodeStats().numPacketsSentToDecoder, 60); ourDecoder->setCursorPtsInSeconds(6.0); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextFrameNoDemux(); tensor6FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 180'180. / 30'000); EXPECT_TRUE(torch::equal(tensor6FromOurDecoder, tensor6FromFFMPEG)); @@ -366,7 +366,7 @@ TEST_P(VideoDecoderTest, SeeksToFrameWithSpecificPts) { constexpr double kPtsOfLastFrameInVideoStream = 389'389. / 30'000; // ~12.9 ourDecoder->setCursorPtsInSeconds(kPtsOfLastFrameInVideoStream); - output = ourDecoder->getNextDecodedOutputNoDemux(); + output = ourDecoder->getNextFrameNoDemux(); torch::Tensor tensor7FromOurDecoder = output.frame; EXPECT_EQ(output.ptsSeconds, 389'389. / 30'000); torch::Tensor tensor7FromFFMPEG =