From 0620470594dab657b270c490c09d2471cc6a4910 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 12 Sep 2025 09:59:05 -0700 Subject: [PATCH 01/35] First pass on transforms. Committing to switch branches --- src/torchcodec/_core/Transform.cpp | 31 ++++++++++++++++++++++++ src/torchcodec/_core/Transform.h | 38 ++++++++++++++++++++++++++++++ 2 files changed, 69 insertions(+) create mode 100644 src/torchcodec/_core/Transform.cpp create mode 100644 src/torchcodec/_core/Transform.h diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp new file mode 100644 index 000000000..7ac46fbad --- /dev/null +++ b/src/torchcodec/_core/Transform.cpp @@ -0,0 +1,31 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/Transform.h" +#include + +namespace facebook::torchcodec { + +std::string toStringFilterGraph(Transform::InterpolationMode mode) { + switch (mode) { + case Transform::InterpolationMode::BILINEAR: + return "BILINEAR"; + case Transform::InterpolationMode::BICUBIC: + return "BICUBIC"; + case Transform::InterpolationMode::NEAREST: + return "NEAREST"; + default: + TORCH_CHECK(false, "Unknown interpolation mode: " + std::to_string(mode)); + } +} + +std::string Transform::getFilterGraphCpu() const { + return "scale=width=" + std::to_string(width_) + + ":height=" + std::to_string(height_) + + ":sws_flags=" + toStringFilterGraph(interpolationMode_); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h new file mode 100644 index 000000000..20b05ae12 --- /dev/null +++ b/src/torchcodec/_core/Transform.h @@ -0,0 +1,38 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include + +namespace facebook::torchcodec { + +class Transform { + public: + std::string getFilterGraphCpu() const = 0 +}; + +class ResizeTransform : public Transform { + public: + ResizeTransform(int width, int height) + : width_(width), + height_(height), + interpolation_(InterpolationMode::BILINEAR) {} + + ResizeTransform(int width, int height, InterpolationMode interpolation) = + default; + + std::string getFilterGraphCpu() const override; + + enum class InterpolationMode { BILINEAR, BICUBIC, NEAREST }; + + private: + int width_; + int height_; + InterpolationMode interpolation_; +} + +} // namespace facebook::torchcodec From c59de3613e63dc5bdd78ffc2a4c2f7e2fd48eb39 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 07:30:03 -0700 Subject: [PATCH 02/35] Initial C++ implementaiton of transforms --- src/torchcodec/_core/CMakeLists.txt | 1 + src/torchcodec/_core/CpuDeviceInterface.cpp | 150 ++++++----- src/torchcodec/_core/CpuDeviceInterface.h | 20 +- src/torchcodec/_core/CudaDeviceInterface.cpp | 34 ++- src/torchcodec/_core/CudaDeviceInterface.h | 10 +- src/torchcodec/_core/DeviceInterface.h | 10 +- src/torchcodec/_core/FilterGraph.cpp | 3 +- src/torchcodec/_core/Frame.cpp | 24 +- src/torchcodec/_core/Frame.h | 41 ++- src/torchcodec/_core/Metadata.h | 4 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 256 ++++++++----------- src/torchcodec/_core/SingleStreamDecoder.h | 37 ++- src/torchcodec/_core/StreamOptions.h | 2 - src/torchcodec/_core/Transform.cpp | 60 ++++- src/torchcodec/_core/Transform.h | 35 ++- src/torchcodec/_core/custom_ops.cpp | 23 +- test/test_ops.py | 1 + 17 files changed, 413 insertions(+), 298 deletions(-) diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index e3f9102e2..1de8d2d28 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -95,6 +95,7 @@ function(make_torchcodec_libraries SingleStreamDecoder.cpp Encoder.cpp ValidationUtils.cpp + Transform.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 77eaf3d09..817c06398 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -46,6 +46,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) device_.type() == torch::kCPU, "Unsupported device: ", device_.str()); } +void CpuDeviceInterface::initialize( + [[maybe_unused]] AVCodecContext* codecContext, + const VideoStreamOptions& videoStreamOptions, + const std::vector>& transforms, + const AVRational& timeBase, + const FrameDims& outputDims) { + videoStreamOptions_ = videoStreamOptions; + timeBase_ = timeBase; + outputDims_ = outputDims; + + // TODO: rationalize comment below with new stuff. + // By default, we want to use swscale for color conversion because it is + // faster. However, it has width requirements, so we may need to fall back + // to filtergraph. We also need to respect what was requested from the + // options; we respect the options unconditionally, so it's possible for + // swscale's width requirements to be violated. We don't expose the ability to + // choose color conversion library publicly; we only use this ability + // internally. + + // If any transforms are not swscale compatible, then we can't use swscale. + bool areTransformsSwScaleCompatible = true; + for (const auto& transform : transforms) { + areTransformsSwScaleCompatible = + areTransformsSwScaleCompatible && transform->isSwScaleCompatible(); + } + + // swscale requires widths to be multiples of 32: + // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements + bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0; + + bool userRequestedSwScale = + videoStreamOptions_.colorConversionLibrary.has_value() && + videoStreamOptions_.colorConversionLibrary.value() == + ColorConversionLibrary::SWSCALE; + + // Note that we treat the transform limitation differently from the width + // limitation. That is, we consider the transforms being compatible with + // sws_scale as a hard requirement. If the transforms are not compatiable, + // then we will end up not applying the transforms, and that is wrong. + // + // The width requirement, however, is a soft requirement. Even if we don't + // meet it, we let the user override it. We have tests that depend on this + // behavior. Since we don't expose the ability to choose swscale or + // filtergraph in our public API, this is probably okay. It's also the only + // way that we can be certain we are testing one versus the other. + if (areTransformsSwScaleCompatible && + (userRequestedSwScale || isWidthSwScaleCompatible)) { + colorConversionLibrary_ = ColorConversionLibrary::SWSCALE; + } else { + colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH; + + // If we have any transforms, replace filters_ with the filter strings from + // the transforms. + std::stringstream filters; + bool first = true; + for (const auto& transform : transforms) { + if (!first) { + filters << ","; + } + filters << transform->getFilterGraphCpu(); + first = false; + } + if (!transforms.empty()) { + filters_ = filters.str(); + } + } +} + // Note [preAllocatedOutputTensor with swscale and filtergraph]: // Callers may pass a pre-allocated tensor, where the output.data tensor will // be stored. This parameter is honored in any case, but it only leads to a @@ -56,25 +124,18 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void CpuDeviceInterface::convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); - int expectedOutputHeight = frameDims.height; - int expectedOutputWidth = frameDims.width; - if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims_.height) && + (shape[1] == outputDims_.width) && (shape[2] == 3), "Expected pre-allocated tensor of shape ", - expectedOutputHeight, + outputDims_.height, "x", - expectedOutputWidth, + outputDims_.width, "x3, got ", shape); } @@ -83,25 +144,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( enum AVPixelFormat frameFormat = static_cast(avFrame->format); - // By default, we want to use swscale for color conversion because it is - // faster. However, it has width requirements, so we may need to fall back - // to filtergraph. We also need to respect what was requested from the - // options; we respect the options unconditionally, so it's possible for - // swscale's width requirements to be violated. We don't expose the ability to - // choose color conversion library publicly; we only use this ability - // internally. - - // swscale requires widths to be multiples of 32: - // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements - // so we fall back to filtergraph if the width is not a multiple of 32. - auto defaultLibrary = (expectedOutputWidth % 32 == 0) - ? ColorConversionLibrary::SWSCALE - : ColorConversionLibrary::FILTERGRAPH; - - ColorConversionLibrary colorConversionLibrary = - videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); - - if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { + if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) { // We need to compare the current frame context with our previous frame // context. If they are different, then we need to re-create our colorspace // conversion objects. We create our colorspace conversion objects late so @@ -113,11 +156,11 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( avFrame->width, avFrame->height, frameFormat, - expectedOutputWidth, - expectedOutputHeight); + outputDims_.width, + outputDims_.height); - outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( - expectedOutputHeight, expectedOutputWidth, torch::kCPU)); + outputTensor = preAllocatedOutputTensor.value_or( + allocateEmptyHWCTensor(outputDims_, torch::kCPU)); if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { createSwsContext(swsFrameContext, avFrame->colorspace); @@ -129,34 +172,28 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // the expected height. // TODO: Can we do the same check for width? TORCH_CHECK( - resultHeight == expectedOutputHeight, - "resultHeight != expectedOutputHeight: ", + resultHeight == outputDims_.height, + "resultHeight != outputDims_.height: ", resultHeight, " != ", - expectedOutputHeight); + outputDims_.height); frameOutput.data = outputTensor; - } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - // See comment above in swscale branch about the filterGraphContext_ - // creation. creation - std::stringstream filters; - filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; - filters << ":sws_flags=bilinear"; - + } else if (colorConversionLibrary_ == ColorConversionLibrary::FILTERGRAPH) { FiltersContext filtersContext( avFrame->width, avFrame->height, frameFormat, avFrame->sample_aspect_ratio, - expectedOutputWidth, - expectedOutputHeight, + outputDims_.width, + outputDims_.height, AV_PIX_FMT_RGB24, - filters.str(), - timeBase); + filters_, + timeBase_); if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { filterGraphContext_ = - std::make_unique(filtersContext, videoStreamOptions); + std::make_unique(filtersContext, videoStreamOptions_); prevFiltersContext_ = std::move(filtersContext); } outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame); @@ -165,12 +202,12 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // reshaped to its expected dimensions by filtergraph. auto shape = outputTensor.sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims_.height) && + (shape[1] == outputDims_.width) && (shape[2] == 3), "Expected output tensor of shape ", - expectedOutputHeight, + outputDims_.height, "x", - expectedOutputWidth, + outputDims_.width, "x3, got ", shape); @@ -186,7 +223,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( TORCH_CHECK( false, "Invalid color conversion library: ", - static_cast(colorConversionLibrary)); + static_cast(colorConversionLibrary_)); } } @@ -214,9 +251,8 @@ torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( TORCH_CHECK_EQ(filteredAVFrame->format, AV_PIX_FMT_RGB24); - auto frameDims = getHeightAndWidthFromResizedAVFrame(*filteredAVFrame.get()); - int height = frameDims.height; - int width = frameDims.width; + int height = filteredAVFrame->height; + int width = filteredAVFrame->width; std::vector shape = {height, width, 3}; std::vector strides = {filteredAVFrame->linesize[0], 3, 1}; AVFrame* filteredAVFramePtr = filteredAVFrame.release(); diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index d6004ca3b..2bba6fd21 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -23,12 +23,14 @@ class CpuDeviceInterface : public DeviceInterface { return std::nullopt; } - void initializeContext( - [[maybe_unused]] AVCodecContext* codecContext) override {} - - void convertAVFrameToFrameOutput( + virtual void initialize( + [[maybe_unused]] AVCodecContext* codecContext, const VideoStreamOptions& videoStreamOptions, + const std::vector>& transforms, const AVRational& timeBase, + const FrameDims& outputDims) override; + + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = @@ -64,6 +66,16 @@ class CpuDeviceInterface : public DeviceInterface { const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace); + VideoStreamOptions videoStreamOptions_; + ColorConversionLibrary colorConversionLibrary_; + AVRational timeBase_; + FrameDims outputDims_; + + // The copy filter just copies the input to the output. Computationally, it + // should be a no-op. If we get no user-provided transforms, we will use the + // copy filter. + std::string filters_ = "copy"; + // color-conversion fields. Only one of FilterGraphContext and // UniqueSwsContext should be non-null. std::unique_ptr filterGraphContext_; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 74b556ed0..d4f704b88 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -185,9 +185,16 @@ CudaDeviceInterface::~CudaDeviceInterface() { } } -void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) { +void CudaDeviceInterface::initialize( + AVCodecContext* codecContext, + [[maybe_unsued]] const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const std::vector>& transforms, + [[maybe_unused]] const AVRational& timeBase, + const FrameDims& outputDims) { TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + outputDims_ = outputDims; + // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. // This is a dummy tensor to initialize the cuda context. @@ -196,12 +203,9 @@ void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) { ctx_ = getCudaContext(device_); nppCtx_ = getNppStreamContext(device_); codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); - return; } void CudaDeviceInterface::convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - [[maybe_unused]] const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { @@ -219,11 +223,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( FrameOutput cpuFrameOutput; cpuInterface->convertAVFrameToFrameOutput( - videoStreamOptions, - timeBase, - avFrame, - cpuFrameOutput, - preAllocatedOutputTensor); + avFrame, cpuFrameOutput, preAllocatedOutputTensor); frameOutput.data = cpuFrameOutput.data.to(device_); return; @@ -253,25 +253,21 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "If the video is 10bit, we are tracking 10bit support in " "https://github.com/pytorch/torchcodec/issues/776"); - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); - int height = frameDims.height; - int width = frameDims.width; torch::Tensor& dst = frameOutput.data; if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); auto shape = dst.sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && - (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims_.height) && + (shape[1] == outputDims_.width) && (shape[2] == 3), "Expected tensor of shape ", - height, + outputDims_.height, "x", - width, + outputDims_.width, "x3, got ", shape); } else { - dst = allocateEmptyHWCTensor(height, width, device_); + dst = allocateEmptyHWCTensor(outputDims_, device_); } torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_); @@ -308,7 +304,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "cudaStreamGetFlags failed: ", cudaGetErrorString(err)); - NppiSize oSizeROI = {width, height}; + NppiSize oSizeROI = {outputDims_.width, outputDims_.height}; Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; NppStatus status; diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index f29caff42..aab5a8539 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -19,17 +19,25 @@ class CudaDeviceInterface : public DeviceInterface { std::optional findCodec(const AVCodecID& codecId) override; - void initializeContext(AVCodecContext* codecContext) override; + void initialize( + AVCodecContext* codecContext, + [[maybe_unsued]] const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const std::vector>& + transforms, + [[maybe_unused]] const AVRational& timeBase, + const FrameDims& outputDims) override; void convertAVFrameToFrameOutput( const VideoStreamOptions& videoStreamOptions, const AVRational& timeBase, UniqueAVFrame& avFrame, + const FrameDims& outputDims, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) override; private: + FrameDims outputDims_; UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; }; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 9a7288eb0..a28e03ddc 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -14,6 +14,7 @@ #include "FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" #include "src/torchcodec/_core/StreamOptions.h" +#include "src/torchcodec/_core/Transform.h" namespace facebook::torchcodec { @@ -31,11 +32,14 @@ class DeviceInterface { // Initialize the hardware device that is specified in `device`. Some builds // support CUDA and others only support CPU. - virtual void initializeContext(AVCodecContext* codecContext) = 0; - - virtual void convertAVFrameToFrameOutput( + virtual void initialize( + AVCodecContext* codecContext, const VideoStreamOptions& videoStreamOptions, + const std::vector>& transforms, const AVRational& timeBase, + const FrameDims& outputDims) = 0; + + virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index 43a12f092..c4a77da95 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -137,7 +137,8 @@ FilterGraph::FilterGraph( TORCH_CHECK( status >= 0, "Failed to parse filter description: ", - getFFMPEGErrorStringFromErrorCode(status)); + getFFMPEGErrorStringFromErrorCode(status), + ", provided filters: " + filtersContext.filtergraphStr); status = avfilter_graph_config(filterGraph_.get(), nullptr); TORCH_CHECK( diff --git a/src/torchcodec/_core/Frame.cpp b/src/torchcodec/_core/Frame.cpp index bc3bbb788..9fa87a1cb 100644 --- a/src/torchcodec/_core/Frame.cpp +++ b/src/torchcodec/_core/Frame.cpp @@ -8,24 +8,34 @@ namespace facebook::torchcodec { +FrameBatchOutput::FrameBatchOutput( + int64_t numFrames, + const FrameDims& outputDims, + const torch::Device& device) + : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), + durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { + data = allocateEmptyHWCTensor(outputDims, device, numFrames); +} + torch::Tensor allocateEmptyHWCTensor( - int height, - int width, - torch::Device device, + const FrameDims& frameDims, + const torch::Device& device, std::optional numFrames) { auto tensorOptions = torch::TensorOptions() .dtype(torch::kUInt8) .layout(torch::kStrided) .device(device); - TORCH_CHECK(height > 0, "height must be > 0, got: ", height); - TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + TORCH_CHECK( + frameDims.height > 0, "height must be > 0, got: ", frameDims.height); + TORCH_CHECK(frameDims.width > 0, "width must be > 0, got: ", frameDims.width); if (numFrames.has_value()) { auto numFramesValue = numFrames.value(); TORCH_CHECK( numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); - return torch::empty({numFramesValue, height, width, 3}, tensorOptions); + return torch::empty( + {numFramesValue, frameDims.height, frameDims.width, 3}, tensorOptions); } else { - return torch::empty({height, width, 3}, tensorOptions); + return torch::empty({frameDims.height, frameDims.width, 3}, tensorOptions); } } diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 84ccc7288..99faaa816 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -13,6 +13,15 @@ namespace facebook::torchcodec { +struct FrameDims { + int width = 0; + int height = 0; + + FrameDims() = default; + + FrameDims(int w, int h) : width(w), height(h) {} +}; + // All public video decoding entry points return either a FrameOutput or a // FrameBatchOutput. // They are the equivalent of the user-facing Frame and FrameBatch classes in @@ -34,10 +43,10 @@ struct FrameBatchOutput { torch::Tensor ptsSeconds; // 1D of shape (N,) torch::Tensor durationSeconds; // 1D of shape (N,) - explicit FrameBatchOutput( + FrameBatchOutput( int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); + const FrameDims& outputDims, + const torch::Device& device); }; struct AudioFramesOutput { @@ -56,6 +65,8 @@ struct AudioFramesOutput { // the high-level decoding entry-points to permute that back to CHW, by calling // maybePermuteHWC2CHW(). // +// TODO: Rationalize the comment below with refactoring. +// // Also, importantly, the way we figure out the the height and width of the // output frame tensor varies, and depends on the decoding entry-point. In // *decreasing order of accuracy*, we use the following sources for determining @@ -90,29 +101,9 @@ struct AudioFramesOutput { // it is very important to check the height and width assumptions where the // tensors memory is used/filled in order to avoid segfaults. -struct FrameDims { - int height; - int width; - - FrameDims(int h, int w) : height(h), width(w) {} -}; - -// There's nothing preventing you from calling this on a non-resized frame, but -// please don't. -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); - -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); - -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame); - torch::Tensor allocateEmptyHWCTensor( - int height, - int width, - torch::Device device, + const FrameDims& frameDims, + const torch::Device& device, std::optional numFrames = std::nullopt); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index f9ca85e67..ace6cf84c 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -44,8 +44,8 @@ struct StreamMetadata { std::optional numFramesFromContent; // Video-only fields derived from the AVCodecContext. - std::optional width; - std::optional height; + std::optional width; + std::optional height; std::optional sampleAspectRatio; // Audio-only fields diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index b52556e93..e152c4539 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -412,8 +412,6 @@ void SingleStreamDecoder::addStream( streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; - deviceInterface_ = createDeviceInterface(device); - // This should never happen, checking just to be safe. TORCH_CHECK( streamInfo.stream->codecpar->codec_type == mediaType, @@ -421,14 +419,17 @@ void SingleStreamDecoder::addStream( activeStreamIndex_, " which is of the wrong media type."); + deviceInterface_ = createDeviceInterface(device); + TORCH_CHECK( + deviceInterface_ != nullptr, + "Failed to create device interface. This should never happen, please report."); + // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (deviceInterface_) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); - } + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); @@ -442,13 +443,6 @@ void SingleStreamDecoder::addStream( streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0); streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; - // TODO_CODE_QUALITY same as above. - if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (deviceInterface_) { - deviceInterface_->initializeContext(codecContext); - } - } - retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); @@ -469,6 +463,7 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, + std::vector& transforms, const VideoStreamOptions& videoStreamOptions, std::optional customFrameMappings) { addStream( @@ -503,6 +498,30 @@ void SingleStreamDecoder::addVideoStream( readCustomFrameMappingsUpdateMetadataAndIndex( streamIndex, customFrameMappings.value()); } + + TORCH_CHECK( + !videoStreamOptions.width.has_value(), "width should have no value!"); + TORCH_CHECK( + !videoStreamOptions.height.has_value(), "height should have no value!"); + outputDims_ = + FrameDims(streamMetadata.width.value(), streamMetadata.height.value()); + for (auto& transform : transforms) { + TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); + if (transform->getOutputFrameDims().has_value()) { + outputDims_ = transform->getOutputFrameDims().value(); + } + transforms_.push_back(std::unique_ptr(transform)); + } + + // We initialize the device context late because we want to know a lot of + // information that we can only know after resolving the codec, opening the + // stream and inspecting the metadata. + deviceInterface_->initialize( + streamInfo.codecContext.get(), + videoStreamOptions, + transforms_, + streamInfo.timeBase, + outputDims_); } void SingleStreamDecoder::addAudioStream( @@ -596,7 +615,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( std::vector argsort; if (!indicesAreSorted) { // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we + // want // to use to decode the frames // and argsort is [ 1, 3, 2, 0] argsort.resize(frameIndices.size()); @@ -609,12 +629,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( }); } - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - frameIndices.size(), videoStreamOptions, streamMetadata); + frameIndices.size(), outputDims_, videoStreamOptions.device); auto previousIndexInVideo = -1; for (size_t f = 0; f < frameIndices.size(); ++f) { @@ -657,8 +675,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); - // Note that if we do not have the number of frames available in our metadata, - // then we assume that the upper part of the range is valid. + // Note that if we do not have the number of frames available in our + // metadata, then we assume that the upper part of the range is valid. std::optional numFrames = getNumFrames(streamMetadata); if (numFrames.has_value()) { TORCH_CHECK( @@ -671,7 +689,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - numOutputFrames, videoStreamOptions, streamMetadata); + numOutputFrames, outputDims_, videoStreamOptions.device); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput frameOutput = @@ -687,9 +705,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double lastDecodedStartTime = - ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); + ptsToSeconds(lastDecodedAvFramePts_, streamInfo.timeBase); double lastDecodedEndTime = ptsToSeconds( - streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration, + lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_, streamInfo.timeBase); if (seconds >= lastDecodedStartTime && seconds < lastDecodedEndTime) { // We are in the same frame as the one we just returned. However, since we @@ -709,9 +727,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() // in maybeSeekToBeforeDesiredPts(). - // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137 - // In this case we return the very next frame instead of throwing an - // exception. + // This could be a bug in FFMPEG: + // https://trac.ffmpeg.org/ticket/11137 In this case we return the + // very next frame instead of throwing an exception. // TODO: Maybe log to stderr for Debug builds? return true; } @@ -791,13 +809,14 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // interval B: [0.2, 0.15) // // Both intervals take place between the pts values for frame 0 and frame 1, - // which by our abstract player, means that both intervals map to frame 0. By - // the definition of a half open interval, interval A should return no frames. - // Interval B should return frame 0. However, for both A and B, the individual - // values of the intervals will map to the same frame indices below. Hence, we - // need this special case below. + // which by our abstract player, means that both intervals map to frame 0. + // By the definition of a half open interval, interval A should return no + // frames. Interval B should return frame 0. However, for both A and B, the + // individual values of the intervals will map to the same frame indices + // below. Hence, we need this special case below. if (startSeconds == stopSeconds) { - FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); + FrameBatchOutput frameBatchOutput( + 0, outputDims_, videoStreamOptions.device); frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -809,8 +828,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( "; must be greater than or equal to " + std::to_string(minSeconds) + "."); - // Note that if we can't determine the maximum seconds from the metadata, then - // we assume upper range is valid. + // Note that if we can't determine the maximum seconds from the metadata, + // then we assume upper range is valid. std::optional maxSeconds = getMaxSeconds(streamMetadata); if (maxSeconds.has_value()) { TORCH_CHECK( @@ -842,7 +861,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; FrameBatchOutput frameBatchOutput( - numFrames, videoStreamOptions, streamMetadata); + numFrames, outputDims_, videoStreamOptions.device); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput frameOutput = getFrameAtIndexInternal(i, frameBatchOutput.data[f]); @@ -863,25 +882,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // `numChannels` values. An audio frame, or a sequence thereof, is always // converted into a tensor of shape `(numChannels, numSamplesPerChannel)`. // -// The notion of 'frame' in audio isn't what users want to interact with. Users -// want to interact with samples. The C++ and core APIs return frames, because -// we want those to be close to FFmpeg concepts, but the higher-level public -// APIs expose samples. As a result: +// The notion of 'frame' in audio isn't what users want to interact with. +// Users want to interact with samples. The C++ and core APIs return frames, +// because we want those to be close to FFmpeg concepts, but the higher-level +// public APIs expose samples. As a result: // - We don't expose index-based APIs for audio, because that would mean -// exposing the concept of audio frame. For now, we think exposing time-based -// APIs is more natural. -// - We never perform a scan for audio streams. We don't need to, since we won't +// exposing the concept of audio frame. For now, we think exposing +// time-based APIs is more natural. +// - We never perform a scan for audio streams. We don't need to, since we +// won't // be converting timestamps to indices. That's why we enforce the seek_mode -// to be "approximate" (which is slightly misleading, because technically the -// output samples will be at their exact positions. But this incongruence is -// only exposed at the C++/core private levels). +// to be "approximate" (which is slightly misleading, because technically +// the output samples will be at their exact positions. But this +// incongruence is only exposed at the C++/core private levels). // // Audio frames are of variable dimensions: in the same stream, a frame can // contain 1024 samples and the next one may contain 512 [1]. This makes it // impossible to stack audio frames in the same way we can stack video frames. -// This is one of the main reasons we cannot reuse the same pre-allocation logic -// we have for videos in getFramesPlayedInRange(): pre-allocating a batch -// requires constant (and known) frame dimensions. That's also why +// This is one of the main reasons we cannot reuse the same pre-allocation +// logic we have for videos in getFramesPlayedInRange(): pre-allocating a +// batch requires constant (and known) frame dimensions. That's also why // *concatenated* along the samples dimension, not stacked. // // [IMPORTANT!] There is one key invariant that we must respect when decoding @@ -889,10 +909,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // // BEFORE DECODING FRAME i, WE MUST DECODE ALL FRAMES j < i. // -// Always. Why? We don't know. What we know is that if we don't, we get clipped, -// incorrect audio as output [2]. All other (correct) libraries like TorchAudio -// or Decord do something similar, whether it was intended or not. This has a -// few implications: +// Always. Why? We don't know. What we know is that if we don't, we get +// clipped, incorrect audio as output [2]. All other (correct) libraries like +// TorchAudio or Decord do something similar, whether it was intended or not. +// This has a few implications: // - The **only** place we're allowed to seek to in an audio stream is the // stream's beginning. This ensures that if we need a frame, we'll have // decoded all previous frames. @@ -900,8 +920,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually // seek. // - We try not to seek, when we can avoid it. Typically if the next frame we -// need is in the future, we don't seek back to the beginning, we just decode -// all the frames in-between. +// need is in the future, we don't seek back to the beginning, we just +// decode all the frames in-between. // // [2] If you're brave and curious, you can read the long "Seek offset for // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which @@ -928,11 +948,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( } auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); - if (startPts < streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration) { - // If we need to seek backwards, then we have to seek back to the beginning - // of the stream. - // See [Audio Decoding Design]. + if (startPts < lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_) { + // If we need to seek backwards, then we have to seek back to the + // beginning of the stream. See [Audio Decoding Design]. setCursor(INT64_MIN); } @@ -966,9 +984,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( // stop decoding more frames. Note that if we were to use [begin, end), // which may seem more natural, then we would decode the frame starting at // stopSeconds, which isn't what we want! - auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration; - finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts && + auto lastDecodedAvFrameEnd = + lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_; + finished |= (lastDecodedAvFramePts_) <= stopPts && (stopPts <= lastDecodedAvFrameEnd); } @@ -1035,18 +1053,16 @@ I P P P I P P P I P P I P P I P bool SingleStreamDecoder::canWeAvoidSeeking() const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - // For audio, we only need to seek if a backwards seek was requested within - // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called. - // For more context, see [Audio Decoding Design] + // For audio, we only need to seek if a backwards seek was requested + // within getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was + // called. For more context, see [Audio Decoding Design] return !cursorWasJustSet_; } - int64_t lastDecodedAvFramePts = - streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; - if (cursor_ < lastDecodedAvFramePts) { + if (cursor_ < lastDecodedAvFramePts_) { // We can never skip a seek if we are seeking backwards. return false; } - if (lastDecodedAvFramePts == cursor_) { + if (lastDecodedAvFramePts_ == cursor_) { // We are seeking to the exact same frame as we are currently at. Without // caching we have to rewind back and decode the frame again. // TODO: https://github.com/pytorch/torchcodec/issues/84 we could @@ -1056,7 +1072,7 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const { // We are seeking forwards. // We can only skip a seek if both lastDecodedAvFramePts and // cursor_ share the same keyframe. - int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts); + int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_); int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_); return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 && lastDecodedAvFrameIndex == targetKeyFrameIndex; @@ -1216,14 +1232,15 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( getFFMPEGErrorStringFromErrorCode(status)); } - // Note that we don't flush the decoder when we reach EOF (even though that's - // mentioned in https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). - // This is because we may have packets internally in the decoder that we - // haven't received as frames. Eventually we will either hit AVERROR_EOF from - // av_receive_frame() or the user will have seeked to a different location in - // the file and that will flush the decoder. - streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame); - streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); + // Note that we don't flush the decoder when we reach EOF (even though + // that's mentioned in + // https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). This is + // because we may have packets internally in the decoder that we haven't + // received as frames. Eventually we will either hit AVERROR_EOF from + // av_receive_frame() or the user will have seeked to a different location + // in the file and that will flush the decoder. + lastDecodedAvFramePts_ = getPtsOrDts(avFrame); + lastDecodedAvFrameDuration_ = getDuration(avFrame); return avFrame; } @@ -1247,16 +1264,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); } else { - TORCH_CHECK( - deviceInterface_ != nullptr, - "No device interface available for video decoding. This ", - "shouldn't happen, please report."); deviceInterface_->convertAVFrameToFrameOutput( - streamInfo.videoStreamOptions, - streamInfo.timeBase, - avFrame, - frameOutput, - preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); } return frameOutput; } @@ -1292,8 +1301,8 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( UniqueAVFrame convertedAVFrame; if (mustConvert) { - if (!streamInfo.swrContext) { - streamInfo.swrContext.reset(createSwrContext( + if (!swrContext_) { + swrContext_.reset(createSwrContext( srcSampleFormat, outSampleFormat, srcSampleRate, @@ -1303,7 +1312,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( } convertedAVFrame = convertAudioAVFrameSamples( - streamInfo.swrContext, + swrContext_, srcAVFrame, outSampleFormat, outSampleRate, @@ -1351,15 +1360,15 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // When sample rate conversion is involved, swresample buffers some of the // samples in-between calls to swr_convert (see the libswresample docs). - // That's because the last few samples in a given frame require future samples - // from the next frame to be properly converted. This function flushes out the - // samples that are stored in swresample's buffers. + // That's because the last few samples in a given frame require future + // samples from the next frame to be properly converted. This function + // flushes out the samples that are stored in swresample's buffers. auto& streamInfo = streamInfos_[activeStreamIndex_]; - if (!streamInfo.swrContext) { + if (!swrContext_) { return std::nullopt; } auto numRemainingSamples = // this is an upper bound - swr_get_out_samples(streamInfo.swrContext.get(), 0); + swr_get_out_samples(swrContext_.get(), 0); if (numRemainingSamples == 0) { return std::nullopt; @@ -1376,11 +1385,7 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { } auto actualNumRemainingSamples = swr_convert( - streamInfo.swrContext.get(), - outputBuffers.data(), - numRemainingSamples, - nullptr, - 0); + swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); return lastSamples.narrow( /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); @@ -1390,25 +1395,10 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- -FrameBatchOutput::FrameBatchOutput( - int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata) - : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), - durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { - auto frameDims = getHeightAndWidthFromOptionsOrMetadata( - videoStreamOptions, streamMetadata); - int height = frameDims.height; - int width = frameDims.width; - data = allocateEmptyHWCTensor( - height, width, videoStreamOptions.device, numFrames); -} - -// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. -// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D -// or 4D. -// Calling permute() is guaranteed to return a view as per the docs: -// https://pytorch.org/docs/stable/generated/torch.permute.html +// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require +// so. The [N] leading batch-dimension is optional i.e. the input tensor can +// be 3D or 4D. Calling permute() is guaranteed to return a view as per the +// docs: https://pytorch.org/docs/stable/generated/torch.permute.html torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( torch::Tensor& hwcTensor) { if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder == @@ -1628,8 +1618,8 @@ void SingleStreamDecoder::validateFrameIndex( "and the number of frames must be known."); } - // Note that if we do not have the number of frames available in our metadata, - // then we assume that the frameIndex is valid. + // Note that if we do not have the number of frames available in our + // metadata, then we assume that the frameIndex is valid. std::optional numFrames = getNumFrames(streamMetadata); if (numFrames.has_value()) { if (frameIndex >= numFrames.value()) { @@ -1680,28 +1670,4 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) { streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); } -// -------------------------------------------------------------------------- -// FrameDims APIs -// -------------------------------------------------------------------------- - -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { - return FrameDims(resizedAVFrame.height, resizedAVFrame.width); -} - -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata) { - return FrameDims( - videoStreamOptions.height.value_or(*streamMetadata.height), - videoStreamOptions.width.value_or(*streamMetadata.width)); -} - -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame) { - return FrameDims( - videoStreamOptions.height.value_or(avFrame->height), - videoStreamOptions.width.value_or(avFrame->width)); -} - } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 56bb8bb58..2c93accfc 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -17,6 +17,7 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" #include "src/torchcodec/_core/StreamOptions.h" +#include "src/torchcodec/_core/Transform.h" namespace facebook::torchcodec { @@ -83,6 +84,7 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, + std::vector& transforms, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), std::optional customFrameMappings = std::nullopt); void addAudioStream( @@ -226,17 +228,8 @@ class SingleStreamDecoder { std::vector keyFrames; std::vector allFrames; - // TODO since the decoder is single-stream, these should be decoder fields, - // not streamInfo fields. And they should be defined right next to - // `cursor_`, with joint documentation. - int64_t lastDecodedAvFramePts = 0; - int64_t lastDecodedAvFrameDuration = 0; VideoStreamOptions videoStreamOptions; AudioStreamOptions audioStreamOptions; - - // color-conversion fields. Only one of FilterGraphContext and - // UniqueSwsContext should be non-null. - UniqueSwrContext swrContext; }; // -------------------------------------------------------------------------- @@ -356,16 +349,40 @@ class SingleStreamDecoder { const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; - bool cursorWasJustSet_ = false; // The desired position of the cursor in the stream. We send frames >= this // pts to the user when they request a frame. int64_t cursor_ = INT64_MIN; + bool cursorWasJustSet_ = false; + int64_t lastDecodedAvFramePts_ = 0; + int64_t lastDecodedAvFrameDuration_ = 0; + + // Audio only. We cache it for performance. The video equivalents live in + // deviceInterface_. We store swrContext_ here because we only handle audio + // on the CPU. + UniqueSwrContext swrContext_; + // Stores various internal decoding stats. DecodeStats decodeStats_; + // Stores the AVIOContext for the input buffer. std::unique_ptr avioContextHolder_; + + // We will receive a vector of transforms upon adding a stream and store it + // here. However, we need to know if any of those operations change the + // dimensions of the output frame. If they do, we need to figure out what are + // the final dimensions of the output frame after ALL transformations. We + // figure this out as soon as we receive the transforms. If any of the + // transforms change the final output frame dimensions, we store that in + // finalDims_. + // + // By default, we set outputDims_ to the dimensions from the metadata from the + // stream. + FrameDims outputDims_; + std::vector> transforms_; + // Whether or not we have already scanned all streams to update the metadata. bool scannedAllStreams_ = false; + // Tracks that we've already been initialized. bool initialized_ = false; }; diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 19cc5126c..7665d6c9f 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -31,8 +31,6 @@ struct VideoStreamOptions { // Currently the dimension order can be either NHWC or NCHW. // H=height, W=width, C=channel. std::string dimensionOrder = "NCHW"; - // The output height and width of the frame. If not specified, the output - // is the same as the original video. std::optional width; std::optional height; std::optional colorConversionLibrary; diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index 7ac46fbad..c9504387e 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -6,26 +6,62 @@ #include "src/torchcodec/_core/Transform.h" #include +#include "src/torchcodec/_core/FFMPEGCommon.h" namespace facebook::torchcodec { -std::string toStringFilterGraph(Transform::InterpolationMode mode) { +namespace { + +std::string toFilterGraphInterpolation( + ResizeTransform::InterpolationMode mode) { switch (mode) { - case Transform::InterpolationMode::BILINEAR: - return "BILINEAR"; - case Transform::InterpolationMode::BICUBIC: - return "BICUBIC"; - case Transform::InterpolationMode::NEAREST: - return "NEAREST"; + case ResizeTransform::InterpolationMode::BILINEAR: + return "bilinear"; + case ResizeTransform::InterpolationMode::BICUBIC: + return "bicubic"; + case ResizeTransform::InterpolationMode::NEAREST: + return "nearest"; default: - TORCH_CHECK(false, "Unknown interpolation mode: " + std::to_string(mode)); + TORCH_CHECK( + false, + "Unknown interpolation mode: " + + std::to_string(static_cast(mode))); } } -std::string Transform::getFilterGraphCpu() const { - return "scale=width=" + std::to_string(width_) + - ":height=" + std::to_string(height_) + - ":sws_flags=" + toStringFilterGraph(interpolationMode_); +int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { + switch (mode) { + case ResizeTransform::InterpolationMode::BILINEAR: + return SWS_BILINEAR; + case ResizeTransform::InterpolationMode::BICUBIC: + return SWS_BICUBIC; + case ResizeTransform::InterpolationMode::NEAREST: + return SWS_POINT; + default: + TORCH_CHECK( + false, + "Unknown interpolation mode: " + + std::to_string(static_cast(mode))); + } +} + +} // namespace + +std::string ResizeTransform::getFilterGraphCpu() const { + return "scale=" + std::to_string(width_) + ":" + std::to_string(height_) + + ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_); +} + +std::optional ResizeTransform::getOutputFrameDims() const { + return FrameDims(width_, height_); +} + +bool ResizeTransform::isSwScaleCompatible() const { + return true; +} + +int ResizeTransform::getSwsFlags() const { + return toSwsInterpolation(interpolationMode_); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 20b05ae12..66c47612f 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -6,33 +6,54 @@ #pragma once +#include #include +#include "src/torchcodec/_core/Frame.h" namespace facebook::torchcodec { class Transform { public: - std::string getFilterGraphCpu() const = 0 + virtual std::string getFilterGraphCpu() const = 0; + + // If the transformation does not change the output frame dimensions, then + // there is no need to override this member function. The default + // implementation returns an empty optional, indicating that the output frame + // has the same dimensions as the input frame. + // + // If the transformation does change the output frame dimensions, then it + // must override this member function and return the output frame dimensions. + virtual std::optional getOutputFrameDims() const { + return std::nullopt; + } + + virtual bool isSwScaleCompatible() const { + return false; + } }; class ResizeTransform : public Transform { public: + enum class InterpolationMode { BILINEAR, BICUBIC, NEAREST }; + ResizeTransform(int width, int height) : width_(width), height_(height), - interpolation_(InterpolationMode::BILINEAR) {} + interpolationMode_(InterpolationMode::BILINEAR) {} - ResizeTransform(int width, int height, InterpolationMode interpolation) = - default; + ResizeTransform(int width, int height, InterpolationMode interpolationMode) + : width_(width), height_(height), interpolationMode_(interpolationMode) {} std::string getFilterGraphCpu() const override; + std::optional getOutputFrameDims() const override; + bool isSwScaleCompatible() const override; - enum class InterpolationMode { BILINEAR, BICUBIC, NEAREST }; + int getSwsFlags() const; private: int width_; int height_; - InterpolationMode interpolation_; -} + InterpolationMode interpolationMode_; +}; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index a865bdaed..b6987036e 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -262,10 +262,24 @@ void _add_video_stream( custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; - videoStreamOptions.width = width; - videoStreamOptions.height = height; videoStreamOptions.ffmpegThreadCount = num_threads; + // TODO: Eliminate this temporary bridge code. This exists because we have + // not yet exposed the transforms API on the Python side. We also want + // to remove the `width` and `height` arguments from the Python API. + // + // TEMPORARY BRIDGE CODE START + TORCH_CHECK( + width.has_value() == height.has_value(), + "width and height must both be set or unset."); + std::vector transforms; + if (width.has_value()) { + transforms.push_back(new ResizeTransform(width.value(), height.value())); + width.reset(); + height.reset(); + } + // TEMPORARY BRIDGE CODE END + if (dimension_order.has_value()) { std::string stdDimensionOrder{dimension_order.value()}; TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); @@ -296,7 +310,10 @@ void _add_video_stream( : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, converted_mappings); + stream_index.value_or(-1), + transforms, + videoStreamOptions, + converted_mappings); } // Add a new video stream at `stream_index` using the provided options. diff --git a/test/test_ops.py b/test/test_ops.py index 64ec09063..3f704dc80 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -602,6 +602,7 @@ def test_color_conversion_library_with_scaling( ) swscale_frame0, _, _ = get_next_frame(swscale_decoder) assert_frames_equal(filtergraph_frame0, swscale_frame0) + assert filtergraph_frame0.shape == (3, target_height, target_width) @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) From 890e2b47033f8d0359da9724586b9fd7e6bdacc9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 07:50:50 -0700 Subject: [PATCH 03/35] Ha, "maybe unsued". --- src/torchcodec/_core/CudaDeviceInterface.cpp | 2 +- src/torchcodec/_core/CudaDeviceInterface.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index d4f704b88..2a68c2ce4 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -187,7 +187,7 @@ CudaDeviceInterface::~CudaDeviceInterface() { void CudaDeviceInterface::initialize( AVCodecContext* codecContext, - [[maybe_unsued]] const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const std::vector>& transforms, [[maybe_unused]] const AVRational& timeBase, const FrameDims& outputDims) { diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index aab5a8539..0afe52305 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -21,7 +21,7 @@ class CudaDeviceInterface : public DeviceInterface { void initialize( AVCodecContext* codecContext, - [[maybe_unsued]] const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const std::vector>& transforms, [[maybe_unused]] const AVRational& timeBase, From d07f7d8f0b2ab494ef392ad71622a9b964bf3b87 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 08:07:05 -0700 Subject: [PATCH 04/35] Update C++ tests --- test/VideoDecoderTest.cpp | 41 ++++++++++++++++++++++----------------- 1 file changed, 23 insertions(+), 18 deletions(-) diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 0c21f0d46..643c6c83a 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -146,14 +146,6 @@ double computeAverageCosineSimilarity( return averageCosineSimilarity; } -// TEST(DecoderOptionsTest, ConvertsFromStringToOptions) { -// std::string optionsString = -// "ffmpeg_thread_count=3,dimension_order=NCHW,width=100,height=120"; -// SingleStreamDecoder::DecoderOptions options = -// SingleStreamDecoder::DecoderOptions(optionsString); -// EXPECT_EQ(options.ffmpegThreadCount, 3); -// } - TEST(SingleStreamDecoderTest, RespectsWidthAndHeightFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = @@ -161,7 +153,8 @@ TEST(SingleStreamDecoderTest, RespectsWidthAndHeightFromOptions) { VideoStreamOptions videoStreamOptions; videoStreamOptions.width = 100; videoStreamOptions.height = 120; - decoder->addVideoStream(-1, videoStreamOptions); + std::vector transforms; + decoder->addVideoStream(-1, transforms, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrame().data; EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); } @@ -172,7 +165,8 @@ TEST(SingleStreamDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { std::make_unique(path); VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; - decoder->addVideoStream(-1, videoStreamOptions); + std::vector transforms; + decoder->addVideoStream(-1, transforms, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrame().data; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -181,7 +175,8 @@ TEST_P(SingleStreamDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); auto output = ourDecoder->getNextFrame(); torch::Tensor tensor0FromOurDecoder = output.data; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); @@ -220,7 +215,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) { ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - ourDecoder->addVideoStream(bestVideoStreamIndex); + std::vector transforms; + ourDecoder->addVideoStream(bestVideoStreamIndex, transforms); // Frame with index 180 corresponds to timestamp 6.006. auto output = ourDecoder->getFramesAtIndices({0, 180}); auto tensor = output.data; @@ -244,7 +240,9 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; - ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); + std::vector transforms; + ourDecoder->addVideoStream( + bestVideoStreamIndex, transforms, videoStreamOptions); // Frame with index 180 corresponds to timestamp 6.006. auto output = ourDecoder->getFramesAtIndices({0, 180}); auto tensor = output.data; @@ -264,7 +262,8 @@ TEST_P(SingleStreamDecoderTest, SeeksCloseToEof) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); auto output = ourDecoder->getNextFrame(); EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000); @@ -277,7 +276,8 @@ TEST_P(SingleStreamDecoderTest, GetsFramePlayedAtTimestamp) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); auto output = ourDecoder->getFramePlayedAt(6.006); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame's duration is 0.033367 according to ffprobe, @@ -307,7 +307,8 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); ourDecoder->setCursorPtsInSeconds(6.0); auto output = ourDecoder->getNextFrame(); torch::Tensor tensor6FromOurDecoder = output.data; @@ -410,7 +411,9 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorFilterGraph) { VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH; - ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); + std::vector transforms; + ourDecoder->addVideoStream( + bestVideoStreamIndex, transforms, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); @@ -427,7 +430,9 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorSwscale) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::SWSCALE; - ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); + std::vector transforms; + ourDecoder->addVideoStream( + bestVideoStreamIndex, transforms, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); From cc4e2ecc6ed221373f91cb097343cbc6b91bcdd3 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 08:29:00 -0700 Subject: [PATCH 05/35] Remove C++ test that we no longer need --- test/VideoDecoderTest.cpp | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 643c6c83a..ed8f56cd6 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -146,19 +146,6 @@ double computeAverageCosineSimilarity( return averageCosineSimilarity; } -TEST(SingleStreamDecoderTest, RespectsWidthAndHeightFromOptions) { - std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr decoder = - std::make_unique(path); - VideoStreamOptions videoStreamOptions; - videoStreamOptions.width = 100; - videoStreamOptions.height = 120; - std::vector transforms; - decoder->addVideoStream(-1, transforms, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrame().data; - EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); -} - TEST(SingleStreamDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = From f471776d9fa3dce22e3afebaeb33799e552729b2 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 08:38:59 -0700 Subject: [PATCH 06/35] Virtual classes need virtual destructors --- src/torchcodec/_core/Transform.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 66c47612f..957758088 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -15,6 +15,7 @@ namespace facebook::torchcodec { class Transform { public: virtual std::string getFilterGraphCpu() const = 0; + virtual ~Transform() = default; // If the transformation does not change the output frame dimensions, then // there is no need to override this member function. The default From c06aa947a4917c4c303a2aa258639c24fbcf8fc9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 08:53:18 -0700 Subject: [PATCH 07/35] Cuda device convert frames function --- src/torchcodec/_core/CudaDeviceInterface.h | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 0afe52305..b8224e6f2 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -28,10 +28,7 @@ class CudaDeviceInterface : public DeviceInterface { const FrameDims& outputDims) override; void convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, - const FrameDims& outputDims, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) override; From 781f956d16bcd98f42619d22449041bad1302148 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 11:32:25 -0700 Subject: [PATCH 08/35] Fix cuda --- src/torchcodec/_core/CpuDeviceInterface.cpp | 10 ++++++++-- src/torchcodec/_core/CpuDeviceInterface.h | 8 +++++++- src/torchcodec/_core/CudaDeviceInterface.cpp | 7 +++++++ 3 files changed, 22 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 817c06398..3dda3b968 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -94,6 +94,9 @@ void CpuDeviceInterface::initialize( if (areTransformsSwScaleCompatible && (userRequestedSwScale || isWidthSwScaleCompatible)) { colorConversionLibrary_ = ColorConversionLibrary::SWSCALE; + + // SCOTT NEXT TODO: set swsFlags_ + } else { colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH; @@ -112,6 +115,8 @@ void CpuDeviceInterface::initialize( filters_ = filters.str(); } } + + initialized_ = true; } // Note [preAllocatedOutputTensor with swscale and filtergraph]: @@ -127,6 +132,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { + TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( @@ -167,7 +173,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( prevSwsFrameContext_ = swsFrameContext; } int resultHeight = - convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); + convertAVFrameToTensorUsingSwScale(avFrame, outputTensor); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? @@ -227,7 +233,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( } } -int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( +int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor) { uint8_t* pointers[4] = { diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 2bba6fd21..380e0d7cb 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -37,7 +37,7 @@ class CpuDeviceInterface : public DeviceInterface { std::nullopt) override; private: - int convertAVFrameToTensorUsingSwsScale( + int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); @@ -71,6 +71,10 @@ class CpuDeviceInterface : public DeviceInterface { AVRational timeBase_; FrameDims outputDims_; + // If we use swscale for resizing, the flags control the resizing algorithm. + // We exclusively get the value from the ResizeTransform. + int swsFlags_ = 0; + // The copy filter just copies the input to the output. Computationally, it // should be a no-op. If we get no user-provided transforms, we will use the // copy filter. @@ -85,6 +89,8 @@ class CpuDeviceInterface : public DeviceInterface { // be created before decoding a new frame. SwsFrameContext prevSwsFrameContext_; FiltersContext prevFiltersContext_; + + bool initialized_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 2a68c2ce4..26ccda8b0 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -216,10 +216,17 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // Typically that happens if the video's encoder isn't supported by NVDEC. // Below, we choose to convert the frame's color-space using the CPU // codepath, and send it back to the GPU at the very end. + // // TODO: A possibly better solution would be to send the frame to the GPU // first, and do the color conversion there. + // + // TODO: If we're going to keep this around, we should probably cache it? auto cpuDevice = torch::Device(torch::kCPU); auto cpuInterface = createDeviceInterface(cpuDevice); + TORCH_CHECK( + cpuInterface != nullptr, "Failed to create CPU device interface"); + cpuDeviceInterface->initialize( + nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_); FrameOutput cpuFrameOutput; cpuInterface->convertAVFrameToFrameOutput( From 8e7072f99022101839fd24b4b42abfa7975e7bdd Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 12:37:36 -0700 Subject: [PATCH 09/35] Handle swscale correctly --- src/torchcodec/_core/CpuDeviceInterface.cpp | 22 +++++++++++++-------- src/torchcodec/_core/CpuDeviceInterface.h | 4 ++-- src/torchcodec/_core/Transform.cpp | 2 +- src/torchcodec/_core/Transform.h | 6 ++++-- 4 files changed, 21 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 3dda3b968..658290261 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -65,12 +65,12 @@ void CpuDeviceInterface::initialize( // choose color conversion library publicly; we only use this ability // internally. - // If any transforms are not swscale compatible, then we can't use swscale. - bool areTransformsSwScaleCompatible = true; - for (const auto& transform : transforms) { - areTransformsSwScaleCompatible = - areTransformsSwScaleCompatible && transform->isSwScaleCompatible(); - } + // We can only use swscale when we have a single resize transform. Note that + // this means swscale will not support the case of having several, + // back-to-base resizes. There's no strong reason to even do that, but if + // someone does, it's more correct to implement that with filtergraph. + bool areTransformsSwScaleCompatible = transforms.empty() || + (transforms.size() == 1 && transforms[0]->isResize()); // swscale requires widths to be multiples of 32: // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements @@ -95,8 +95,14 @@ void CpuDeviceInterface::initialize( (userRequestedSwScale || isWidthSwScaleCompatible)) { colorConversionLibrary_ = ColorConversionLibrary::SWSCALE; - // SCOTT NEXT TODO: set swsFlags_ - + // We established above that if the transforms are swscale compatible and + // non-empty, then they must have only one transforms, and that transform is + // ResizeTransform. + if (!transforms.empty()) { + auto resize = dynamic_cast(transforms[0].get()); + TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!") + swsFlags_ = resize->getSwsFlags(); + } } else { colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH; diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 380e0d7cb..e26dfa0e9 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -72,8 +72,8 @@ class CpuDeviceInterface : public DeviceInterface { FrameDims outputDims_; // If we use swscale for resizing, the flags control the resizing algorithm. - // We exclusively get the value from the ResizeTransform. - int swsFlags_ = 0; + // We default to bilinear. Users can override this with a ResizeTransform. + int swsFlags_ = SWS_BILINEAR; // The copy filter just copies the input to the output. Computationally, it // should be a no-op. If we get no user-provided transforms, we will use the diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index c9504387e..4884f92e9 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -56,7 +56,7 @@ std::optional ResizeTransform::getOutputFrameDims() const { return FrameDims(width_, height_); } -bool ResizeTransform::isSwScaleCompatible() const { +bool ResizeTransform::isResize() const { return true; } diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 957758088..307e18b73 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -28,7 +28,9 @@ class Transform { return std::nullopt; } - virtual bool isSwScaleCompatible() const { + // The ResizeTransform is special, because it is the only transform that + // swscale can handle. + virtual bool isResize() const { return false; } }; @@ -47,7 +49,7 @@ class ResizeTransform : public Transform { std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; - bool isSwScaleCompatible() const override; + bool isResize() const override; int getSwsFlags() const; From 1d0c2754c23fe48dd2670e4c901e072bf3efefab Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 12:49:14 -0700 Subject: [PATCH 10/35] Variable names matter --- src/torchcodec/_core/CudaDeviceInterface.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 26ccda8b0..3ef293938 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -225,7 +225,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( auto cpuInterface = createDeviceInterface(cpuDevice); TORCH_CHECK( cpuInterface != nullptr, "Failed to create CPU device interface"); - cpuDeviceInterface->initialize( + cpuInterface->initialize( nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_); FrameOutput cpuFrameOutput; From 30622a7060e6b0b1ff3b76205dcd489492542593 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 13:08:13 -0700 Subject: [PATCH 11/35] Timebase --- src/torchcodec/_core/CudaDeviceInterface.cpp | 3 ++- src/torchcodec/_core/CudaDeviceInterface.h | 1 + 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 3ef293938..400027fca 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -189,10 +189,11 @@ void CudaDeviceInterface::initialize( AVCodecContext* codecContext, [[maybe_unused]] const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const std::vector>& transforms, - [[maybe_unused]] const AVRational& timeBase, + const AVRational& timeBase, const FrameDims& outputDims) { TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + timeBase_ = timeBase; outputDims_ = outputDims; // It is important for pytorch itself to create the cuda context. If ffmpeg diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index b8224e6f2..319e1525e 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -34,6 +34,7 @@ class CudaDeviceInterface : public DeviceInterface { std::nullopt) override; private: + AVRational timeBase_; FrameDims outputDims_; UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; From 7a41bfd69c19725c44379367c8f3fae8218ea379 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 19 Sep 2025 13:50:17 -0700 Subject: [PATCH 12/35] Removes width and height from StreamOptions --- src/torchcodec/_core/CpuDeviceInterface.cpp | 6 ++---- src/torchcodec/_core/Encoder.cpp | 4 ++-- src/torchcodec/_core/SingleStreamDecoder.cpp | 4 ---- src/torchcodec/_core/StreamOptions.h | 12 ++++++++---- 4 files changed, 12 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 658290261..a81dbe2c9 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -76,10 +76,8 @@ void CpuDeviceInterface::initialize( // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0; - bool userRequestedSwScale = - videoStreamOptions_.colorConversionLibrary.has_value() && - videoStreamOptions_.colorConversionLibrary.value() == - ColorConversionLibrary::SWSCALE; + bool userRequestedSwScale = videoStreamOptions_.colorConversionLibrary == + ColorConversionLibrary::SWSCALE; // Note that we treat the transform limitation differently from the width // limitation. That is, we consider the transforms being compatible with diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index b4d2c5609..acef42ad6 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -603,8 +603,8 @@ void VideoEncoder::initializeEncoder( // Use specified dimensions or input dimensions // TODO-VideoEncoder: Allow height and width to be set - outWidth_ = videoStreamOptions.width.value_or(inWidth_); - outHeight_ = videoStreamOptions.height.value_or(inHeight_); + outWidth_ = inWidth_; + outHeight_ = inHeight_; // Use YUV420P as default output format // TODO-VideoEncoder: Enable other pixel formats diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index e152c4539..2de51fca0 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -499,10 +499,6 @@ void SingleStreamDecoder::addVideoStream( streamIndex, customFrameMappings.value()); } - TORCH_CHECK( - !videoStreamOptions.width.has_value(), "width should have no value!"); - TORCH_CHECK( - !videoStreamOptions.height.has_value(), "height should have no value!"); outputDims_ = FrameDims(streamMetadata.width.value(), streamMetadata.height.value()); for (auto& transform : transforms) { diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 7665d6c9f..711ef331d 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -13,7 +13,6 @@ namespace facebook::torchcodec { enum ColorConversionLibrary { - // TODO: Add an AUTO option later. // Use the libavfilter library for color conversion. FILTERGRAPH, // Use the libswscale library for color conversion. @@ -28,12 +27,17 @@ struct VideoStreamOptions { // utilize all cores. If not set, it will be the default FFMPEG behavior for // the given codec. std::optional ffmpegThreadCount; + // Currently the dimension order can be either NHWC or NCHW. // H=height, W=width, C=channel. std::string dimensionOrder = "NCHW"; - std::optional width; - std::optional height; - std::optional colorConversionLibrary; + + // By default we have to use filtergraph, as it is more general. We can only + // use swscale when we have met strict requirements. See + // CpuDeviceInterface::initialze() for the logic. + ColorConversionLibrary colorConversionLibrary = + ColorConversionLibrary::FILTERGRAPH; + // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; From 8e55bd491bf553ba59aa72747d05d8fe88d2e54f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 22 Sep 2025 07:33:42 -0700 Subject: [PATCH 13/35] More cuda error checking --- src/torchcodec/_core/CudaDeviceInterface.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 400027fca..dd984a58e 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -192,6 +192,7 @@ void CudaDeviceInterface::initialize( const AVRational& timeBase, const FrameDims& outputDims) { TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); timeBase_ = timeBase; outputDims_ = outputDims; @@ -222,8 +223,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // first, and do the color conversion there. // // TODO: If we're going to keep this around, we should probably cache it? - auto cpuDevice = torch::Device(torch::kCPU); - auto cpuInterface = createDeviceInterface(cpuDevice); + auto cpuInterface = createDeviceInterface(torch::Device(torch::kCPU)); TORCH_CHECK( cpuInterface != nullptr, "Failed to create CPU device interface"); cpuInterface->initialize( @@ -294,6 +294,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "The AVFrame's hw_frames_ctx does not have a device_ctx. "); auto cudaDeviceCtx = static_cast(hwFramesCtx->device_ctx->hwctx); + TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null"); + at::cuda::CUDAEvent nvdecDoneEvent; at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad. c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex); From a032cb7db1f2b0cd87c8c13c648c2ab4fa89065a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 22 Sep 2025 20:36:35 -0700 Subject: [PATCH 14/35] Don't pass pre-allocated GPU tensor to CPU decoding --- src/torchcodec/_core/CpuDeviceInterface.cpp | 3 +++ src/torchcodec/_core/CudaDeviceInterface.cpp | 13 +++++++++++-- test/test_ops.py | 4 ++++ 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index a81dbe2c9..63ed4bce5 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -169,15 +169,18 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( outputDims_.width, outputDims_.height); + outputTensor = preAllocatedOutputTensor.value_or( allocateEmptyHWCTensor(outputDims_, torch::kCPU)); + if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { createSwsContext(swsFrameContext, avFrame->colorspace); prevSwsFrameContext_ = swsFrameContext; } int resultHeight = convertAVFrameToTensorUsingSwScale(avFrame, outputTensor); + // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index dd984a58e..04614a841 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -231,9 +231,18 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( FrameOutput cpuFrameOutput; cpuInterface->convertAVFrameToFrameOutput( - avFrame, cpuFrameOutput, preAllocatedOutputTensor); + avFrame, cpuFrameOutput); + + // TODO: explain that the pre-allocated tensor is on the GPU, but we need + // to do the decoding on the CPU, and we can't pass the pre-allocated tensor + // to do it. BUT WHY did it work before? + if (preAllocatedOutputTensor.has_value()) { + preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); + frameOutput.data = preAllocatedOutputTensor.value(); + } else { + frameOutput.data = cpuFrameOutput.data.to(device_); + } - frameOutput.data = cpuFrameOutput.data.to(device_); return; } diff --git a/test/test_ops.py b/test/test_ops.py index 3f704dc80..9fc323466 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -146,9 +146,13 @@ def test_get_frame_with_info_at_index(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_at_indices(self, device): + print("test_get_frames_at_indices") decoder = create_from_file(str(NASA_VIDEO.path)) + print("decoder created") add_video_stream(decoder, device=device) + print("stream added") frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180]) + print("frames retrieved") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame180 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS From 9aa85c2032f8872afbab19a08c5fc37c7a115935 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 23 Sep 2025 03:35:39 -0700 Subject: [PATCH 15/35] Lint --- src/torchcodec/_core/CpuDeviceInterface.cpp | 2 -- src/torchcodec/_core/CudaDeviceInterface.cpp | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 63ed4bce5..cd6407ef8 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -169,11 +169,9 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( outputDims_.width, outputDims_.height); - outputTensor = preAllocatedOutputTensor.value_or( allocateEmptyHWCTensor(outputDims_, torch::kCPU)); - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { createSwsContext(swsFrameContext, avFrame->colorspace); prevSwsFrameContext_ = swsFrameContext; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 04614a841..e5a8d2929 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -230,8 +230,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_); FrameOutput cpuFrameOutput; - cpuInterface->convertAVFrameToFrameOutput( - avFrame, cpuFrameOutput); + cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); // TODO: explain that the pre-allocated tensor is on the GPU, but we need // to do the decoding on the CPU, and we can't pass the pre-allocated tensor From 4e6c6f81efb18aa06c01ecdbfbc38de5feec069f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 23 Sep 2025 06:48:28 -0700 Subject: [PATCH 16/35] Remove prints from test --- test/test_ops.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 9fc323466..3f704dc80 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -146,13 +146,9 @@ def test_get_frame_with_info_at_index(self, device): @pytest.mark.parametrize("device", all_supported_devices()) def test_get_frames_at_indices(self, device): - print("test_get_frames_at_indices") decoder = create_from_file(str(NASA_VIDEO.path)) - print("decoder created") add_video_stream(decoder, device=device) - print("stream added") frames0and180, *_ = get_frames_at_indices(decoder, frame_indices=[0, 180]) - print("frames retrieved") reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) reference_frame180 = NASA_VIDEO.get_frame_data_by_index( INDEX_OF_FRAME_AT_6_SECONDS From 3737099c26dc8d2a98dea429bbb02ea8a6d4b5f5 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 23 Sep 2025 19:04:46 -0700 Subject: [PATCH 17/35] Lint --- src/torchcodec/_core/CpuDeviceInterface.cpp | 3 +-- src/torchcodec/_core/CudaDeviceInterface.cpp | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 196b1fe14..cd0dea81b 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -157,8 +157,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // This is an early-return optimization: if the format is already what we // need, and the dimensions are also what we need, we don't need to call // swscale or filtergraph. We can just convert the AVFrame to a tensor. - if (frameFormat == AV_PIX_FMT_RGB24 && - avFrame->width == outputDims_.width && + if (frameFormat == AV_PIX_FMT_RGB24 && avFrame->width == outputDims_.width && avFrame->height == outputDims_.height) { outputTensor = toTensor(avFrame); if (preAllocatedOutputTensor.has_value()) { diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 31340e11c..698b957e9 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -291,8 +291,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // Finally, we want to reuse the filter graph as much as possible for // performance reasons. if (!filterGraph_ || *filtersContext_ != *newFiltersContext) { - filterGraph_ = - std::make_unique(*newFiltersContext, videoStreamOptions_); + filterGraph_ = std::make_unique( + *newFiltersContext, videoStreamOptions_); filtersContext_ = std::move(newFiltersContext); } avFilteredFrame = filterGraph_->convert(avInputFrame); From 139e4ff51fd798ddf0d5f764005a5bf45ae46554 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 24 Sep 2025 08:52:06 -0700 Subject: [PATCH 18/35] Refactor NV12 stuff; test if we need format for FFmpeg 4 --- src/torchcodec/_core/CudaDeviceInterface.cpp | 86 ++++++++++---------- src/torchcodec/_core/CudaDeviceInterface.h | 15 ++-- 2 files changed, 52 insertions(+), 49 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 698b957e9..dbcae7090 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -208,19 +208,17 @@ void CudaDeviceInterface::initialize( codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); } -std::unique_ptr CudaDeviceInterface::initializeFiltersContext( - const UniqueAVFrame& avFrame) { +UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( + UniqueAVFrame& avFrame) { // We need FFmpeg filters to handle those conversion cases which are not // directly implemented in CUDA or CPU device interface (in case of a // fallback). - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); // Input frame is on CPU, we will just pass it to CPU device interface, so // skipping filters context as CPU device interface will handle everything for // us. if (avFrame->format != AV_PIX_FMT_CUDA) { - return nullptr; + return std::move(avFrame); } TORCH_CHECK( @@ -234,7 +232,7 @@ std::unique_ptr CudaDeviceInterface::initializeFiltersContext( // NV12 conversion is implemented directly with NPP, no need for filters. if (actualFormat == AV_PIX_FMT_NV12) { - return nullptr; + return std::move(avFrame); } AVPixelFormat outputFormat; @@ -248,13 +246,15 @@ std::unique_ptr CudaDeviceInterface::initializeFiltersContext( // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95 outputFormat = AV_PIX_FMT_RGB24; + /* auto actualFormatName = av_get_pix_fmt_name(actualFormat); TORCH_CHECK( actualFormatName != nullptr, "The actual format of a frame is unknown to FFmpeg. " "That's unexpected, please report this to the TorchCodec repo."); - filters << "hwdownload,format=" << actualFormatName; + */ + filters << "hwdownload"; } else { // Actual output color format will be set via filter options outputFormat = AV_PIX_FMT_CUDA; @@ -262,7 +262,10 @@ std::unique_ptr CudaDeviceInterface::initializeFiltersContext( filters << "scale_cuda=format=nv12:interp_algo=bilinear"; } - return std::make_unique( + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + auto newContext = std::make_unique( avFrame->width, avFrame->height, frameFormat, @@ -273,46 +276,43 @@ std::unique_ptr CudaDeviceInterface::initializeFiltersContext( filters.str(), timeBase_, av_buffer_ref(avFrame->hw_frames_ctx)); + + // We need to compare the current filter context with our previous filter + // context. If they are different, then we need to re-create a filter + // graph. We create a filter graph late so that we don't have to depend + // on the unreliable metadata in the header. And we sometimes re-create + // it because it's possible for frame resolution to change mid-stream. + // Finally, we want to reuse the filter graph as much as possible for + // performance reasons. + if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) { + nv12Conversion_ = + std::make_unique(*newContext, videoStreamOptions_); + nv12ConversionContext_ = std::move(newContext); + } + auto filteredAVFrame = nv12Conversion_->convert(avFrame); + + // If this check fails it means the frame wasn't + // reshaped to its expected dimensions by filtergraph. + TORCH_CHECK( + (filteredAVFrame->width == nv12ConversionContext_->outputWidth) && + (filteredAVFrame->height == nv12ConversionContext_->outputHeight), + "Expected frame from filter graph of ", + nv12ConversionContext_->outputWidth, + "x", + nv12ConversionContext_->outputHeight, + ", got ", + filteredAVFrame->width, + "x", + filteredAVFrame->height); + + return filteredAVFrame; } void CudaDeviceInterface::convertAVFrameToFrameOutput( - UniqueAVFrame& avInputFrame, + UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - std::unique_ptr newFiltersContext = - initializeFiltersContext(avInputFrame); - UniqueAVFrame avFilteredFrame; - if (newFiltersContext) { - // We need to compare the current filter context with our previous filter - // context. If they are different, then we need to re-create a filter - // graph. We create a filter graph late so that we don't have to depend - // on the unreliable metadata in the header. And we sometimes re-create - // it because it's possible for frame resolution to change mid-stream. - // Finally, we want to reuse the filter graph as much as possible for - // performance reasons. - if (!filterGraph_ || *filtersContext_ != *newFiltersContext) { - filterGraph_ = std::make_unique( - *newFiltersContext, videoStreamOptions_); - filtersContext_ = std::move(newFiltersContext); - } - avFilteredFrame = filterGraph_->convert(avInputFrame); - - // If this check fails it means the frame wasn't - // reshaped to its expected dimensions by filtergraph. - TORCH_CHECK( - (avFilteredFrame->width == filtersContext_->outputWidth) && - (avFilteredFrame->height == filtersContext_->outputHeight), - "Expected frame from filter graph of ", - filtersContext_->outputWidth, - "x", - filtersContext_->outputHeight, - ", got ", - avFilteredFrame->width, - "x", - avFilteredFrame->height); - } - - UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame; + avFrame = maybeConvertAVFrameToNV12(avFrame); // The filtered frame might be on CPU if CPU fallback has happenned on filter // graph level. For example, that's how we handle color format conversion diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index e04499394..c41515c1e 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -35,8 +35,10 @@ class CudaDeviceInterface : public DeviceInterface { std::nullopt) override; private: - std::unique_ptr initializeFiltersContext( - const UniqueAVFrame& avFrame); + // Our CUDA decoding code assumes NV12 format. In order to handle other + // kindsof input, we need to convert them to NV12. Our current implementation + // does this using filtergraph. + UniqueAVFrame maybeConvertAVFrameToNV12(UniqueAVFrame& avFrame); VideoStreamOptions videoStreamOptions_; AVRational timeBase_; @@ -44,10 +46,11 @@ class CudaDeviceInterface : public DeviceInterface { UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; - // Current filter context. Used to know whether a new FilterGraph - // should be created to process the next frame. - std::unique_ptr filtersContext_; - std::unique_ptr filterGraph_; + + // This filtergraph instance is only used for NV12 format conversion in + // maybeConvertAVFrameToNV12(). + std::unique_ptr nv12ConversionContext_; + std::unique_ptr nv12Conversion_; }; } // namespace facebook::torchcodec From 6668f4bbea5389415b06f17167b41849c718d7e1 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 24 Sep 2025 11:40:34 -0700 Subject: [PATCH 19/35] Specify hwdownload format as rgb24 --- src/torchcodec/_core/CudaDeviceInterface.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index dbcae7090..17e2daf0f 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -254,7 +254,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( "That's unexpected, please report this to the TorchCodec repo."); */ - filters << "hwdownload"; + filters << "hwdownload,format=rgb24"; } else { // Actual output color format will be set via filter options outputFormat = AV_PIX_FMT_CUDA; From 9f357c75a821fdb7ceefdd0e9eb2c62f7c05f24f Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 24 Sep 2025 12:35:50 -0700 Subject: [PATCH 20/35] Do all nv12 conversions on GPU --- src/torchcodec/_core/CudaDeviceInterface.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 17e2daf0f..92607eb64 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -254,7 +254,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( "That's unexpected, please report this to the TorchCodec repo."); */ - filters << "hwdownload,format=rgb24"; + filters << "hwupload,format=nv12"; } else { // Actual output color format will be set via filter options outputFormat = AV_PIX_FMT_CUDA; From 3dc20b84991abab900e4cc667bf79648d307c41d Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 24 Sep 2025 12:39:07 -0700 Subject: [PATCH 21/35] Wrong output format --- src/torchcodec/_core/CudaDeviceInterface.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 92607eb64..cb5415415 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -244,7 +244,8 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( // n5.0. With the earlier version of ffmpeg we have no choice but use CPU // filters. See: // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95 - outputFormat = AV_PIX_FMT_RGB24; + //outputFormat = AV_PIX_FMT_RGB24; + outputFormat = AV_PIX_FMT_CUDA; /* auto actualFormatName = av_get_pix_fmt_name(actualFormat); From 7f88e60b5d0ac60e1467224e081fffb379e96500 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 24 Sep 2025 13:34:50 -0700 Subject: [PATCH 22/35] Back to RGB24 --- src/torchcodec/_core/CudaDeviceInterface.cpp | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index cb5415415..174ea95b5 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -244,18 +244,15 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( // n5.0. With the earlier version of ffmpeg we have no choice but use CPU // filters. See: // https://github.com/FFmpeg/FFmpeg/commit/62dc5df941f5e196164c151691e4274195523e95 - //outputFormat = AV_PIX_FMT_RGB24; - outputFormat = AV_PIX_FMT_CUDA; + outputFormat = AV_PIX_FMT_RGB24; - /* auto actualFormatName = av_get_pix_fmt_name(actualFormat); TORCH_CHECK( actualFormatName != nullptr, "The actual format of a frame is unknown to FFmpeg. " "That's unexpected, please report this to the TorchCodec repo."); - */ - filters << "hwupload,format=nv12"; + filters << "hwdownload,format=" << actualFormatName; } else { // Actual output color format will be set via filter options outputFormat = AV_PIX_FMT_CUDA; From dda26498eef4e162bd31c0798fbce283a8c22fc2 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 25 Sep 2025 06:37:34 -0700 Subject: [PATCH 23/35] CUDA and CPU refactoring regarding NV12. --- src/torchcodec/_core/CpuDeviceInterface.cpp | 17 ----- src/torchcodec/_core/CpuDeviceInterface.h | 4 +- src/torchcodec/_core/CudaDeviceInterface.cpp | 74 +++++++++++++------- 3 files changed, 49 insertions(+), 46 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index cd0dea81b..398b448ef 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -154,23 +154,6 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( enum AVPixelFormat frameFormat = static_cast(avFrame->format); - // This is an early-return optimization: if the format is already what we - // need, and the dimensions are also what we need, we don't need to call - // swscale or filtergraph. We can just convert the AVFrame to a tensor. - if (frameFormat == AV_PIX_FMT_RGB24 && avFrame->width == outputDims_.width && - avFrame->height == outputDims_.height) { - outputTensor = toTensor(avFrame); - if (preAllocatedOutputTensor.has_value()) { - // We have already validated that preAllocatedOutputTensor and - // outputTensor have the same shape. - preAllocatedOutputTensor.value().copy_(outputTensor); - frameOutput.data = preAllocatedOutputTensor.value(); - } else { - frameOutput.data = outputTensor; - } - return; - } - if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) { // We need to compare the current frame context with our previous frame // context. If they are different, then we need to re-create our colorspace diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 0eafea7da..d50013285 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -36,13 +36,13 @@ class CpuDeviceInterface : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; + torch::Tensor toTensor(const UniqueAVFrame& avFrame); + private: int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); - torch::Tensor toTensor(const UniqueAVFrame& avFrame); - struct SwsFrameContext { int inputWidth = 0; int inputHeight = 0; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 174ea95b5..b9b77c5c8 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -5,6 +5,7 @@ #include #include "src/torchcodec/_core/Cache.h" +#include "src/torchcodec/_core/CpuDeviceInterface.h" #include "src/torchcodec/_core/CudaDeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" @@ -230,7 +231,7 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( reinterpret_cast(avFrame->hw_frames_ctx->data); AVPixelFormat actualFormat = hwFramesCtx->sw_format; - // NV12 conversion is implemented directly with NPP, no need for filters. + // If the frame is already in NV12 format, we don't need to do anything. if (actualFormat == AV_PIX_FMT_NV12) { return std::move(avFrame); } @@ -310,35 +311,64 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { + if (preAllocatedOutputTensor.has_value()) { + auto shape = preAllocatedOutputTensor.value().sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == outputDims_.height) && + (shape[1] == outputDims_.width) && (shape[2] == 3), + "Expected tensor of shape ", + outputDims_.height, + "x", + outputDims_.width, + "x3, got ", + shape); + } + + // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by + // converting them to NV12. avFrame = maybeConvertAVFrameToNV12(avFrame); - // The filtered frame might be on CPU if CPU fallback has happenned on filter - // graph level. For example, that's how we handle color format conversion - // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet. if (avFrame->format != AV_PIX_FMT_CUDA) { // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on - // the GPU. In this branch, the frame is on the CPU: this is what NVDEC - // gives us if it wasn't able to decode a frame, for whatever reason. - // Typically that happens if the video's encoder isn't supported by NVDEC. - // Below, we choose to convert the frame's color-space using the CPU - // codepath, and send it back to the GPU at the very end. + // the GPU. In this branch, the frame is on the CPU. There are two possible + // reasons: + // + // 1. During maybeConvertAVFrameToNV12(), we had a non-NV12 format frame + // and we're on FFmpeg 4.4 or earlier. In such cases, we had to use CPU + // filters and we just converted the frame to RGB24. + // 2. This is what NVDEC gave us if it wasn't able to decode a frame, for + // whatever reason. Typically that happens if the video's encoder isn't + // supported by NVDEC. // - // TODO: A possibly better solution would be to send the frame to the GPU - // first, and do the color conversion there. + // In both cases, we have a frame on the CPU, and we need a CPU device to + // handle it. We send the frame back to the CUDA device when we're done. // - // TODO: If we're going to keep this around, we should probably cache it? - auto cpuInterface = createDeviceInterface(torch::Device(torch::kCPU)); + // TODO: Perhaps we should cache cpuInterface? + auto cpuInterface = std::make_unique(torch::kCPU); TORCH_CHECK( cpuInterface != nullptr, "Failed to create CPU device interface"); cpuInterface->initialize( nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_); + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + FrameOutput cpuFrameOutput; - cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); - // TODO: explain that the pre-allocated tensor is on the GPU, but we need - // to do the decoding on the CPU, and we can't pass the pre-allocated tensor - // to do it. BUT WHY did it work before? + if (frameFormat == AV_PIX_FMT_RGB24 && + avFrame->width == outputDims_.width && + avFrame->height == outputDims_.height) { + // Reason 1 above. The frame is already in the format and dimensions that + // we need, we just need to convert it to a tensor. + cpuFrameOutput.data = cpuInterface->toTensor(avFrame); + } else { + // Reason 2 above. We need to do a full conversion. + cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); + } + + // Finally, we need to send the frame back to the GPU. Note that the + // pre-allocated tensor is on the GPU, so we can't send that to the CPU + // device interface. We copy it over here. if (preAllocatedOutputTensor.has_value()) { preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); frameOutput.data = preAllocatedOutputTensor.value(); @@ -372,16 +402,6 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( torch::Tensor& dst = frameOutput.data; if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); - auto shape = dst.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == outputDims_.height) && - (shape[1] == outputDims_.width) && (shape[2] == 3), - "Expected tensor of shape ", - outputDims_.height, - "x", - outputDims_.width, - "x3, got ", - shape); } else { dst = allocateEmptyHWCTensor(outputDims_, device_); } From fc5468ec5ba753e2f42043416662142c2d51fdc4 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 25 Sep 2025 12:16:59 -0700 Subject: [PATCH 24/35] Test to ensure transforms are not used with non-CPU --- src/torchcodec/_core/SingleStreamDecoder.cpp | 4 ++++ test/test_ops.py | 9 +++++++++ 2 files changed, 13 insertions(+) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 2de51fca0..dbef7a587 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -466,6 +466,10 @@ void SingleStreamDecoder::addVideoStream( std::vector& transforms, const VideoStreamOptions& videoStreamOptions, std::optional customFrameMappings) { + TORCH_CHECK( + transforms.empty() || videoStreamOptions.device == torch::kCPU, + "Transforms are only supported for CPU devices."); + addStream( streamIndex, AVMEDIA_TYPE_VIDEO, diff --git a/test/test_ops.py b/test/test_ops.py index f8d9e2b3c..eb6ed31bc 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -614,6 +614,15 @@ def test_color_conversion_library_with_scaling( assert_frames_equal(filtergraph_frame0, swscale_frame0) assert filtergraph_frame0.shape == (3, target_height, target_width) + @needs_cuda + def test_scaling_on_cuda_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + with pytest.raises( + RuntimeError, + match="Transforms are only supported for CPU devices.", + ): + add_video_stream(decoder, device="cuda", width=100, height=100) + @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale")) def test_color_conversion_library_with_dimension_order( From 48e3ea3eedf9f453c1ca77d9124cbe317f64ec6a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 26 Sep 2025 08:19:56 -0700 Subject: [PATCH 25/35] Better comments; refactor toTensor --- src/torchcodec/_core/CpuDeviceInterface.cpp | 34 +++++--------------- src/torchcodec/_core/CpuDeviceInterface.h | 2 -- src/torchcodec/_core/CudaDeviceInterface.cpp | 5 ++- src/torchcodec/_core/DeviceInterface.cpp | 15 +++++++++ src/torchcodec/_core/DeviceInterface.h | 2 ++ 5 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 398b448ef..868438652 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -56,14 +56,9 @@ void CpuDeviceInterface::initialize( timeBase_ = timeBase; outputDims_ = outputDims; - // TODO: rationalize comment below with new stuff. - // By default, we want to use swscale for color conversion because it is - // faster. However, it has width requirements, so we may need to fall back - // to filtergraph. We also need to respect what was requested from the - // options; we respect the options unconditionally, so it's possible for - // swscale's width requirements to be violated. We don't expose the ability to - // choose color conversion library publicly; we only use this ability - // internally. + // We want to use swscale for color conversion if possible because it is + // faster than filtergraph. The following are the conditions we need to meet + // to use it. // We can only use swscale when we have a single resize transform. Note that // this means swscale will not support the case of having several, @@ -76,12 +71,14 @@ void CpuDeviceInterface::initialize( // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0; + // Note that we do not expose this capability in the public API, only through + // the core API. bool userRequestedSwScale = videoStreamOptions_.colorConversionLibrary == ColorConversionLibrary::SWSCALE; // Note that we treat the transform limitation differently from the width // limitation. That is, we consider the transforms being compatible with - // sws_scale as a hard requirement. If the transforms are not compatiable, + // swscale as a hard requirement. If the transforms are not compatiable, // then we will end up not applying the transforms, and that is wrong. // // The width requirement, however, is a soft requirement. Even if we don't @@ -94,7 +91,7 @@ void CpuDeviceInterface::initialize( colorConversionLibrary_ = ColorConversionLibrary::SWSCALE; // We established above that if the transforms are swscale compatible and - // non-empty, then they must have only one transforms, and that transform is + // non-empty, then they must have only one transform, and that transform is // ResizeTransform. if (!transforms.empty()) { auto resize = dynamic_cast(transforms[0].get()); @@ -207,7 +204,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( std::make_unique(filtersContext, videoStreamOptions_); prevFiltersContext_ = std::move(filtersContext); } - outputTensor = toTensor(filterGraphContext_->convert(avFrame)); + outputTensor = rgbAVFrameToTensor(filterGraphContext_->convert(avFrame)); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. @@ -256,21 +253,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( return resultHeight; } -torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) { - TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24); - - int height = avFrame->height; - int width = avFrame->width; - std::vector shape = {height, width, 3}; - std::vector strides = {avFrame->linesize[0], 3, 1}; - AVFrame* avFrameClone = av_frame_clone(avFrame.get()); - auto deleter = [avFrameClone](void*) { - UniqueAVFrame avFrameToDelete(avFrameClone); - }; - return torch::from_blob( - avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8}); -} - void CpuDeviceInterface::createSwsContext( const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace) { diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index d50013285..89b307cc9 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -36,8 +36,6 @@ class CpuDeviceInterface : public DeviceInterface { std::optional preAllocatedOutputTensor = std::nullopt) override; - torch::Tensor toTensor(const UniqueAVFrame& avFrame); - private: int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index b9b77c5c8..e8c78a9e9 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -5,7 +5,6 @@ #include #include "src/torchcodec/_core/Cache.h" -#include "src/torchcodec/_core/CpuDeviceInterface.h" #include "src/torchcodec/_core/CudaDeviceInterface.h" #include "src/torchcodec/_core/FFMPEGCommon.h" @@ -344,7 +343,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // handle it. We send the frame back to the CUDA device when we're done. // // TODO: Perhaps we should cache cpuInterface? - auto cpuInterface = std::make_unique(torch::kCPU); + auto cpuInterface = createDeviceInterface(torch::kCPU); TORCH_CHECK( cpuInterface != nullptr, "Failed to create CPU device interface"); cpuInterface->initialize( @@ -360,7 +359,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( avFrame->height == outputDims_.height) { // Reason 1 above. The frame is already in the format and dimensions that // we need, we just need to convert it to a tensor. - cpuFrameOutput.data = cpuInterface->toTensor(avFrame); + cpuFrameOutput.data = rgbAVFrameToTensor(avFrame); } else { // Reason 2 above. We need to do a full conversion. cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index 70b00fb62..356f29c64 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -76,4 +76,19 @@ std::unique_ptr createDeviceInterface( return std::unique_ptr(deviceMap[deviceType](device)); } +torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame) { + TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24); + + int height = avFrame->height; + int width = avFrame->width; + std::vector shape = {height, width, 3}; + std::vector strides = {avFrame->linesize[0], 3, 1}; + AVFrame* avFrameClone = av_frame_clone(avFrame.get()); + auto deleter = [avFrameClone](void*) { + UniqueAVFrame avFrameToDelete(avFrameClone); + }; + return torch::from_blob( + avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8}); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index a28e03ddc..219852616 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -60,4 +60,6 @@ torch::Device createTorchDevice(const std::string device); std::unique_ptr createDeviceInterface( const torch::Device& device); +torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame); + } // namespace facebook::torchcodec From 7813005b8e3bab12ef591a4da55a5f79a922ce09 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 26 Sep 2025 13:53:20 -0700 Subject: [PATCH 26/35] Deal with variable resolution and lying metadata - again --- src/torchcodec/_core/CpuDeviceInterface.cpp | 159 +++++++++++-------- src/torchcodec/_core/CpuDeviceInterface.h | 56 ++++--- src/torchcodec/_core/CudaDeviceInterface.cpp | 39 +++-- src/torchcodec/_core/CudaDeviceInterface.h | 8 +- src/torchcodec/_core/DeviceInterface.h | 3 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 25 ++- src/torchcodec/_core/SingleStreamDecoder.h | 17 +- 7 files changed, 189 insertions(+), 118 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 868438652..69ef59005 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -51,31 +51,70 @@ void CpuDeviceInterface::initialize( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& outputDims) { + [[maybe_unused]] const FrameDims& metadataDims, + const std::optional& resizedOutputDims) { videoStreamOptions_ = videoStreamOptions; timeBase_ = timeBase; - outputDims_ = outputDims; - - // We want to use swscale for color conversion if possible because it is - // faster than filtergraph. The following are the conditions we need to meet - // to use it. + resizedOutputDims_ = resizedOutputDims; // We can only use swscale when we have a single resize transform. Note that // this means swscale will not support the case of having several, // back-to-base resizes. There's no strong reason to even do that, but if // someone does, it's more correct to implement that with filtergraph. - bool areTransformsSwScaleCompatible = transforms.empty() || + // + // We calculate this value during initilization but we don't refer to it until + // getColorConversionLibrary() is called. Calculating this value during + // initialization saves us from having to save all of the transforms. + areTransformsSwScaleCompatible_ = transforms.empty() || (transforms.size() == 1 && transforms[0]->isResize()); - // swscale requires widths to be multiples of 32: - // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements - bool isWidthSwScaleCompatible = (outputDims_.width % 32) == 0; - // Note that we do not expose this capability in the public API, only through // the core API. - bool userRequestedSwScale = videoStreamOptions_.colorConversionLibrary == + // + // Same as above, we calculate this value during initialization and refer to + // it in getColorConversionLibrary(). + userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary == ColorConversionLibrary::SWSCALE; + // We can only use swscale when we have a single resize transform. Note that + // we actually decide on whether or not to actually use swscale at the last + // possible moment, when we actually convert the frame. This is because we + // need to know the actual frame dimensions. + if (transforms.size() == 1 && transforms[0]->isResize()) { + auto resize = dynamic_cast(transforms[0].get()); + TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!") + swsFlags_ = resize->getSwsFlags(); + } + + // If we have any transforms, replace filters_ with the filter strings from + // the transforms. As noted above, we decide between swscale and filtergraph + // when we actually decode a frame. + std::stringstream filters; + bool first = true; + for (const auto& transform : transforms) { + if (!first) { + filters << ","; + } + filters << transform->getFilterGraphCpu(); + first = false; + } + if (!transforms.empty()) { + filters_ = filters.str(); + } + + initialized_ = true; +} + +ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( + const FrameDims& outputDims) { + // swscale requires widths to be multiples of 32: + // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements + bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0; + + // We want to use swscale for color conversion if possible because it is + // faster than filtergraph. The following are the conditions we need to meet + // to use it. + // // Note that we treat the transform limitation differently from the width // limitation. That is, we consider the transforms being compatible with // swscale as a hard requirement. If the transforms are not compatiable, @@ -86,38 +125,12 @@ void CpuDeviceInterface::initialize( // behavior. Since we don't expose the ability to choose swscale or // filtergraph in our public API, this is probably okay. It's also the only // way that we can be certain we are testing one versus the other. - if (areTransformsSwScaleCompatible && - (userRequestedSwScale || isWidthSwScaleCompatible)) { - colorConversionLibrary_ = ColorConversionLibrary::SWSCALE; - - // We established above that if the transforms are swscale compatible and - // non-empty, then they must have only one transform, and that transform is - // ResizeTransform. - if (!transforms.empty()) { - auto resize = dynamic_cast(transforms[0].get()); - TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!") - swsFlags_ = resize->getSwsFlags(); - } + if (areTransformsSwScaleCompatible_ && + (userRequestedSwScale_ || isWidthSwScaleCompatible)) { + return ColorConversionLibrary::SWSCALE; } else { - colorConversionLibrary_ = ColorConversionLibrary::FILTERGRAPH; - - // If we have any transforms, replace filters_ with the filter strings from - // the transforms. - std::stringstream filters; - bool first = true; - for (const auto& transform : transforms) { - if (!first) { - filters << ","; - } - filters << transform->getFilterGraphCpu(); - first = false; - } - if (!transforms.empty()) { - filters_ = filters.str(); - } + return ColorConversionLibrary::FILTERGRAPH; } - - initialized_ = true; } // Note [preAllocatedOutputTensor with swscale and filtergraph]: @@ -134,24 +147,42 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); + + // Note that we ignore the dimensions from the metadata; we don't even bother + // storing them. The resized dimensions take priority. If we don't have any, + // then we use the dimensions from the actual decoded frame. We use the actual + // decoded frame and not the metadata for two reasons: + // + // 1. Metadata may be wrong. If we access to more accurate information, we + // should use it. + // 2. Video streams can have variable resolution. This fact is not captured + // in the stream metadata. + // + // Both cases cause problems for our batch APIs, as we allocate + // FrameBatchOutputs based on the the stream metadata. But single-frame APIs + // can still work in such situations, so they should. + auto outputDims = + resizedOutputDims_.value_or(FrameDims(avFrame->width, avFrame->height)); + if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == outputDims_.height) && - (shape[1] == outputDims_.width) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims.height) && + (shape[1] == outputDims.width) && (shape[2] == 3), "Expected pre-allocated tensor of shape ", - outputDims_.height, + outputDims.height, "x", - outputDims_.width, + outputDims.width, "x3, got ", shape); } + auto colorConversionLibrary = getColorConversionLibrary(outputDims); torch::Tensor outputTensor; enum AVPixelFormat frameFormat = static_cast(avFrame->format); - if (colorConversionLibrary_ == ColorConversionLibrary::SWSCALE) { + if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { // We need to compare the current frame context with our previous frame // context. If they are different, then we need to re-create our colorspace // conversion objects. We create our colorspace conversion objects late so @@ -163,11 +194,11 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( avFrame->width, avFrame->height, frameFormat, - outputDims_.width, - outputDims_.height); + outputDims.width, + outputDims.height); outputTensor = preAllocatedOutputTensor.value_or( - allocateEmptyHWCTensor(outputDims_, torch::kCPU)); + allocateEmptyHWCTensor(outputDims, torch::kCPU)); if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { createSwsContext(swsFrameContext, avFrame->colorspace); @@ -180,42 +211,42 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // the expected height. // TODO: Can we do the same check for width? TORCH_CHECK( - resultHeight == outputDims_.height, - "resultHeight != outputDims_.height: ", + resultHeight == outputDims.height, + "resultHeight != outputDims.height: ", resultHeight, " != ", - outputDims_.height); + outputDims.height); frameOutput.data = outputTensor; - } else if (colorConversionLibrary_ == ColorConversionLibrary::FILTERGRAPH) { + } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { FiltersContext filtersContext( avFrame->width, avFrame->height, frameFormat, avFrame->sample_aspect_ratio, - outputDims_.width, - outputDims_.height, + outputDims.width, + outputDims.height, AV_PIX_FMT_RGB24, filters_, timeBase_); - if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { - filterGraphContext_ = + if (!filterGraph_ || prevFiltersContext_ != filtersContext) { + filterGraph_ = std::make_unique(filtersContext, videoStreamOptions_); prevFiltersContext_ = std::move(filtersContext); } - outputTensor = rgbAVFrameToTensor(filterGraphContext_->convert(avFrame)); + outputTensor = rgbAVFrameToTensor(filterGraph_->convert(avFrame)); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. auto shape = outputTensor.sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == outputDims_.height) && - (shape[1] == outputDims_.width) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims.height) && + (shape[1] == outputDims.width) && (shape[2] == 3), "Expected output tensor of shape ", - outputDims_.height, + outputDims.height, "x", - outputDims_.width, + outputDims.width, "x3, got ", shape); @@ -231,7 +262,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( TORCH_CHECK( false, "Invalid color conversion library: ", - static_cast(colorConversionLibrary_)); + static_cast(colorConversionLibrary)); } } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 89b307cc9..6f69bfd7f 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -28,7 +28,8 @@ class CpuDeviceInterface : public DeviceInterface { const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& outputDims) override; + [[maybe_unused]] const FrameDims& metadataDims, + const std::optional& resizedOutputDims) override; void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, @@ -41,6 +42,9 @@ class CpuDeviceInterface : public DeviceInterface { const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); + ColorConversionLibrary getColorConversionLibrary( + const FrameDims& inputFrameDims); + struct SwsFrameContext { int inputWidth = 0; int inputHeight = 0; @@ -64,28 +68,44 @@ class CpuDeviceInterface : public DeviceInterface { const enum AVColorSpace colorspace); VideoStreamOptions videoStreamOptions_; - ColorConversionLibrary colorConversionLibrary_; AVRational timeBase_; - FrameDims outputDims_; - - // If we use swscale for resizing, the flags control the resizing algorithm. - // We default to bilinear. Users can override this with a ResizeTransform. - int swsFlags_ = SWS_BILINEAR; + std::optional resizedOutputDims_; + + // Color-conversion objects. Only one of filterGraph_ and swsContext_ should + // be non-null. Which one we use is controlled by colorConversionLibrary_. + // + // Creating both filterGraph_ and swsContext_ is relatively expensive, so we + // reuse them across frames. However, it is possbile that subsequent frames + // are different enough (change in dimensions) that we can't reuse the color + // conversion object. We store the relevant frame context from the frame used + // to create the object last time. We always compare the current frame's info + // against the previous one to determine if we need to recreate the color + // conversion object. + // + // TODO: The names of these fields is confusing, as the actual color + // conversion object for Sws has "context" in the name, and we use + // "context" for the structs we store to know if we need to recreate a + // color conversion object. We should clean that up. + std::unique_ptr filterGraph_; + FiltersContext prevFiltersContext_; + UniqueSwsContext swsContext_; + SwsFrameContext prevSwsFrameContext_; - // The copy filter just copies the input to the output. Computationally, it - // should be a no-op. If we get no user-provided transforms, we will use the - // copy filter. + // The filter we supply to filterGraph_, if it is used. The copy filter just + // copies the input to the output. Computationally, it should be a no-op. If + // we get no user-provided transforms, we will use the copy filter. Otherwise, + // we will construct the string from the transforms. std::string filters_ = "copy"; - // color-conversion fields. Only one of FilterGraphContext and - // UniqueSwsContext should be non-null. - std::unique_ptr filterGraphContext_; - UniqueSwsContext swsContext_; + // The flags we supply to swsContext_, if it used. The flags control the + // resizing algorithm. We default to bilinear. Users can override this with a + // ResizeTransform. + int swsFlags_ = SWS_BILINEAR; - // Used to know whether a new FilterGraphContext or UniqueSwsContext should - // be created before decoding a new frame. - SwsFrameContext prevSwsFrameContext_; - FiltersContext prevFiltersContext_; + // Values set during initialization and referred to in + // getColorConversionLibrary(). + bool areTransformsSwScaleCompatible_; + bool userRequestedSwScale_; bool initialized_ = false; }; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index e8c78a9e9..1b64a9d98 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -190,13 +190,14 @@ void CudaDeviceInterface::initialize( const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& outputDims) { + const FrameDims& metadataDims, + [[maybe_unused]] const std::optional& resizedOutputDims) { TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); videoStreamOptions_ = videoStreamOptions; timeBase_ = timeBase; - outputDims_ = outputDims; + metadataDims_ = metadataDims; // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. @@ -268,20 +269,13 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( avFrame->height, frameFormat, avFrame->sample_aspect_ratio, - outputDims_.width, - outputDims_.height, + metadataDims_.width, + metadataDims_.height, outputFormat, filters.str(), timeBase_, av_buffer_ref(avFrame->hw_frames_ctx)); - // We need to compare the current filter context with our previous filter - // context. If they are different, then we need to re-create a filter - // graph. We create a filter graph late so that we don't have to depend - // on the unreliable metadata in the header. And we sometimes re-create - // it because it's possible for frame resolution to change mid-stream. - // Finally, we want to reuse the filter graph as much as possible for - // performance reasons. if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) { nv12Conversion_ = std::make_unique(*newContext, videoStreamOptions_); @@ -313,12 +307,12 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == outputDims_.height) && - (shape[1] == outputDims_.width) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == metadataDims_.height) && + (shape[1] == metadataDims_.width) && (shape[2] == 3), "Expected tensor of shape ", - outputDims_.height, + metadataDims_.height, "x", - outputDims_.width, + metadataDims_.width, "x3, got ", shape); } @@ -347,7 +341,12 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( TORCH_CHECK( cpuInterface != nullptr, "Failed to create CPU device interface"); cpuInterface->initialize( - nullptr, VideoStreamOptions(), {}, timeBase_, outputDims_); + nullptr, + VideoStreamOptions(), + {}, + timeBase_, + metadataDims_, + std::nullopt); enum AVPixelFormat frameFormat = static_cast(avFrame->format); @@ -355,8 +354,8 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( FrameOutput cpuFrameOutput; if (frameFormat == AV_PIX_FMT_RGB24 && - avFrame->width == outputDims_.width && - avFrame->height == outputDims_.height) { + avFrame->width == metadataDims_.width && + avFrame->height == metadataDims_.height) { // Reason 1 above. The frame is already in the format and dimensions that // we need, we just need to convert it to a tensor. cpuFrameOutput.data = rgbAVFrameToTensor(avFrame); @@ -402,7 +401,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); } else { - dst = allocateEmptyHWCTensor(outputDims_, device_); + dst = allocateEmptyHWCTensor(metadataDims_, device_); } torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_); @@ -441,7 +440,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "cudaStreamGetFlags failed: ", cudaGetErrorString(err)); - NppiSize oSizeROI = {outputDims_.width, outputDims_.height}; + NppiSize oSizeROI = {metadataDims_.width, metadataDims_.height}; Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; NppStatus status; diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index c41515c1e..1cc5d2044 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -26,7 +26,9 @@ class CudaDeviceInterface : public DeviceInterface { [[maybe_unused]] const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& outputDims) override; + const FrameDims& metadataOutputDims, + [[maybe_unused]] const std::optional& resizedOutputDims) + override; void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, @@ -36,13 +38,13 @@ class CudaDeviceInterface : public DeviceInterface { private: // Our CUDA decoding code assumes NV12 format. In order to handle other - // kindsof input, we need to convert them to NV12. Our current implementation + // kinds of input, we need to convert them to NV12. Our current implementation // does this using filtergraph. UniqueAVFrame maybeConvertAVFrameToNV12(UniqueAVFrame& avFrame); VideoStreamOptions videoStreamOptions_; AVRational timeBase_; - FrameDims outputDims_; + FrameDims metadataDims_; UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 219852616..275b2bea0 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -37,7 +37,8 @@ class DeviceInterface { const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& outputDims) = 0; + const FrameDims& metadataDims, + const std::optional& resizedOutputDims) = 0; virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index dbef7a587..9bfae18b9 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -468,7 +468,7 @@ void SingleStreamDecoder::addVideoStream( std::optional customFrameMappings) { TORCH_CHECK( transforms.empty() || videoStreamOptions.device == torch::kCPU, - "Transforms are only supported for CPU devices."); + " Transforms are only supported for CPU devices."); addStream( streamIndex, @@ -503,12 +503,12 @@ void SingleStreamDecoder::addVideoStream( streamIndex, customFrameMappings.value()); } - outputDims_ = + metadataDims_ = FrameDims(streamMetadata.width.value(), streamMetadata.height.value()); for (auto& transform : transforms) { TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); if (transform->getOutputFrameDims().has_value()) { - outputDims_ = transform->getOutputFrameDims().value(); + resizedOutputDims_ = transform->getOutputFrameDims().value(); } transforms_.push_back(std::unique_ptr(transform)); } @@ -521,7 +521,8 @@ void SingleStreamDecoder::addVideoStream( videoStreamOptions, transforms_, streamInfo.timeBase, - outputDims_); + metadataDims_, + resizedOutputDims_); } void SingleStreamDecoder::addAudioStream( @@ -632,7 +633,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - frameIndices.size(), outputDims_, videoStreamOptions.device); + frameIndices.size(), + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); auto previousIndexInVideo = -1; for (size_t f = 0; f < frameIndices.size(); ++f) { @@ -689,7 +692,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - numOutputFrames, outputDims_, videoStreamOptions.device); + numOutputFrames, + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput frameOutput = @@ -816,7 +821,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // below. Hence, we need this special case below. if (startSeconds == stopSeconds) { FrameBatchOutput frameBatchOutput( - 0, outputDims_, videoStreamOptions.device); + 0, + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -861,7 +868,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; FrameBatchOutput frameBatchOutput( - numFrames, outputDims_, videoStreamOptions.device); + numFrames, + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput frameOutput = getFrameAtIndexInternal(i, frameBatchOutput.data[f]); diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 2c93accfc..3d33ffaba 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -373,12 +373,21 @@ class SingleStreamDecoder { // the final dimensions of the output frame after ALL transformations. We // figure this out as soon as we receive the transforms. If any of the // transforms change the final output frame dimensions, we store that in - // finalDims_. + // resizedOutputDims_. If resizedOutputDims_ has no value, that means there + // are no transforms that change the output frame dimensions. // - // By default, we set outputDims_ to the dimensions from the metadata from the - // stream. - FrameDims outputDims_; + // The priority order for output frame dimension is: + // + // 1. resizedOutputDims_; the resize requested by the user always takes + // priority. + // 2. The dimemnsions of the actual decoded AVFrame. This can change + // per-decoded frame, and is unknown in SingleStreamDecoder. Only the + // DeviceInterface learns it immediately after decoding a raw frame but + // before the color transformation. + // 3. metdataDims_; the dimensions we learned from the metadata. std::vector> transforms_; + std::optional resizedOutputDims_; + FrameDims metadataDims_; // Whether or not we have already scanned all streams to update the metadata. bool scannedAllStreams_ = false; From 23ec35f4e20c0ad2c06245cf8e98d60eca7fb1ef Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 26 Sep 2025 13:58:06 -0700 Subject: [PATCH 27/35] Better comment --- src/torchcodec/_core/CpuDeviceInterface.h | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 6f69bfd7f..29a25c3b1 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -72,7 +72,8 @@ class CpuDeviceInterface : public DeviceInterface { std::optional resizedOutputDims_; // Color-conversion objects. Only one of filterGraph_ and swsContext_ should - // be non-null. Which one we use is controlled by colorConversionLibrary_. + // be non-null. Which one we use is determined dynamically in + // getColorConversionLibrary() each time we decode a frame. // // Creating both filterGraph_ and swsContext_ is relatively expensive, so we // reuse them across frames. However, it is possbile that subsequent frames From fb06f87b3dfb67ed137813e2a7e155821225e190 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 26 Sep 2025 19:24:59 -0700 Subject: [PATCH 28/35] Proper frame dims handling in CUDA --- src/torchcodec/_core/CpuDeviceInterface.cpp | 3 +- src/torchcodec/_core/CpuDeviceInterface.h | 3 +- src/torchcodec/_core/CudaDeviceInterface.cpp | 62 ++++++++++---------- src/torchcodec/_core/CudaDeviceInterface.h | 2 - src/torchcodec/_core/DeviceInterface.h | 1 - src/torchcodec/_core/SingleStreamDecoder.cpp | 1 - 6 files changed, 33 insertions(+), 39 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 69ef59005..d22531e2c 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -51,7 +51,6 @@ void CpuDeviceInterface::initialize( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const AVRational& timeBase, - [[maybe_unused]] const FrameDims& metadataDims, const std::optional& resizedOutputDims) { videoStreamOptions_ = videoStreamOptions; timeBase_ = timeBase; @@ -106,7 +105,7 @@ void CpuDeviceInterface::initialize( } ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( - const FrameDims& outputDims) { + const FrameDims& outputDims) const { // swscale requires widths to be multiples of 32: // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0; diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 29a25c3b1..0651a7e0d 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -28,7 +28,6 @@ class CpuDeviceInterface : public DeviceInterface { const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const AVRational& timeBase, - [[maybe_unused]] const FrameDims& metadataDims, const std::optional& resizedOutputDims) override; void convertAVFrameToFrameOutput( @@ -43,7 +42,7 @@ class CpuDeviceInterface : public DeviceInterface { torch::Tensor& outputTensor); ColorConversionLibrary getColorConversionLibrary( - const FrameDims& inputFrameDims); + const FrameDims& inputFrameDims) const; struct SwsFrameContext { int inputWidth = 0; diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 1b64a9d98..be5fa010a 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -190,14 +190,12 @@ void CudaDeviceInterface::initialize( const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& metadataDims, [[maybe_unused]] const std::optional& resizedOutputDims) { TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); videoStreamOptions_ = videoStreamOptions; timeBase_ = timeBase; - metadataDims_ = metadataDims; // It is important for pytorch itself to create the cuda context. If ffmpeg // creates the context it may not be compatible with pytorch. @@ -269,8 +267,8 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( avFrame->height, frameFormat, avFrame->sample_aspect_ratio, - metadataDims_.width, - metadataDims_.height, + avFrame->width, + avFrame->height, outputFormat, filters.str(), timeBase_, @@ -304,15 +302,19 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { + // Note that CUDA does not yet support transforms, so the only possible + // frame dimensions are the raw decoded frame's dimensions. + auto frameDims = FrameDims(avFrame->width, avFrame->height); + if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == metadataDims_.height) && - (shape[1] == metadataDims_.width) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == frameDims.height) && + (shape[1] == frameDims.width) && (shape[2] == 3), "Expected tensor of shape ", - metadataDims_.height, + frameDims.height, "x", - metadataDims_.width, + frameDims.width, "x3, got ", shape); } @@ -333,34 +335,32 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // whatever reason. Typically that happens if the video's encoder isn't // supported by NVDEC. // - // In both cases, we have a frame on the CPU, and we need a CPU device to - // handle it. We send the frame back to the CUDA device when we're done. - // - // TODO: Perhaps we should cache cpuInterface? - auto cpuInterface = createDeviceInterface(torch::kCPU); - TORCH_CHECK( - cpuInterface != nullptr, "Failed to create CPU device interface"); - cpuInterface->initialize( - nullptr, - VideoStreamOptions(), - {}, - timeBase_, - metadataDims_, - std::nullopt); + // In both cases, we have a frame on the CPU. We send the frame back to the + // CUDA device when we're done. enum AVPixelFormat frameFormat = static_cast(avFrame->format); FrameOutput cpuFrameOutput; - - if (frameFormat == AV_PIX_FMT_RGB24 && - avFrame->width == metadataDims_.width && - avFrame->height == metadataDims_.height) { - // Reason 1 above. The frame is already in the format and dimensions that - // we need, we just need to convert it to a tensor. + if (frameFormat == AV_PIX_FMT_RGB24) { + // Reason 1 above. The frame is already in RGB24, we just need to convert + // it to a tensor. cpuFrameOutput.data = rgbAVFrameToTensor(avFrame); } else { - // Reason 2 above. We need to do a full conversion. + // Reason 2 above. We need to do a full conversion which requires an + // actual CPU device. + // + // TODO: Perhaps we should cache cpuInterface? + auto cpuInterface = createDeviceInterface(torch::kCPU); + TORCH_CHECK( + cpuInterface != nullptr, "Failed to create CPU device interface"); + cpuInterface->initialize( + nullptr, + VideoStreamOptions(), + {}, + timeBase_, + /*resizedOutputDims=*/std::nullopt); + cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); } @@ -401,7 +401,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); } else { - dst = allocateEmptyHWCTensor(metadataDims_, device_); + dst = allocateEmptyHWCTensor(frameDims, device_); } torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_); @@ -440,7 +440,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "cudaStreamGetFlags failed: ", cudaGetErrorString(err)); - NppiSize oSizeROI = {metadataDims_.width, metadataDims_.height}; + NppiSize oSizeROI = {frameDims.width, frameDims.height}; Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; NppStatus status; diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 1cc5d2044..c7771a019 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -26,7 +26,6 @@ class CudaDeviceInterface : public DeviceInterface { [[maybe_unused]] const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& metadataOutputDims, [[maybe_unused]] const std::optional& resizedOutputDims) override; @@ -44,7 +43,6 @@ class CudaDeviceInterface : public DeviceInterface { VideoStreamOptions videoStreamOptions_; AVRational timeBase_; - FrameDims metadataDims_; UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 275b2bea0..f6888bda4 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -37,7 +37,6 @@ class DeviceInterface { const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, const AVRational& timeBase, - const FrameDims& metadataDims, const std::optional& resizedOutputDims) = 0; virtual void convertAVFrameToFrameOutput( diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 9bfae18b9..be01b7aa5 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -521,7 +521,6 @@ void SingleStreamDecoder::addVideoStream( videoStreamOptions, transforms_, streamInfo.timeBase, - metadataDims_, resizedOutputDims_); } From 36268541045deb30a9eb414e22a7de71a3d06fae Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Sun, 28 Sep 2025 18:54:53 -0700 Subject: [PATCH 29/35] Make swscale and filtergraph look more similar --- src/torchcodec/_core/CpuDeviceInterface.cpp | 92 ++++++++++++--------- src/torchcodec/_core/CpuDeviceInterface.h | 7 +- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index d22531e2c..49e54c492 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -178,33 +178,13 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( auto colorConversionLibrary = getColorConversionLibrary(outputDims); torch::Tensor outputTensor; - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - SwsFrameContext swsFrameContext( - avFrame->width, - avFrame->height, - frameFormat, - outputDims.width, - outputDims.height); - outputTensor = preAllocatedOutputTensor.value_or( allocateEmptyHWCTensor(outputDims, torch::kCPU)); - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - createSwsContext(swsFrameContext, avFrame->colorspace); - prevSwsFrameContext_ = swsFrameContext; - } int resultHeight = - convertAVFrameToTensorUsingSwScale(avFrame, outputTensor); + convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims); // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. @@ -218,23 +198,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - FiltersContext filtersContext( - avFrame->width, - avFrame->height, - frameFormat, - avFrame->sample_aspect_ratio, - outputDims.width, - outputDims.height, - AV_PIX_FMT_RGB24, - filters_, - timeBase_); - - if (!filterGraph_ || prevFiltersContext_ != filtersContext) { - filterGraph_ = - std::make_unique(filtersContext, videoStreamOptions_); - prevFiltersContext_ = std::move(filtersContext); - } - outputTensor = rgbAVFrameToTensor(filterGraph_->convert(avFrame)); + outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. @@ -267,7 +231,30 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor) { + torch::Tensor& outputTensor, + const FrameDims& outputDims) { + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + SwsFrameContext swsFrameContext( + avFrame->width, + avFrame->height, + frameFormat, + outputDims.width, + outputDims.height); + + if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { + createSwsContext(swsFrameContext, avFrame->colorspace); + prevSwsFrameContext_ = swsFrameContext; + } + uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; int expectedOutputWidth = outputTensor.sizes()[1]; @@ -293,7 +280,7 @@ void CpuDeviceInterface::createSwsContext( swsFrameContext.outputWidth, swsFrameContext.outputHeight, AV_PIX_FMT_RGB24, - SWS_BILINEAR, + swsFlags_, nullptr, nullptr, nullptr); @@ -328,4 +315,29 @@ void CpuDeviceInterface::createSwsContext( swsContext_.reset(swsContext); } +torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( + const UniqueAVFrame& avFrame, + const FrameDims& outputDims) { + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + FiltersContext filtersContext( + avFrame->width, + avFrame->height, + frameFormat, + avFrame->sample_aspect_ratio, + outputDims.width, + outputDims.height, + AV_PIX_FMT_RGB24, + filters_, + timeBase_); + + if (!filterGraph_ || prevFiltersContext_ != filtersContext) { + filterGraph_ = + std::make_unique(filtersContext, videoStreamOptions_); + prevFiltersContext_ = std::move(filtersContext); + } + return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 0651a7e0d..4678188e8 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -39,7 +39,12 @@ class CpuDeviceInterface : public DeviceInterface { private: int convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor); + torch::Tensor& outputTensor, + const FrameDims& outputDims); + + torch::Tensor convertAVFrameToTensorUsingFilterGraph( + const UniqueAVFrame& avFrame, + const FrameDims& outputDims); ColorConversionLibrary getColorConversionLibrary( const FrameDims& inputFrameDims) const; From 1a078287cb8e23133b24312ff88c55086ead5536 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Sun, 28 Sep 2025 18:59:36 -0700 Subject: [PATCH 30/35] Better comment formatting --- src/torchcodec/_core/SingleStreamDecoder.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index be01b7aa5..ddecfcf39 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -615,8 +615,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( std::vector argsort; if (!indicesAreSorted) { // if frameIndices is [13, 10, 12, 11] - // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we - // want + // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want // to use to decode the frames // and argsort is [ 1, 3, 2, 0] argsort.resize(frameIndices.size()); From ee3b9b72c4a286115ce517b502a83eb489c8edaf Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 1 Oct 2025 08:02:18 -0700 Subject: [PATCH 31/35] Apply reviewer suggestions --- src/torchcodec/_core/CpuDeviceInterface.cpp | 2 +- src/torchcodec/_core/CpuDeviceInterface.h | 12 +++++--- src/torchcodec/_core/CudaDeviceInterface.cpp | 31 ++++++++++---------- src/torchcodec/_core/Frame.h | 4 +-- src/torchcodec/_core/SingleStreamDecoder.cpp | 2 +- src/torchcodec/_core/Transform.cpp | 13 ++------ src/torchcodec/_core/Transform.h | 15 ++++------ src/torchcodec/_core/custom_ops.cpp | 3 +- 8 files changed, 39 insertions(+), 43 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 49e54c492..cf72f9f45 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -161,7 +161,7 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( // FrameBatchOutputs based on the the stream metadata. But single-frame APIs // can still work in such situations, so they should. auto outputDims = - resizedOutputDims_.value_or(FrameDims(avFrame->width, avFrame->height)); + resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width)); if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 4678188e8..9d2313832 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -96,10 +96,14 @@ class CpuDeviceInterface : public DeviceInterface { UniqueSwsContext swsContext_; SwsFrameContext prevSwsFrameContext_; - // The filter we supply to filterGraph_, if it is used. The copy filter just - // copies the input to the output. Computationally, it should be a no-op. If - // we get no user-provided transforms, we will use the copy filter. Otherwise, - // we will construct the string from the transforms. + // The filter we supply to filterGraph_, if it is used. The default is the + // copy filter, which just copies the input to the output. Computationally, it + // should be a no-op. If we get no user-provided transforms, we will use the + // copy filter. Otherwise, we will construct the string from the transforms. + // + // Note that even if we only use the copy filter, we still get the desired + // colorspace conversion. We construct the filtergraph with its output sink + // set to RGB24. std::string filters_ = "copy"; // The flags we supply to swsContext_, if it used. The flags control the diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index be5fa010a..a822f02cb 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -174,6 +174,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!"); TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + + // It is important for pytorch itself to create the cuda context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + ctx_ = getCudaContext(device_); + nppCtx_ = getNppStreamContext(device_); } CudaDeviceInterface::~CudaDeviceInterface() { @@ -191,20 +199,12 @@ void CudaDeviceInterface::initialize( [[maybe_unused]] const std::vector>& transforms, const AVRational& timeBase, [[maybe_unused]] const std::optional& resizedOutputDims) { - TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); + TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); videoStreamOptions_ = videoStreamOptions; timeBase_ = timeBase; - - // It is important for pytorch itself to create the cuda context. If ffmpeg - // creates the context it may not be compatible with pytorch. - // This is a dummy tensor to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); - ctx_ = getCudaContext(device_); - nppCtx_ = getNppStreamContext(device_); - codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); } UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( @@ -304,7 +304,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( std::optional preAllocatedOutputTensor) { // Note that CUDA does not yet support transforms, so the only possible // frame dimensions are the raw decoded frame's dimensions. - auto frameDims = FrameDims(avFrame->width, avFrame->height); + auto frameDims = FrameDims(avFrame->height, avFrame->width); if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); @@ -379,14 +379,15 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), - // because this is what the NPP color conversion routines expect. + // because this is what the NPP color conversion routines expect. This SHOULD + // be enforced by our call to maybeConvertAVFrameToNV12() above. + auto hwFramesCtx = + reinterpret_cast(avFrame->hw_frames_ctx->data); TORCH_CHECK( - avFrame->hw_frames_ctx != nullptr, + hwFramesCtx != nullptr, "The AVFrame does not have a hw_frames_ctx. " "That's unexpected, please report this to the TorchCodec repo."); - auto hwFramesCtx = - reinterpret_cast(avFrame->hw_frames_ctx->data); AVPixelFormat actualFormat = hwFramesCtx->sw_format; TORCH_CHECK( diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 99faaa816..a140c1eed 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -14,12 +14,12 @@ namespace facebook::torchcodec { struct FrameDims { - int width = 0; int height = 0; + int width = 0; FrameDims() = default; - FrameDims(int w, int h) : width(w), height(h) {} + FrameDims(int h, int w) : height(h), width(w) {} }; // All public video decoding entry points return either a FrameOutput or a diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index ddecfcf39..284cb8226 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -504,7 +504,7 @@ void SingleStreamDecoder::addVideoStream( } metadataDims_ = - FrameDims(streamMetadata.width.value(), streamMetadata.height.value()); + FrameDims(streamMetadata.height.value(), streamMetadata.width.value()); for (auto& transform : transforms) { TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); if (transform->getOutputFrameDims().has_value()) { diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index 4884f92e9..d0a5104f3 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -17,10 +17,6 @@ std::string toFilterGraphInterpolation( switch (mode) { case ResizeTransform::InterpolationMode::BILINEAR: return "bilinear"; - case ResizeTransform::InterpolationMode::BICUBIC: - return "bicubic"; - case ResizeTransform::InterpolationMode::NEAREST: - return "nearest"; default: TORCH_CHECK( false, @@ -33,10 +29,6 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { switch (mode) { case ResizeTransform::InterpolationMode::BILINEAR: return SWS_BILINEAR; - case ResizeTransform::InterpolationMode::BICUBIC: - return SWS_BICUBIC; - case ResizeTransform::InterpolationMode::NEAREST: - return SWS_POINT; default: TORCH_CHECK( false, @@ -48,12 +40,13 @@ int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { } // namespace std::string ResizeTransform::getFilterGraphCpu() const { - return "scale=" + std::to_string(width_) + ":" + std::to_string(height_) + + return "scale=" + std::to_string(outputDims_.width) + ":" + + std::to_string(outputDims_.height) + ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_); } std::optional ResizeTransform::getOutputFrameDims() const { - return FrameDims(width_, height_); + return outputDims_; } bool ResizeTransform::isResize() const { diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index 307e18b73..6aea255ab 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -37,15 +37,13 @@ class Transform { class ResizeTransform : public Transform { public: - enum class InterpolationMode { BILINEAR, BICUBIC, NEAREST }; + enum class InterpolationMode { BILINEAR }; - ResizeTransform(int width, int height) - : width_(width), - height_(height), - interpolationMode_(InterpolationMode::BILINEAR) {} + ResizeTransform(const FrameDims& dims) + : outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {} - ResizeTransform(int width, int height, InterpolationMode interpolationMode) - : width_(width), height_(height), interpolationMode_(interpolationMode) {} + ResizeTransform(const FrameDims& dims, InterpolationMode interpolationMode) + : outputDims_(dims), interpolationMode_(interpolationMode) {} std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; @@ -54,8 +52,7 @@ class ResizeTransform : public Transform { int getSwsFlags() const; private: - int width_; - int height_; + FrameDims outputDims_; InterpolationMode interpolationMode_; }; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index b6987036e..bd1d3ff35 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -274,7 +274,8 @@ void _add_video_stream( "width and height must both be set or unset."); std::vector transforms; if (width.has_value()) { - transforms.push_back(new ResizeTransform(width.value(), height.value())); + transforms.push_back( + new ResizeTransform(FrameDims(height.value(), width.value()))); width.reset(); height.reset(); } From d2e9bde61f02afbe0bb7abc6354240c8b7dd7dcb Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 1 Oct 2025 11:41:02 -0700 Subject: [PATCH 32/35] Refactor device interface, again. --- src/torchcodec/_core/CpuDeviceInterface.cpp | 7 ++++-- src/torchcodec/_core/CpuDeviceInterface.h | 12 +++++++++- src/torchcodec/_core/CudaDeviceInterface.cpp | 24 +++++++++++--------- src/torchcodec/_core/CudaDeviceInterface.h | 7 +++--- src/torchcodec/_core/DeviceInterface.h | 16 ++++++++----- src/torchcodec/_core/SingleStreamDecoder.cpp | 21 +++++++++-------- 6 files changed, 55 insertions(+), 32 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index cf72f9f45..8376be900 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -48,12 +48,15 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) void CpuDeviceInterface::initialize( [[maybe_unused]] AVCodecContext* codecContext, + const AVRational& timeBase) { + timeBase_ = timeBase; +} + +void CpuDeviceInterface::initializeVideo( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, - const AVRational& timeBase, const std::optional& resizedOutputDims) { videoStreamOptions_ = videoStreamOptions; - timeBase_ = timeBase; resizedOutputDims_ = resizedOutputDims; // We can only use swscale when we have a single resize transform. Note that diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 9d2313832..a2927b141 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -25,9 +25,11 @@ class CpuDeviceInterface : public DeviceInterface { virtual void initialize( [[maybe_unused]] AVCodecContext* codecContext, + const AVRational& timeBase) override; + + virtual void initializeVideo( const VideoStreamOptions& videoStreamOptions, const std::vector>& transforms, - const AVRational& timeBase, const std::optional& resizedOutputDims) override; void convertAVFrameToFrameOutput( @@ -73,6 +75,14 @@ class CpuDeviceInterface : public DeviceInterface { VideoStreamOptions videoStreamOptions_; AVRational timeBase_; + + // If the resized output dimensions are present, then we always use those as + // the output frame's dimensions. If they are not present, then we use the + // dimensions of the raw decoded frame. Note that we do not know the + // dimensions of the raw decoded frame until very late; we learn it in + // convertAVFrameToFrameOutput(). Deciding the final output frame's actual + // dimensions late allows us to handle video streams with variable + // resolutions. std::optional resizedOutputDims_; // Color-conversion objects. Only one of filterGraph_ and swsContext_ should diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index a822f02cb..b0471ebae 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -195,18 +195,20 @@ CudaDeviceInterface::~CudaDeviceInterface() { void CudaDeviceInterface::initialize( AVCodecContext* codecContext, - const VideoStreamOptions& videoStreamOptions, - [[maybe_unused]] const std::vector>& transforms, - const AVRational& timeBase, - [[maybe_unused]] const std::optional& resizedOutputDims) { + const AVRational& timeBase) { TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); - codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); - videoStreamOptions_ = videoStreamOptions; timeBase_ = timeBase; } +void CudaDeviceInterface::initializeVideo( + const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const std::vector>& transforms, + [[maybe_unused]] const std::optional& resizedOutputDims) { + videoStreamOptions_ = videoStreamOptions; +} + UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( UniqueAVFrame& avFrame) { // We need FFmpeg filters to handle those conversion cases which are not @@ -220,13 +222,13 @@ UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( return std::move(avFrame); } + auto hwFramesCtx = + reinterpret_cast(avFrame->hw_frames_ctx->data); TORCH_CHECK( - avFrame->hw_frames_ctx != nullptr, + hwFramesCtx != nullptr, "The AVFrame does not have a hw_frames_ctx. " "That's unexpected, please report this to the TorchCodec repo."); - auto hwFramesCtx = - reinterpret_cast(avFrame->hw_frames_ctx->data); AVPixelFormat actualFormat = hwFramesCtx->sw_format; // If the frame is already in NV12 format, we don't need to do anything. @@ -355,10 +357,10 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( TORCH_CHECK( cpuInterface != nullptr, "Failed to create CPU device interface"); cpuInterface->initialize( - nullptr, + /*codecContext=*/nullptr, timeBase_); + cpuInterface->initializeVideo( VideoStreamOptions(), {}, - timeBase_, /*resizedOutputDims=*/std::nullopt); cpuInterface->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index c7771a019..4daf29a8b 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -20,12 +20,13 @@ class CudaDeviceInterface : public DeviceInterface { std::optional findCodec(const AVCodecID& codecId) override; - void initialize( - AVCodecContext* codecContext, + void initialize(AVCodecContext* codecContext, const AVRational& timeBase) + override; + + void initializeVideo( const VideoStreamOptions& videoStreamOptions, [[maybe_unused]] const std::vector>& transforms, - const AVRational& timeBase, [[maybe_unused]] const std::optional& resizedOutputDims) override; diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index f6888bda4..a3c9e94b0 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -30,14 +30,18 @@ class DeviceInterface { virtual std::optional findCodec(const AVCodecID& codecId) = 0; - // Initialize the hardware device that is specified in `device`. Some builds - // support CUDA and others only support CPU. + // Initialize the device with parameters generic to all kinds of decoding. virtual void initialize( AVCodecContext* codecContext, - const VideoStreamOptions& videoStreamOptions, - const std::vector>& transforms, - const AVRational& timeBase, - const std::optional& resizedOutputDims) = 0; + const AVRational& timeBase) = 0; + + // Initialize the device with parameters specific to video decoding. There is + // a default empty implementation. + virtual void initializeVideo( + [[maybe_unused]] const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const std::vector>& + transforms, + [[maybe_unused]] const std::optional& resizedOutputDims) {} virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 284cb8226..b150d54e4 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -436,6 +436,9 @@ void SingleStreamDecoder::addStream( TORCH_CHECK(codecContext != nullptr); streamInfo.codecContext.reset(codecContext); + deviceInterface_->initialize( + streamInfo.codecContext.get(), streamInfo.timeBase); + int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); TORCH_CHECK_EQ(retVal, AVSUCCESS); @@ -443,6 +446,10 @@ void SingleStreamDecoder::addStream( streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0); streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; + // Note that we must make sure to call avcodec_open2() AFTER we initialize + // the device interface. Device initialization tells the codec context which + // device to use. If we initialize the device interface after avcodec_open2(), + // then all decoding may fall back to the CPU. retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); @@ -510,18 +517,14 @@ void SingleStreamDecoder::addVideoStream( if (transform->getOutputFrameDims().has_value()) { resizedOutputDims_ = transform->getOutputFrameDims().value(); } + + // Note that we are claiming ownership of the transform objects passed in to + // us. transforms_.push_back(std::unique_ptr(transform)); } - // We initialize the device context late because we want to know a lot of - // information that we can only know after resolving the codec, opening the - // stream and inspecting the metadata. - deviceInterface_->initialize( - streamInfo.codecContext.get(), - videoStreamOptions, - transforms_, - streamInfo.timeBase, - resizedOutputDims_); + deviceInterface_->initializeVideo( + videoStreamOptions, transforms_, resizedOutputDims_); } void SingleStreamDecoder::addAudioStream( From db2ea07c9770c37c2e916f895a4f4896f6cde497 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 1 Oct 2025 12:25:33 -0700 Subject: [PATCH 33/35] Clean up comment --- src/torchcodec/_core/Frame.h | 39 +----------------------------------- 1 file changed, 1 insertion(+), 38 deletions(-) diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index a140c1eed..4b27d5bdd 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -58,49 +58,12 @@ struct AudioFramesOutput { // FRAME TENSOR ALLOCATION APIs // -------------------------------------------------------------------------- -// Note [Frame Tensor allocation and height and width] +// Note [Frame Tensor allocation] // // We always allocate [N]HWC tensors. The low-level decoding functions all // assume HWC tensors, since this is what FFmpeg natively handles. It's up to // the high-level decoding entry-points to permute that back to CHW, by calling // maybePermuteHWC2CHW(). -// -// TODO: Rationalize the comment below with refactoring. -// -// Also, importantly, the way we figure out the the height and width of the -// output frame tensor varies, and depends on the decoding entry-point. In -// *decreasing order of accuracy*, we use the following sources for determining -// height and width: -// - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the -// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, -// on CPU, with filtergraph. -// - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from -// the user-specified options if they exist, or the height and width of the -// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within -// our code or within FFmpeg code, this should be exactly the same as -// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame -// decoding APIs, on CPU with swscale, and on GPU. -// - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from -// the user-specified options if they exist, or the height and width form the -// stream metadata, which itself got its value from the CodecContext, when the -// stream was added. This is used by batch decoding APIs, for both GPU and -// CPU. -// -// The source of truth for height and width really is the (resized) AVFrame: it -// comes from the decoded ouptut of FFmpeg. The info from the metadata (i.e. -// from the CodecContext) may not be as accurate. However, the AVFrame is only -// available late in the call stack, when the frame is decoded, while the -// CodecContext is available early when a stream is added. This is why we use -// the CodecContext for pre-allocating batched output tensors (we could -// pre-allocate those only once we decode the first frame to get the info frame -// the AVFrame, but that's a more complex logic). -// -// Because the sources for height and width may disagree, we may end up with -// conflicts: e.g. if we pre-allocate a batch output tensor based on the -// metadata info, but the decoded AVFrame has a different height and width. -// it is very important to check the height and width assumptions where the -// tensors memory is used/filled in order to avoid segfaults. - torch::Tensor allocateEmptyHWCTensor( const FrameDims& frameDims, const torch::Device& device, From 1753f9cadb6879d2733caf92f04d87df75dddd64 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 2 Oct 2025 09:59:14 -0700 Subject: [PATCH 34/35] Name change --- src/torchcodec/_core/CudaDeviceInterface.cpp | 12 ++++++------ src/torchcodec/_core/CudaDeviceInterface.h | 2 +- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index b0471ebae..5cdc450cd 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -209,7 +209,7 @@ void CudaDeviceInterface::initializeVideo( videoStreamOptions_ = videoStreamOptions; } -UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12( +UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( UniqueAVFrame& avFrame) { // We need FFmpeg filters to handle those conversion cases which are not // directly implemented in CUDA or CPU device interface (in case of a @@ -323,16 +323,16 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by // converting them to NV12. - avFrame = maybeConvertAVFrameToNV12(avFrame); + avFrame = maybeConvertAVFrameToNV12OrRGB24(avFrame); if (avFrame->format != AV_PIX_FMT_CUDA) { // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on // the GPU. In this branch, the frame is on the CPU. There are two possible // reasons: // - // 1. During maybeConvertAVFrameToNV12(), we had a non-NV12 format frame - // and we're on FFmpeg 4.4 or earlier. In such cases, we had to use CPU - // filters and we just converted the frame to RGB24. + // 1. During maybeConvertAVFrameToNV12OrRGB24(), we had a non-NV12 format + // frame and we're on FFmpeg 4.4 or earlier. In such cases, we had to + // use CPU filters and we just converted the frame to RGB24. // 2. This is what NVDEC gave us if it wasn't able to decode a frame, for // whatever reason. Typically that happens if the video's encoder isn't // supported by NVDEC. @@ -382,7 +382,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), // because this is what the NPP color conversion routines expect. This SHOULD - // be enforced by our call to maybeConvertAVFrameToNV12() above. + // be enforced by our call to maybeConvertAVFrameToNV12OrRGB24() above. auto hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); TORCH_CHECK( diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 4daf29a8b..9e757722a 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -40,7 +40,7 @@ class CudaDeviceInterface : public DeviceInterface { // Our CUDA decoding code assumes NV12 format. In order to handle other // kinds of input, we need to convert them to NV12. Our current implementation // does this using filtergraph. - UniqueAVFrame maybeConvertAVFrameToNV12(UniqueAVFrame& avFrame); + UniqueAVFrame maybeConvertAVFrameToNV12OrRGB24(UniqueAVFrame& avFrame); VideoStreamOptions videoStreamOptions_; AVRational timeBase_; From 9efb767b947c50f790dcc8079d5bfc964fa6a68c Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 3 Oct 2025 08:34:27 -0700 Subject: [PATCH 35/35] Stragglers --- .../_core/BetaCudaDeviceInterface.cpp | 25 +++++++------------ .../_core/BetaCudaDeviceInterface.h | 4 +-- src/torchcodec/_core/CpuDeviceInterface.cpp | 7 +++--- src/torchcodec/_core/CpuDeviceInterface.h | 4 +-- src/torchcodec/_core/CudaDeviceInterface.h | 9 +++++-- 5 files changed, 21 insertions(+), 28 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 7e88efbd3..f0b4cbfcc 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -156,13 +156,20 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { } } -void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) { +void BetaCudaDeviceInterface::initialize(const AVStream* avStream) { torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); timeBase_ = avStream->time_base; + auto cudaDevice = torch::Device(torch::kCUDA); + defaultCudaInterface_ = + std::unique_ptr(createDeviceInterface(cudaDevice)); + AVCodecContext dummyCodecContext = {}; + defaultCudaInterface_->initialize(avStream); + defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext); + const AVCodecParameters* codecpar = avStream->codecpar; TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null"); @@ -523,8 +530,6 @@ void BetaCudaDeviceInterface::flush() { } void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { @@ -535,20 +540,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( // TODONVDEC P1: we use the 'default' cuda device interface for color // conversion. That's a temporary hack to make things work. we should abstract // the color conversion stuff separately. - if (!defaultCudaInterface_) { - auto cudaDevice = torch::Device(torch::kCUDA); - defaultCudaInterface_ = - std::unique_ptr(createDeviceInterface(cudaDevice)); - AVCodecContext dummyCodecContext = {}; - defaultCudaInterface_->initializeContext(&dummyCodecContext); - } - defaultCudaInterface_->convertAVFrameToFrameOutput( - videoStreamOptions, - timeBase, - avFrame, - frameOutput, - preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); } BetaCudaDeviceInterface::FrameBuffer::Slot* diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index d42885c75..b19112f0d 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -37,11 +37,9 @@ class BetaCudaDeviceInterface : public DeviceInterface { explicit BetaCudaDeviceInterface(const torch::Device& device); virtual ~BetaCudaDeviceInterface(); - void initializeInterface(AVStream* stream) override; + void initialize(const AVStream* avStream) override; void convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 3fbc15ac1..8c85c2fcf 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -46,10 +46,9 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) device_.type() == torch::kCPU, "Unsupported device: ", device_.str()); } -void CpuDeviceInterface::initialize( - [[maybe_unused]] AVCodecContext* codecContext, - const AVRational& timeBase) { - timeBase_ = timeBase; +void CpuDeviceInterface::initialize(const AVStream* avStream) { + TORCH_CHECK(avStream != nullptr, "avStream is null"); + timeBase_ = avStream->time_base; } void CpuDeviceInterface::initializeVideo( diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index a2927b141..305b5ae14 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -23,9 +23,7 @@ class CpuDeviceInterface : public DeviceInterface { return std::nullopt; } - virtual void initialize( - [[maybe_unused]] AVCodecContext* codecContext, - const AVRational& timeBase) override; + virtual void initialize(const AVStream* avStream) override; virtual void initializeVideo( const VideoStreamOptions& videoStreamOptions, diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 9e757722a..42d517a72 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -20,8 +20,7 @@ class CudaDeviceInterface : public DeviceInterface { std::optional findCodec(const AVCodecID& codecId) override; - void initialize(AVCodecContext* codecContext, const AVRational& timeBase) - override; + void initialize(const AVStream* avStream) override; void initializeVideo( const VideoStreamOptions& videoStreamOptions, @@ -30,6 +29,8 @@ class CudaDeviceInterface : public DeviceInterface { [[maybe_unused]] const std::optional& resizedOutputDims) override; + void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override; + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, @@ -42,6 +43,10 @@ class CudaDeviceInterface : public DeviceInterface { // does this using filtergraph. UniqueAVFrame maybeConvertAVFrameToNV12OrRGB24(UniqueAVFrame& avFrame); + // We sometimes encounter frames that cannot be decoded on the CUDA device. + // Rather than erroring out, we decode them on the CPU. + std::unique_ptr cpuInterface_; + VideoStreamOptions videoStreamOptions_; AVRational timeBase_;