diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 7e88efbd3..f0b4cbfcc 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -156,13 +156,20 @@ BetaCudaDeviceInterface::~BetaCudaDeviceInterface() { } } -void BetaCudaDeviceInterface::initializeInterface(AVStream* avStream) { +void BetaCudaDeviceInterface::initialize(const AVStream* avStream) { torch::Tensor dummyTensorForCudaInitialization = torch::empty( {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); TORCH_CHECK(avStream != nullptr, "AVStream cannot be null"); timeBase_ = avStream->time_base; + auto cudaDevice = torch::Device(torch::kCUDA); + defaultCudaInterface_ = + std::unique_ptr(createDeviceInterface(cudaDevice)); + AVCodecContext dummyCodecContext = {}; + defaultCudaInterface_->initialize(avStream); + defaultCudaInterface_->registerHardwareDeviceWithCodec(&dummyCodecContext); + const AVCodecParameters* codecpar = avStream->codecpar; TORCH_CHECK(codecpar != nullptr, "CodecParameters cannot be null"); @@ -523,8 +530,6 @@ void BetaCudaDeviceInterface::flush() { } void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { @@ -535,20 +540,8 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( // TODONVDEC P1: we use the 'default' cuda device interface for color // conversion. That's a temporary hack to make things work. we should abstract // the color conversion stuff separately. - if (!defaultCudaInterface_) { - auto cudaDevice = torch::Device(torch::kCUDA); - defaultCudaInterface_ = - std::unique_ptr(createDeviceInterface(cudaDevice)); - AVCodecContext dummyCodecContext = {}; - defaultCudaInterface_->initializeContext(&dummyCodecContext); - } - defaultCudaInterface_->convertAVFrameToFrameOutput( - videoStreamOptions, - timeBase, - avFrame, - frameOutput, - preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); } BetaCudaDeviceInterface::FrameBuffer::Slot* diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index d42885c75..b19112f0d 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -37,11 +37,9 @@ class BetaCudaDeviceInterface : public DeviceInterface { explicit BetaCudaDeviceInterface(const torch::Device& device); virtual ~BetaCudaDeviceInterface(); - void initializeInterface(AVStream* stream) override; + void initialize(const AVStream* avStream) override; void convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 7c04d79d4..f9b24ace2 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -95,6 +95,7 @@ function(make_torchcodec_libraries SingleStreamDecoder.cpp Encoder.cpp ValidationUtils.cpp + Transform.cpp ) if(ENABLE_CUDA) diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 692a4aa31..8c85c2fcf 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -46,6 +46,94 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) device_.type() == torch::kCPU, "Unsupported device: ", device_.str()); } +void CpuDeviceInterface::initialize(const AVStream* avStream) { + TORCH_CHECK(avStream != nullptr, "avStream is null"); + timeBase_ = avStream->time_base; +} + +void CpuDeviceInterface::initializeVideo( + const VideoStreamOptions& videoStreamOptions, + const std::vector>& transforms, + const std::optional& resizedOutputDims) { + videoStreamOptions_ = videoStreamOptions; + resizedOutputDims_ = resizedOutputDims; + + // We can only use swscale when we have a single resize transform. Note that + // this means swscale will not support the case of having several, + // back-to-base resizes. There's no strong reason to even do that, but if + // someone does, it's more correct to implement that with filtergraph. + // + // We calculate this value during initilization but we don't refer to it until + // getColorConversionLibrary() is called. Calculating this value during + // initialization saves us from having to save all of the transforms. + areTransformsSwScaleCompatible_ = transforms.empty() || + (transforms.size() == 1 && transforms[0]->isResize()); + + // Note that we do not expose this capability in the public API, only through + // the core API. + // + // Same as above, we calculate this value during initialization and refer to + // it in getColorConversionLibrary(). + userRequestedSwScale_ = videoStreamOptions_.colorConversionLibrary == + ColorConversionLibrary::SWSCALE; + + // We can only use swscale when we have a single resize transform. Note that + // we actually decide on whether or not to actually use swscale at the last + // possible moment, when we actually convert the frame. This is because we + // need to know the actual frame dimensions. + if (transforms.size() == 1 && transforms[0]->isResize()) { + auto resize = dynamic_cast(transforms[0].get()); + TORCH_CHECK(resize != nullptr, "ResizeTransform expected but not found!") + swsFlags_ = resize->getSwsFlags(); + } + + // If we have any transforms, replace filters_ with the filter strings from + // the transforms. As noted above, we decide between swscale and filtergraph + // when we actually decode a frame. + std::stringstream filters; + bool first = true; + for (const auto& transform : transforms) { + if (!first) { + filters << ","; + } + filters << transform->getFilterGraphCpu(); + first = false; + } + if (!transforms.empty()) { + filters_ = filters.str(); + } + + initialized_ = true; +} + +ColorConversionLibrary CpuDeviceInterface::getColorConversionLibrary( + const FrameDims& outputDims) const { + // swscale requires widths to be multiples of 32: + // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements + bool isWidthSwScaleCompatible = (outputDims.width % 32) == 0; + + // We want to use swscale for color conversion if possible because it is + // faster than filtergraph. The following are the conditions we need to meet + // to use it. + // + // Note that we treat the transform limitation differently from the width + // limitation. That is, we consider the transforms being compatible with + // swscale as a hard requirement. If the transforms are not compatiable, + // then we will end up not applying the transforms, and that is wrong. + // + // The width requirement, however, is a soft requirement. Even if we don't + // meet it, we let the user override it. We have tests that depend on this + // behavior. Since we don't expose the ability to choose swscale or + // filtergraph in our public API, this is probably okay. It's also the only + // way that we can be certain we are testing one versus the other. + if (areTransformsSwScaleCompatible_ && + (userRequestedSwScale_ || isWidthSwScaleCompatible)) { + return ColorConversionLibrary::SWSCALE; + } else { + return ColorConversionLibrary::FILTERGRAPH; + } +} + // Note [preAllocatedOutputTensor with swscale and filtergraph]: // Callers may pass a pre-allocated tensor, where the output.data tensor will // be stored. This parameter is honored in any case, but it only leads to a @@ -56,139 +144,74 @@ CpuDeviceInterface::CpuDeviceInterface(const torch::Device& device) // Dimension order of the preAllocatedOutputTensor must be HWC, regardless of // `dimension_order` parameter. It's up to callers to re-shape it if needed. void CpuDeviceInterface::convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); - int expectedOutputHeight = frameDims.height; - int expectedOutputWidth = frameDims.width; + TORCH_CHECK(initialized_, "CpuDeviceInterface was not initialized."); + + // Note that we ignore the dimensions from the metadata; we don't even bother + // storing them. The resized dimensions take priority. If we don't have any, + // then we use the dimensions from the actual decoded frame. We use the actual + // decoded frame and not the metadata for two reasons: + // + // 1. Metadata may be wrong. If we access to more accurate information, we + // should use it. + // 2. Video streams can have variable resolution. This fact is not captured + // in the stream metadata. + // + // Both cases cause problems for our batch APIs, as we allocate + // FrameBatchOutputs based on the the stream metadata. But single-frame APIs + // can still work in such situations, so they should. + auto outputDims = + resizedOutputDims_.value_or(FrameDims(avFrame->height, avFrame->width)); if (preAllocatedOutputTensor.has_value()) { auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims.height) && + (shape[1] == outputDims.width) && (shape[2] == 3), "Expected pre-allocated tensor of shape ", - expectedOutputHeight, + outputDims.height, "x", - expectedOutputWidth, + outputDims.width, "x3, got ", shape); } + auto colorConversionLibrary = getColorConversionLibrary(outputDims); torch::Tensor outputTensor; - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); - - // This is an early-return optimization: if the format is already what we - // need, and the dimensions are also what we need, we don't need to call - // swscale or filtergraph. We can just convert the AVFrame to a tensor. - if (frameFormat == AV_PIX_FMT_RGB24 && - avFrame->width == expectedOutputWidth && - avFrame->height == expectedOutputHeight) { - outputTensor = toTensor(avFrame); - if (preAllocatedOutputTensor.has_value()) { - // We have already validated that preAllocatedOutputTensor and - // outputTensor have the same shape. - preAllocatedOutputTensor.value().copy_(outputTensor); - frameOutput.data = preAllocatedOutputTensor.value(); - } else { - frameOutput.data = outputTensor; - } - return; - } - - // By default, we want to use swscale for color conversion because it is - // faster. However, it has width requirements, so we may need to fall back - // to filtergraph. We also need to respect what was requested from the - // options; we respect the options unconditionally, so it's possible for - // swscale's width requirements to be violated. We don't expose the ability to - // choose color conversion library publicly; we only use this ability - // internally. - - // swscale requires widths to be multiples of 32: - // https://stackoverflow.com/questions/74351955/turn-off-sw-scale-conversion-to-planar-yuv-32-byte-alignment-requirements - // so we fall back to filtergraph if the width is not a multiple of 32. - auto defaultLibrary = (expectedOutputWidth % 32 == 0) - ? ColorConversionLibrary::SWSCALE - : ColorConversionLibrary::FILTERGRAPH; - - ColorConversionLibrary colorConversionLibrary = - videoStreamOptions.colorConversionLibrary.value_or(defaultLibrary); if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - SwsFrameContext swsFrameContext( - avFrame->width, - avFrame->height, - frameFormat, - expectedOutputWidth, - expectedOutputHeight); - - outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( - expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - - if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { - createSwsContext(swsFrameContext, avFrame->colorspace); - prevSwsFrameContext_ = swsFrameContext; - } + outputTensor = preAllocatedOutputTensor.value_or( + allocateEmptyHWCTensor(outputDims, torch::kCPU)); + int resultHeight = - convertAVFrameToTensorUsingSwsScale(avFrame, outputTensor); + convertAVFrameToTensorUsingSwScale(avFrame, outputTensor, outputDims); + // If this check failed, it would mean that the frame wasn't reshaped to // the expected height. // TODO: Can we do the same check for width? TORCH_CHECK( - resultHeight == expectedOutputHeight, - "resultHeight != expectedOutputHeight: ", + resultHeight == outputDims.height, + "resultHeight != outputDims.height: ", resultHeight, " != ", - expectedOutputHeight); + outputDims.height); frameOutput.data = outputTensor; } else if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - // See comment above in swscale branch about the filterGraphContext_ - // creation. creation - std::stringstream filters; - filters << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; - filters << ":sws_flags=bilinear"; - - FiltersContext filtersContext( - avFrame->width, - avFrame->height, - frameFormat, - avFrame->sample_aspect_ratio, - expectedOutputWidth, - expectedOutputHeight, - AV_PIX_FMT_RGB24, - filters.str(), - timeBase); - - if (!filterGraphContext_ || prevFiltersContext_ != filtersContext) { - filterGraphContext_ = - std::make_unique(filtersContext, videoStreamOptions); - prevFiltersContext_ = std::move(filtersContext); - } - outputTensor = toTensor(filterGraphContext_->convert(avFrame)); + outputTensor = convertAVFrameToTensorUsingFilterGraph(avFrame, outputDims); // Similarly to above, if this check fails it means the frame wasn't // reshaped to its expected dimensions by filtergraph. auto shape = outputTensor.sizes(); TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), + (shape.size() == 3) && (shape[0] == outputDims.height) && + (shape[1] == outputDims.width) && (shape[2] == 3), "Expected output tensor of shape ", - expectedOutputHeight, + outputDims.height, "x", - expectedOutputWidth, + outputDims.width, "x3, got ", shape); @@ -208,9 +231,32 @@ void CpuDeviceInterface::convertAVFrameToFrameOutput( } } -int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( +int CpuDeviceInterface::convertAVFrameToTensorUsingSwScale( const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor) { + torch::Tensor& outputTensor, + const FrameDims& outputDims) { + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + SwsFrameContext swsFrameContext( + avFrame->width, + avFrame->height, + frameFormat, + outputDims.width, + outputDims.height); + + if (!swsContext_ || prevSwsFrameContext_ != swsFrameContext) { + createSwsContext(swsFrameContext, avFrame->colorspace); + prevSwsFrameContext_ = swsFrameContext; + } + uint8_t* pointers[4] = { outputTensor.data_ptr(), nullptr, nullptr, nullptr}; int expectedOutputWidth = outputTensor.sizes()[1]; @@ -226,22 +272,6 @@ int CpuDeviceInterface::convertAVFrameToTensorUsingSwsScale( return resultHeight; } -torch::Tensor CpuDeviceInterface::toTensor(const UniqueAVFrame& avFrame) { - TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24); - - auto frameDims = getHeightAndWidthFromResizedAVFrame(*avFrame.get()); - int height = frameDims.height; - int width = frameDims.width; - std::vector shape = {height, width, 3}; - std::vector strides = {avFrame->linesize[0], 3, 1}; - AVFrame* avFrameClone = av_frame_clone(avFrame.get()); - auto deleter = [avFrameClone](void*) { - UniqueAVFrame avFrameToDelete(avFrameClone); - }; - return torch::from_blob( - avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8}); -} - void CpuDeviceInterface::createSwsContext( const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace) { @@ -252,7 +282,7 @@ void CpuDeviceInterface::createSwsContext( swsFrameContext.outputWidth, swsFrameContext.outputHeight, AV_PIX_FMT_RGB24, - SWS_BILINEAR, + swsFlags_, nullptr, nullptr, nullptr); @@ -287,4 +317,29 @@ void CpuDeviceInterface::createSwsContext( swsContext_.reset(swsContext); } +torch::Tensor CpuDeviceInterface::convertAVFrameToTensorUsingFilterGraph( + const UniqueAVFrame& avFrame, + const FrameDims& outputDims) { + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + FiltersContext filtersContext( + avFrame->width, + avFrame->height, + frameFormat, + avFrame->sample_aspect_ratio, + outputDims.width, + outputDims.height, + AV_PIX_FMT_RGB24, + filters_, + timeBase_); + + if (!filterGraph_ || prevFiltersContext_ != filtersContext) { + filterGraph_ = + std::make_unique(filtersContext, videoStreamOptions_); + prevFiltersContext_ = std::move(filtersContext); + } + return rgbAVFrameToTensor(filterGraph_->convert(avFrame)); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 347b738a1..305b5ae14 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -23,23 +23,31 @@ class CpuDeviceInterface : public DeviceInterface { return std::nullopt; } - void initializeContext( - [[maybe_unused]] AVCodecContext* codecContext) override {} + virtual void initialize(const AVStream* avStream) override; - void convertAVFrameToFrameOutput( + virtual void initializeVideo( const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, + const std::vector>& transforms, + const std::optional& resizedOutputDims) override; + + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) override; private: - int convertAVFrameToTensorUsingSwsScale( + int convertAVFrameToTensorUsingSwScale( + const UniqueAVFrame& avFrame, + torch::Tensor& outputTensor, + const FrameDims& outputDims); + + torch::Tensor convertAVFrameToTensorUsingFilterGraph( const UniqueAVFrame& avFrame, - torch::Tensor& outputTensor); + const FrameDims& outputDims); - torch::Tensor toTensor(const UniqueAVFrame& avFrame); + ColorConversionLibrary getColorConversionLibrary( + const FrameDims& inputFrameDims) const; struct SwsFrameContext { int inputWidth = 0; @@ -63,15 +71,60 @@ class CpuDeviceInterface : public DeviceInterface { const SwsFrameContext& swsFrameContext, const enum AVColorSpace colorspace); - // color-conversion fields. Only one of FilterGraphContext and - // UniqueSwsContext should be non-null. - std::unique_ptr filterGraphContext_; + VideoStreamOptions videoStreamOptions_; + AVRational timeBase_; + + // If the resized output dimensions are present, then we always use those as + // the output frame's dimensions. If they are not present, then we use the + // dimensions of the raw decoded frame. Note that we do not know the + // dimensions of the raw decoded frame until very late; we learn it in + // convertAVFrameToFrameOutput(). Deciding the final output frame's actual + // dimensions late allows us to handle video streams with variable + // resolutions. + std::optional resizedOutputDims_; + + // Color-conversion objects. Only one of filterGraph_ and swsContext_ should + // be non-null. Which one we use is determined dynamically in + // getColorConversionLibrary() each time we decode a frame. + // + // Creating both filterGraph_ and swsContext_ is relatively expensive, so we + // reuse them across frames. However, it is possbile that subsequent frames + // are different enough (change in dimensions) that we can't reuse the color + // conversion object. We store the relevant frame context from the frame used + // to create the object last time. We always compare the current frame's info + // against the previous one to determine if we need to recreate the color + // conversion object. + // + // TODO: The names of these fields is confusing, as the actual color + // conversion object for Sws has "context" in the name, and we use + // "context" for the structs we store to know if we need to recreate a + // color conversion object. We should clean that up. + std::unique_ptr filterGraph_; + FiltersContext prevFiltersContext_; UniqueSwsContext swsContext_; - - // Used to know whether a new FilterGraphContext or UniqueSwsContext should - // be created before decoding a new frame. SwsFrameContext prevSwsFrameContext_; - FiltersContext prevFiltersContext_; + + // The filter we supply to filterGraph_, if it is used. The default is the + // copy filter, which just copies the input to the output. Computationally, it + // should be a no-op. If we get no user-provided transforms, we will use the + // copy filter. Otherwise, we will construct the string from the transforms. + // + // Note that even if we only use the copy filter, we still get the desired + // colorspace conversion. We construct the filtergraph with its output sink + // set to RGB24. + std::string filters_ = "copy"; + + // The flags we supply to swsContext_, if it used. The flags control the + // resizing algorithm. We default to bilinear. Users can override this with a + // ResizeTransform. + int swsFlags_ = SWS_BILINEAR; + + // Values set during initialization and referred to in + // getColorConversionLibrary(). + bool areTransformsSwScaleCompatible_; + bool userRequestedSwScale_; + + bool initialized_ = false; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 5629686b4..c7f02185a 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -184,6 +184,14 @@ CudaDeviceInterface::CudaDeviceInterface(const torch::Device& device) TORCH_CHECK(g_cuda, "CudaDeviceInterface was not registered!"); TORCH_CHECK( device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + + // It is important for pytorch itself to create the cuda context. If ffmpeg + // creates the context it may not be compatible with pytorch. + // This is a dummy tensor to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::empty( + {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); + ctx_ = getCudaContext(device_); + nppCtx_ = getNppStreamContext(device_); } CudaDeviceInterface::~CudaDeviceInterface() { @@ -195,57 +203,67 @@ CudaDeviceInterface::~CudaDeviceInterface() { } } -void CudaDeviceInterface::initializeContext(AVCodecContext* codecContext) { - TORCH_CHECK(!ctx_, "FFmpeg HW device context already initialized"); +void CudaDeviceInterface::initialize(const AVStream* avStream) { + TORCH_CHECK(avStream != nullptr, "avStream is null"); + timeBase_ = avStream->time_base; - // It is important for pytorch itself to create the cuda context. If ffmpeg - // creates the context it may not be compatible with pytorch. - // This is a dummy tensor to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::empty( - {1}, torch::TensorOptions().dtype(torch::kUInt8).device(device_)); - ctx_ = getCudaContext(device_); - nppCtx_ = getNppStreamContext(device_); - codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); - return; + cpuInterface_ = createDeviceInterface(torch::kCPU); + TORCH_CHECK( + cpuInterface_ != nullptr, "Failed to create CPU device interface"); + cpuInterface_->initialize(avStream); + cpuInterface_->initializeVideo( + VideoStreamOptions(), + {}, + /*resizedOutputDims=*/std::nullopt); } -std::unique_ptr CudaDeviceInterface::initializeFiltersContext( +void CudaDeviceInterface::initializeVideo( const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame, - const AVRational& timeBase) { + [[maybe_unused]] const std::vector>& transforms, + [[maybe_unused]] const std::optional& resizedOutputDims) { + videoStreamOptions_ = videoStreamOptions; +} + +void CudaDeviceInterface::registerHardwareDeviceWithCodec( + AVCodecContext* codecContext) { + TORCH_CHECK(ctx_, "FFmpeg HW device has not been initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + codecContext->hw_device_ctx = av_buffer_ref(ctx_.get()); +} + +UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( + UniqueAVFrame& avFrame) { // We need FFmpeg filters to handle those conversion cases which are not // directly implemented in CUDA or CPU device interface (in case of a // fallback). - enum AVPixelFormat frameFormat = - static_cast(avFrame->format); // Input frame is on CPU, we will just pass it to CPU device interface, so - // skipping filters context as CPU device interface will handle everythong for + // skipping filters context as CPU device interface will handle everything for // us. if (avFrame->format != AV_PIX_FMT_CUDA) { - return nullptr; + return std::move(avFrame); } if (avFrame->hw_frames_ctx == nullptr) { // TODONVDEC P2 return early for for beta interface where avFrames don't // have a hw_frames_ctx. We should get rid of this or improve the logic. - return nullptr; + return std::move(avFrame); } auto hwFramesCtx = reinterpret_cast(avFrame->hw_frames_ctx->data); + TORCH_CHECK( + hwFramesCtx != nullptr, + "The AVFrame does not have a hw_frames_ctx. " + "That's unexpected, please report this to the TorchCodec repo."); + AVPixelFormat actualFormat = hwFramesCtx->sw_format; - // NV12 conversion is implemented directly with NPP, no need for filters. + // If the frame is already in NV12 format, we don't need to do anything. if (actualFormat == AV_PIX_FMT_NV12) { - return nullptr; + return std::move(avFrame); } - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); - int height = frameDims.height; - int width = frameDims.width; - AVPixelFormat outputFormat; std::stringstream filters; @@ -264,100 +282,123 @@ std::unique_ptr CudaDeviceInterface::initializeFiltersContext( "That's unexpected, please report this to the TorchCodec repo."); filters << "hwdownload,format=" << actualFormatName; - filters << ",scale=" << width << ":" << height; - filters << ":sws_flags=bilinear"; } else { // Actual output color format will be set via filter options outputFormat = AV_PIX_FMT_CUDA; - filters << "scale_cuda=" << width << ":" << height; - filters << ":format=nv12:interp_algo=bilinear"; + filters << "scale_cuda=format=nv12:interp_algo=bilinear"; } - return std::make_unique( + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); + + auto newContext = std::make_unique( avFrame->width, avFrame->height, frameFormat, avFrame->sample_aspect_ratio, - width, - height, + avFrame->width, + avFrame->height, outputFormat, filters.str(), - timeBase, + timeBase_, av_buffer_ref(avFrame->hw_frames_ctx)); + + if (!nv12Conversion_ || *nv12ConversionContext_ != *newContext) { + nv12Conversion_ = + std::make_unique(*newContext, videoStreamOptions_); + nv12ConversionContext_ = std::move(newContext); + } + auto filteredAVFrame = nv12Conversion_->convert(avFrame); + + // If this check fails it means the frame wasn't + // reshaped to its expected dimensions by filtergraph. + TORCH_CHECK( + (filteredAVFrame->width == nv12ConversionContext_->outputWidth) && + (filteredAVFrame->height == nv12ConversionContext_->outputHeight), + "Expected frame from filter graph of ", + nv12ConversionContext_->outputWidth, + "x", + nv12ConversionContext_->outputHeight, + ", got ", + filteredAVFrame->width, + "x", + filteredAVFrame->height); + + return filteredAVFrame; } void CudaDeviceInterface::convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - [[maybe_unused]] const AVRational& timeBase, - UniqueAVFrame& avInputFrame, + UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) { - std::unique_ptr newFiltersContext = - initializeFiltersContext(videoStreamOptions, avInputFrame, timeBase); - UniqueAVFrame avFilteredFrame; - if (newFiltersContext) { - // We need to compare the current filter context with our previous filter - // context. If they are different, then we need to re-create a filter - // graph. We create a filter graph late so that we don't have to depend - // on the unreliable metadata in the header. And we sometimes re-create - // it because it's possible for frame resolution to change mid-stream. - // Finally, we want to reuse the filter graph as much as possible for - // performance reasons. - if (!filterGraph_ || *filtersContext_ != *newFiltersContext) { - filterGraph_ = - std::make_unique(*newFiltersContext, videoStreamOptions); - filtersContext_ = std::move(newFiltersContext); - } - avFilteredFrame = filterGraph_->convert(avInputFrame); + // Note that CUDA does not yet support transforms, so the only possible + // frame dimensions are the raw decoded frame's dimensions. + auto frameDims = FrameDims(avFrame->height, avFrame->width); - // If this check fails it means the frame wasn't - // reshaped to its expected dimensions by filtergraph. + if (preAllocatedOutputTensor.has_value()) { + auto shape = preAllocatedOutputTensor.value().sizes(); TORCH_CHECK( - (avFilteredFrame->width == filtersContext_->outputWidth) && - (avFilteredFrame->height == filtersContext_->outputHeight), - "Expected frame from filter graph of ", - filtersContext_->outputWidth, - "x", - filtersContext_->outputHeight, - ", got ", - avFilteredFrame->width, + (shape.size() == 3) && (shape[0] == frameDims.height) && + (shape[1] == frameDims.width) && (shape[2] == 3), + "Expected tensor of shape ", + frameDims.height, "x", - avFilteredFrame->height); + frameDims.width, + "x3, got ", + shape); } - UniqueAVFrame& avFrame = (avFilteredFrame) ? avFilteredFrame : avInputFrame; + // All of our CUDA decoding assumes NV12 format. We handle non-NV12 formats by + // converting them to NV12. + avFrame = maybeConvertAVFrameToNV12OrRGB24(avFrame); - // The filtered frame might be on CPU if CPU fallback has happenned on filter - // graph level. For example, that's how we handle color format conversion - // on FFmpeg 4.4 where scale_cuda did not have this supported implemented yet. if (avFrame->format != AV_PIX_FMT_CUDA) { // The frame's format is AV_PIX_FMT_CUDA if and only if its content is on - // the GPU. In this branch, the frame is on the CPU: this is what NVDEC - // gives us if it wasn't able to decode a frame, for whatever reason. - // Typically that happens if the video's encoder isn't supported by NVDEC. - // Below, we choose to convert the frame's color-space using the CPU - // codepath, and send it back to the GPU at the very end. - // TODO: A possibly better solution would be to send the frame to the GPU - // first, and do the color conversion there. - auto cpuDevice = torch::Device(torch::kCPU); - auto cpuInterface = createDeviceInterface(cpuDevice); + // the GPU. In this branch, the frame is on the CPU. There are two possible + // reasons: + // + // 1. During maybeConvertAVFrameToNV12OrRGB24(), we had a non-NV12 format + // frame and we're on FFmpeg 4.4 or earlier. In such cases, we had to + // use CPU filters and we just converted the frame to RGB24. + // 2. This is what NVDEC gave us if it wasn't able to decode a frame, for + // whatever reason. Typically that happens if the video's encoder isn't + // supported by NVDEC. + // + // In both cases, we have a frame on the CPU. We send the frame back to the + // CUDA device when we're done. + + enum AVPixelFormat frameFormat = + static_cast(avFrame->format); FrameOutput cpuFrameOutput; - cpuInterface->convertAVFrameToFrameOutput( - videoStreamOptions, - timeBase, - avFrame, - cpuFrameOutput, - preAllocatedOutputTensor); - - frameOutput.data = cpuFrameOutput.data.to(device_); + if (frameFormat == AV_PIX_FMT_RGB24) { + // Reason 1 above. The frame is already in RGB24, we just need to convert + // it to a tensor. + cpuFrameOutput.data = rgbAVFrameToTensor(avFrame); + } else { + // Reason 2 above. We need to do a full conversion which requires an + // actual CPU device. + cpuInterface_->convertAVFrameToFrameOutput(avFrame, cpuFrameOutput); + } + + // Finally, we need to send the frame back to the GPU. Note that the + // pre-allocated tensor is on the GPU, so we can't send that to the CPU + // device interface. We copy it over here. + if (preAllocatedOutputTensor.has_value()) { + preAllocatedOutputTensor.value().copy_(cpuFrameOutput.data); + frameOutput.data = preAllocatedOutputTensor.value(); + } else { + frameOutput.data = cpuFrameOutput.data.to(device_); + } + return; } // Above we checked that the AVFrame was on GPU, but that's not enough, we // also need to check that the AVFrame is in AV_PIX_FMT_NV12 format (8 bits), - // because this is what the NPP color conversion routines expect. + // because this is what the NPP color conversion routines expect. This SHOULD + // be enforced by our call to maybeConvertAVFrameToNV12OrRGB24() above. // TODONVDEC P2 this can be hit from the beta interface, but there's no // hw_frames_ctx in this case. We should try to understand how that affects // this validation. @@ -376,25 +417,11 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "That's unexpected, please report this to the TorchCodec repo."); } - auto frameDims = - getHeightAndWidthFromOptionsOrAVFrame(videoStreamOptions, avFrame); - int height = frameDims.height; - int width = frameDims.width; torch::Tensor& dst = frameOutput.data; if (preAllocatedOutputTensor.has_value()) { dst = preAllocatedOutputTensor.value(); - auto shape = dst.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == height) && (shape[1] == width) && - (shape[2] == 3), - "Expected tensor of shape ", - height, - "x", - width, - "x3, got ", - shape); } else { - dst = allocateEmptyHWCTensor(height, width, device_); + dst = allocateEmptyHWCTensor(frameDims, device_); } torch::DeviceIndex deviceIndex = getNonNegativeDeviceIndex(device_); @@ -418,6 +445,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "The AVFrame's hw_frames_ctx does not have a device_ctx. "); auto cudaDeviceCtx = static_cast(hwFramesCtx->device_ctx->hwctx); + TORCH_CHECK(cudaDeviceCtx != nullptr, "The hardware context is null"); at::cuda::CUDAEvent nvdecDoneEvent; at::cuda::CUDAStream nvdecStream = // That's always the default stream. Sad. c10::cuda::getStreamFromExternal(cudaDeviceCtx->stream, deviceIndex); @@ -435,7 +463,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( "cudaStreamGetFlags failed: ", cudaGetErrorString(err)); - NppiSize oSizeROI = {width, height}; + NppiSize oSizeROI = {frameDims.width, frameDims.height}; Npp8u* yuvData[2] = {avFrame->data[0], avFrame->data[1]}; NppStatus status; diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 678fc2f97..42d517a72 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -20,28 +20,43 @@ class CudaDeviceInterface : public DeviceInterface { std::optional findCodec(const AVCodecID& codecId) override; - void initializeContext(AVCodecContext* codecContext) override; + void initialize(const AVStream* avStream) override; - void convertAVFrameToFrameOutput( + void initializeVideo( const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, + [[maybe_unused]] const std::vector>& + transforms, + [[maybe_unused]] const std::optional& resizedOutputDims) + override; + + void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override; + + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) override; private: - std::unique_ptr initializeFiltersContext( - const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame, - const AVRational& timeBase); + // Our CUDA decoding code assumes NV12 format. In order to handle other + // kinds of input, we need to convert them to NV12. Our current implementation + // does this using filtergraph. + UniqueAVFrame maybeConvertAVFrameToNV12OrRGB24(UniqueAVFrame& avFrame); + + // We sometimes encounter frames that cannot be decoded on the CUDA device. + // Rather than erroring out, we decode them on the CPU. + std::unique_ptr cpuInterface_; + + VideoStreamOptions videoStreamOptions_; + AVRational timeBase_; UniqueAVBufferRef ctx_; std::unique_ptr nppCtx_; - // Current filter context. Used to know whether a new FilterGraph - // should be created to process the next frame. - std::unique_ptr filtersContext_; - std::unique_ptr filterGraph_; + + // This filtergraph instance is only used for NV12 format conversion in + // maybeConvertAVFrameToNV12(). + std::unique_ptr nv12ConversionContext_; + std::unique_ptr nv12Conversion_; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.cpp b/src/torchcodec/_core/DeviceInterface.cpp index f6c17f7b2..2f910e998 100644 --- a/src/torchcodec/_core/DeviceInterface.cpp +++ b/src/torchcodec/_core/DeviceInterface.cpp @@ -99,4 +99,19 @@ std::unique_ptr createDeviceInterface( "'"); } +torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame) { + TORCH_CHECK_EQ(avFrame->format, AV_PIX_FMT_RGB24); + + int height = avFrame->height; + int width = avFrame->width; + std::vector shape = {height, width, 3}; + std::vector strides = {avFrame->linesize[0], 3, 1}; + AVFrame* avFrameClone = av_frame_clone(avFrame.get()); + auto deleter = [avFrameClone](void*) { + UniqueAVFrame avFrameToDelete(avFrameClone); + }; + return torch::from_blob( + avFrameClone->data[0], shape, strides, deleter, {torch::kUInt8}); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index b5701f8ba..08d94fddc 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -14,6 +14,7 @@ #include "FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" #include "src/torchcodec/_core/StreamOptions.h" +#include "src/torchcodec/_core/Transform.h" namespace facebook::torchcodec { @@ -50,16 +51,25 @@ class DeviceInterface { return std::nullopt; }; - // Initialize the hardware device that is specified in `device`. Some builds - // support CUDA and others only support CPU. - virtual void initializeContext( + // Initialize the device with parameters generic to all kinds of decoding. + virtual void initialize(const AVStream* avStream) = 0; + + // Initialize the device with parameters specific to video decoding. There is + // a default empty implementation. + virtual void initializeVideo( + [[maybe_unused]] const VideoStreamOptions& videoStreamOptions, + [[maybe_unused]] const std::vector>& + transforms, + [[maybe_unused]] const std::optional& resizedOutputDims) {} + + // In order for decoding to actually happen on an FFmpeg managed hardware + // device, we need to register the DeviceInterface managed + // AVHardwareDeviceContext with the AVCodecContext. We don't need to do this + // on the CPU and if FFmpeg is not managing the hardware device. + virtual void registerHardwareDeviceWithCodec( [[maybe_unused]] AVCodecContext* codecContext) {} - virtual void initializeInterface([[maybe_unused]] AVStream* stream) {} - virtual void convertAVFrameToFrameOutput( - const VideoStreamOptions& videoStreamOptions, - const AVRational& timeBase, UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; @@ -69,7 +79,11 @@ class DeviceInterface { // ------------------------------------------ // Override to return true if this device interface can decode packets - // directly + // directly. This means that the following two member functions can both + // be called: + // + // 1. sendPacket() + // 2. receiveFrame() virtual bool canDecodePacketDirectly() const { return false; } @@ -121,4 +135,6 @@ std::unique_ptr createDeviceInterface( const torch::Device& device, const std::string_view variant = "default"); +torch::Tensor rgbAVFrameToTensor(const UniqueAVFrame& avFrame); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 748ebbc68..9a1f4ee87 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -605,8 +605,8 @@ void VideoEncoder::initializeEncoder( // Use specified dimensions or input dimensions // TODO-VideoEncoder: Allow height and width to be set - outWidth_ = videoStreamOptions.width.value_or(inWidth_); - outHeight_ = videoStreamOptions.height.value_or(inHeight_); + outWidth_ = inWidth_; + outHeight_ = inHeight_; // Use YUV420P as default output format // TODO-VideoEncoder: Enable other pixel formats diff --git a/src/torchcodec/_core/FilterGraph.cpp b/src/torchcodec/_core/FilterGraph.cpp index c22875915..70bc2b5dd 100644 --- a/src/torchcodec/_core/FilterGraph.cpp +++ b/src/torchcodec/_core/FilterGraph.cpp @@ -139,7 +139,8 @@ FilterGraph::FilterGraph( TORCH_CHECK( status >= 0, "Failed to parse filter description: ", - getFFMPEGErrorStringFromErrorCode(status)); + getFFMPEGErrorStringFromErrorCode(status), + ", provided filters: " + filtersContext.filtergraphStr); status = avfilter_graph_config(filterGraph_.get(), nullptr); TORCH_CHECK( diff --git a/src/torchcodec/_core/Frame.cpp b/src/torchcodec/_core/Frame.cpp index bc3bbb788..9fa87a1cb 100644 --- a/src/torchcodec/_core/Frame.cpp +++ b/src/torchcodec/_core/Frame.cpp @@ -8,24 +8,34 @@ namespace facebook::torchcodec { +FrameBatchOutput::FrameBatchOutput( + int64_t numFrames, + const FrameDims& outputDims, + const torch::Device& device) + : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), + durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { + data = allocateEmptyHWCTensor(outputDims, device, numFrames); +} + torch::Tensor allocateEmptyHWCTensor( - int height, - int width, - torch::Device device, + const FrameDims& frameDims, + const torch::Device& device, std::optional numFrames) { auto tensorOptions = torch::TensorOptions() .dtype(torch::kUInt8) .layout(torch::kStrided) .device(device); - TORCH_CHECK(height > 0, "height must be > 0, got: ", height); - TORCH_CHECK(width > 0, "width must be > 0, got: ", width); + TORCH_CHECK( + frameDims.height > 0, "height must be > 0, got: ", frameDims.height); + TORCH_CHECK(frameDims.width > 0, "width must be > 0, got: ", frameDims.width); if (numFrames.has_value()) { auto numFramesValue = numFrames.value(); TORCH_CHECK( numFramesValue >= 0, "numFrames must be >= 0, got: ", numFramesValue); - return torch::empty({numFramesValue, height, width, 3}, tensorOptions); + return torch::empty( + {numFramesValue, frameDims.height, frameDims.width, 3}, tensorOptions); } else { - return torch::empty({height, width, 3}, tensorOptions); + return torch::empty({frameDims.height, frameDims.width, 3}, tensorOptions); } } diff --git a/src/torchcodec/_core/Frame.h b/src/torchcodec/_core/Frame.h index 84ccc7288..4b27d5bdd 100644 --- a/src/torchcodec/_core/Frame.h +++ b/src/torchcodec/_core/Frame.h @@ -13,6 +13,15 @@ namespace facebook::torchcodec { +struct FrameDims { + int height = 0; + int width = 0; + + FrameDims() = default; + + FrameDims(int h, int w) : height(h), width(w) {} +}; + // All public video decoding entry points return either a FrameOutput or a // FrameBatchOutput. // They are the equivalent of the user-facing Frame and FrameBatch classes in @@ -34,10 +43,10 @@ struct FrameBatchOutput { torch::Tensor ptsSeconds; // 1D of shape (N,) torch::Tensor durationSeconds; // 1D of shape (N,) - explicit FrameBatchOutput( + FrameBatchOutput( int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); + const FrameDims& outputDims, + const torch::Device& device); }; struct AudioFramesOutput { @@ -49,70 +58,15 @@ struct AudioFramesOutput { // FRAME TENSOR ALLOCATION APIs // -------------------------------------------------------------------------- -// Note [Frame Tensor allocation and height and width] +// Note [Frame Tensor allocation] // // We always allocate [N]HWC tensors. The low-level decoding functions all // assume HWC tensors, since this is what FFmpeg natively handles. It's up to // the high-level decoding entry-points to permute that back to CHW, by calling // maybePermuteHWC2CHW(). -// -// Also, importantly, the way we figure out the the height and width of the -// output frame tensor varies, and depends on the decoding entry-point. In -// *decreasing order of accuracy*, we use the following sources for determining -// height and width: -// - getHeightAndWidthFromResizedAVFrame(). This is the height and width of the -// AVframe, *post*-resizing. This is only used for single-frame decoding APIs, -// on CPU, with filtergraph. -// - getHeightAndWidthFromOptionsOrAVFrame(). This is the height and width from -// the user-specified options if they exist, or the height and width of the -// AVFrame *before* it is resized. In theory, i.e. if there are no bugs within -// our code or within FFmpeg code, this should be exactly the same as -// getHeightAndWidthFromResizedAVFrame(). This is used by single-frame -// decoding APIs, on CPU with swscale, and on GPU. -// - getHeightAndWidthFromOptionsOrMetadata(). This is the height and width from -// the user-specified options if they exist, or the height and width form the -// stream metadata, which itself got its value from the CodecContext, when the -// stream was added. This is used by batch decoding APIs, for both GPU and -// CPU. -// -// The source of truth for height and width really is the (resized) AVFrame: it -// comes from the decoded ouptut of FFmpeg. The info from the metadata (i.e. -// from the CodecContext) may not be as accurate. However, the AVFrame is only -// available late in the call stack, when the frame is decoded, while the -// CodecContext is available early when a stream is added. This is why we use -// the CodecContext for pre-allocating batched output tensors (we could -// pre-allocate those only once we decode the first frame to get the info frame -// the AVFrame, but that's a more complex logic). -// -// Because the sources for height and width may disagree, we may end up with -// conflicts: e.g. if we pre-allocate a batch output tensor based on the -// metadata info, but the decoded AVFrame has a different height and width. -// it is very important to check the height and width assumptions where the -// tensors memory is used/filled in order to avoid segfaults. - -struct FrameDims { - int height; - int width; - - FrameDims(int h, int w) : height(h), width(w) {} -}; - -// There's nothing preventing you from calling this on a non-resized frame, but -// please don't. -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame); - -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata); - -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame); - torch::Tensor allocateEmptyHWCTensor( - int height, - int width, - torch::Device device, + const FrameDims& frameDims, + const torch::Device& device, std::optional numFrames = std::nullopt); } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index f9ca85e67..ace6cf84c 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -44,8 +44,8 @@ struct StreamMetadata { std::optional numFramesFromContent; // Video-only fields derived from the AVCodecContext. - std::optional width; - std::optional height; + std::optional width; + std::optional height; std::optional sampleAspectRatio; // Audio-only fields diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index fcb1d2d1b..62be222f6 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -428,8 +428,6 @@ void SingleStreamDecoder::addStream( streamInfo.stream = formatContext_->streams[activeStreamIndex_]; streamInfo.avMediaType = mediaType; - deviceInterface_ = createDeviceInterface(device, deviceVariant); - // This should never happen, checking just to be safe. TORCH_CHECK( streamInfo.stream->codecpar->codec_type == mediaType, @@ -437,14 +435,18 @@ void SingleStreamDecoder::addStream( activeStreamIndex_, " which is of the wrong media type."); + deviceInterface_ = createDeviceInterface(device, deviceVariant); + TORCH_CHECK( + deviceInterface_ != nullptr, + "Failed to create device interface. This should never happen, please report."); + deviceInterface_->initialize(streamInfo.stream); + // TODO_CODE_QUALITY it's pretty meh to have a video-specific logic within // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (deviceInterface_) { - avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) - .value_or(avCodec)); - } + avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( + deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) + .value_or(avCodec)); } AVCodecContext* codecContext = avcodec_alloc_context3(avCodec); @@ -458,14 +460,10 @@ void SingleStreamDecoder::addStream( streamInfo.codecContext->thread_count = ffmpegThreadCount.value_or(0); streamInfo.codecContext->pkt_timebase = streamInfo.stream->time_base; - // TODO_CODE_QUALITY same as above. - if (mediaType == AVMEDIA_TYPE_VIDEO) { - if (deviceInterface_) { - deviceInterface_->initializeContext(codecContext); - deviceInterface_->initializeInterface(streamInfo.stream); - } - } - + // Note that we must make sure to register the harware device context + // with the codec context before calling avcodec_open2(). Otherwise, decoding + // will happen on the CPU and not the hardware device. + deviceInterface_->registerHardwareDeviceWithCodec(codecContext); retVal = avcodec_open2(streamInfo.codecContext.get(), avCodec, nullptr); TORCH_CHECK(retVal >= AVSUCCESS, getFFMPEGErrorStringFromErrorCode(retVal)); @@ -487,8 +485,13 @@ void SingleStreamDecoder::addStream( void SingleStreamDecoder::addVideoStream( int streamIndex, + std::vector& transforms, const VideoStreamOptions& videoStreamOptions, std::optional customFrameMappings) { + TORCH_CHECK( + transforms.empty() || videoStreamOptions.device == torch::kCPU, + " Transforms are only supported for CPU devices."); + addStream( streamIndex, AVMEDIA_TYPE_VIDEO, @@ -522,6 +525,22 @@ void SingleStreamDecoder::addVideoStream( readCustomFrameMappingsUpdateMetadataAndIndex( activeStreamIndex_, customFrameMappings.value()); } + + metadataDims_ = + FrameDims(streamMetadata.height.value(), streamMetadata.width.value()); + for (auto& transform : transforms) { + TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); + if (transform->getOutputFrameDims().has_value()) { + resizedOutputDims_ = transform->getOutputFrameDims().value(); + } + + // Note that we are claiming ownership of the transform objects passed in to + // us. + transforms_.push_back(std::unique_ptr(transform)); + } + + deviceInterface_->initializeVideo( + videoStreamOptions, transforms_, resizedOutputDims_); } void SingleStreamDecoder::addAudioStream( @@ -637,12 +656,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( }); } - const auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - frameIndices.numel(), videoStreamOptions, streamMetadata); + frameIndices.numel(), + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); auto previousIndexInVideo = -1; for (int64_t f = 0; f < frameIndices.numel(); ++f) { @@ -685,8 +704,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); - // Note that if we do not have the number of frames available in our metadata, - // then we assume that the upper part of the range is valid. + // Note that if we do not have the number of frames available in our + // metadata, then we assume that the upper part of the range is valid. std::optional numFrames = getNumFrames(streamMetadata); if (numFrames.has_value()) { TORCH_CHECK( @@ -699,7 +718,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - numOutputFrames, videoStreamOptions, streamMetadata); + numOutputFrames, + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); for (int64_t i = start, f = 0; i < stop; i += step, ++f) { FrameOutput frameOutput = @@ -715,9 +736,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { validateActiveStream(AVMEDIA_TYPE_VIDEO); StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; double lastDecodedStartTime = - ptsToSeconds(streamInfo.lastDecodedAvFramePts, streamInfo.timeBase); + ptsToSeconds(lastDecodedAvFramePts_, streamInfo.timeBase); double lastDecodedEndTime = ptsToSeconds( - streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration, + lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_, streamInfo.timeBase); if (seconds >= lastDecodedStartTime && seconds < lastDecodedEndTime) { // We are in the same frame as the one we just returned. However, since we @@ -737,9 +758,9 @@ FrameOutput SingleStreamDecoder::getFramePlayedAt(double seconds) { // FFMPEG seeked past the frame we are looking for even though we // set max_ts to be our needed timestamp in avformat_seek_file() // in maybeSeekToBeforeDesiredPts(). - // This could be a bug in FFMPEG: https://trac.ffmpeg.org/ticket/11137 - // In this case we return the very next frame instead of throwing an - // exception. + // This could be a bug in FFMPEG: + // https://trac.ffmpeg.org/ticket/11137 In this case we return the + // very next frame instead of throwing an exception. // TODO: Maybe log to stderr for Debug builds? return true; } @@ -820,13 +841,16 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // interval B: [0.2, 0.15) // // Both intervals take place between the pts values for frame 0 and frame 1, - // which by our abstract player, means that both intervals map to frame 0. By - // the definition of a half open interval, interval A should return no frames. - // Interval B should return frame 0. However, for both A and B, the individual - // values of the intervals will map to the same frame indices below. Hence, we - // need this special case below. + // which by our abstract player, means that both intervals map to frame 0. + // By the definition of a half open interval, interval A should return no + // frames. Interval B should return frame 0. However, for both A and B, the + // individual values of the intervals will map to the same frame indices + // below. Hence, we need this special case below. if (startSeconds == stopSeconds) { - FrameBatchOutput frameBatchOutput(0, videoStreamOptions, streamMetadata); + FrameBatchOutput frameBatchOutput( + 0, + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); frameBatchOutput.data = maybePermuteHWC2CHW(frameBatchOutput.data); return frameBatchOutput; } @@ -838,8 +862,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( "; must be greater than or equal to " + std::to_string(minSeconds) + "."); - // Note that if we can't determine the maximum seconds from the metadata, then - // we assume upper range is valid. + // Note that if we can't determine the maximum seconds from the metadata, + // then we assume upper range is valid. std::optional maxSeconds = getMaxSeconds(streamMetadata); if (maxSeconds.has_value()) { TORCH_CHECK( @@ -871,7 +895,9 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( int64_t numFrames = stopFrameIndex - startFrameIndex; FrameBatchOutput frameBatchOutput( - numFrames, videoStreamOptions, streamMetadata); + numFrames, + resizedOutputDims_.value_or(metadataDims_), + videoStreamOptions.device); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { FrameOutput frameOutput = getFrameAtIndexInternal(i, frameBatchOutput.data[f]); @@ -892,25 +918,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // `numChannels` values. An audio frame, or a sequence thereof, is always // converted into a tensor of shape `(numChannels, numSamplesPerChannel)`. // -// The notion of 'frame' in audio isn't what users want to interact with. Users -// want to interact with samples. The C++ and core APIs return frames, because -// we want those to be close to FFmpeg concepts, but the higher-level public -// APIs expose samples. As a result: +// The notion of 'frame' in audio isn't what users want to interact with. +// Users want to interact with samples. The C++ and core APIs return frames, +// because we want those to be close to FFmpeg concepts, but the higher-level +// public APIs expose samples. As a result: // - We don't expose index-based APIs for audio, because that would mean -// exposing the concept of audio frame. For now, we think exposing time-based -// APIs is more natural. -// - We never perform a scan for audio streams. We don't need to, since we won't +// exposing the concept of audio frame. For now, we think exposing +// time-based APIs is more natural. +// - We never perform a scan for audio streams. We don't need to, since we +// won't // be converting timestamps to indices. That's why we enforce the seek_mode -// to be "approximate" (which is slightly misleading, because technically the -// output samples will be at their exact positions. But this incongruence is -// only exposed at the C++/core private levels). +// to be "approximate" (which is slightly misleading, because technically +// the output samples will be at their exact positions. But this +// incongruence is only exposed at the C++/core private levels). // // Audio frames are of variable dimensions: in the same stream, a frame can // contain 1024 samples and the next one may contain 512 [1]. This makes it // impossible to stack audio frames in the same way we can stack video frames. -// This is one of the main reasons we cannot reuse the same pre-allocation logic -// we have for videos in getFramesPlayedInRange(): pre-allocating a batch -// requires constant (and known) frame dimensions. That's also why +// This is one of the main reasons we cannot reuse the same pre-allocation +// logic we have for videos in getFramesPlayedInRange(): pre-allocating a +// batch requires constant (and known) frame dimensions. That's also why // *concatenated* along the samples dimension, not stacked. // // [IMPORTANT!] There is one key invariant that we must respect when decoding @@ -918,10 +945,10 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // // BEFORE DECODING FRAME i, WE MUST DECODE ALL FRAMES j < i. // -// Always. Why? We don't know. What we know is that if we don't, we get clipped, -// incorrect audio as output [2]. All other (correct) libraries like TorchAudio -// or Decord do something similar, whether it was intended or not. This has a -// few implications: +// Always. Why? We don't know. What we know is that if we don't, we get +// clipped, incorrect audio as output [2]. All other (correct) libraries like +// TorchAudio or Decord do something similar, whether it was intended or not. +// This has a few implications: // - The **only** place we're allowed to seek to in an audio stream is the // stream's beginning. This ensures that if we need a frame, we'll have // decoded all previous frames. @@ -929,8 +956,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( // call next() and `getFramesPlayedInRangeAudio()`, but they cannot manually // seek. // - We try not to seek, when we can avoid it. Typically if the next frame we -// need is in the future, we don't seek back to the beginning, we just decode -// all the frames in-between. +// need is in the future, we don't seek back to the beginning, we just +// decode all the frames in-between. // // [2] If you're brave and curious, you can read the long "Seek offset for // audio" note in https://github.com/pytorch/torchcodec/pull/507/files, which @@ -957,11 +984,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( } auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); - if (startPts < streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration) { - // If we need to seek backwards, then we have to seek back to the beginning - // of the stream. - // See [Audio Decoding Design]. + if (startPts < lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_) { + // If we need to seek backwards, then we have to seek back to the + // beginning of the stream. See [Audio Decoding Design]. setCursor(INT64_MIN); } @@ -995,9 +1020,9 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( // stop decoding more frames. Note that if we were to use [begin, end), // which may seem more natural, then we would decode the frame starting at // stopSeconds, which isn't what we want! - auto lastDecodedAvFrameEnd = streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration; - finished |= (streamInfo.lastDecodedAvFramePts) <= stopPts && + auto lastDecodedAvFrameEnd = + lastDecodedAvFramePts_ + lastDecodedAvFrameDuration_; + finished |= (lastDecodedAvFramePts_) <= stopPts && (stopPts <= lastDecodedAvFrameEnd); } @@ -1064,18 +1089,16 @@ I P P P I P P P I P P I P P I P bool SingleStreamDecoder::canWeAvoidSeeking() const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - // For audio, we only need to seek if a backwards seek was requested within - // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called. - // For more context, see [Audio Decoding Design] + // For audio, we only need to seek if a backwards seek was requested + // within getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was + // called. For more context, see [Audio Decoding Design] return !cursorWasJustSet_; } - int64_t lastDecodedAvFramePts = - streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; - if (cursor_ < lastDecodedAvFramePts) { + if (cursor_ < lastDecodedAvFramePts_) { // We can never skip a seek if we are seeking backwards. return false; } - if (lastDecodedAvFramePts == cursor_) { + if (lastDecodedAvFramePts_ == cursor_) { // We are seeking to the exact same frame as we are currently at. Without // caching we have to rewind back and decode the frame again. // TODO: https://github.com/pytorch/torchcodec/issues/84 we could @@ -1085,7 +1108,7 @@ bool SingleStreamDecoder::canWeAvoidSeeking() const { // We are seeking forwards. // We can only skip a seek if both lastDecodedAvFramePts and // cursor_ share the same keyframe. - int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts); + int lastDecodedAvFrameIndex = getKeyFrameIndexForPts(lastDecodedAvFramePts_); int targetKeyFrameIndex = getKeyFrameIndexForPts(cursor_); return lastDecodedAvFrameIndex >= 0 && targetKeyFrameIndex >= 0 && lastDecodedAvFrameIndex == targetKeyFrameIndex; @@ -1135,9 +1158,7 @@ void SingleStreamDecoder::maybeSeekToBeforeDesiredPts() { decodeStats_.numFlushes++; avcodec_flush_buffers(streamInfo.codecContext.get()); - if (deviceInterface_) { - deviceInterface_->flush(); - } + deviceInterface_->flush(); } // -------------------------------------------------------------------------- @@ -1161,16 +1182,13 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( int status = AVSUCCESS; bool reachedEOF = false; - // TODONVDEC P2: Instead of defining useCustomInterface and rely on if/else - // blocks to dispatch to the interface or to FFmpeg, consider *always* + // TODONVDEC P2: Instead of calling canDecodePacketDirectly() and rely on + // if/else blocks to dispatch to the interface or to FFmpeg, consider *always* // dispatching to the interface. The default implementation of the interface's // receiveFrame and sendPacket could just be calling avcodec_receive_frame and // avcodec_send_packet. This would make the decoding loop even more generic. - bool useCustomInterface = - deviceInterface_ && deviceInterface_->canDecodePacketDirectly(); - while (true) { - if (useCustomInterface) { + if (deviceInterface_->canDecodePacketDirectly()) { status = deviceInterface_->receiveFrame(avFrame, cursor_); } else { status = @@ -1211,7 +1229,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( if (status == AVERROR_EOF) { // End of file reached. We must drain the decoder - if (useCustomInterface) { + if (deviceInterface_->canDecodePacketDirectly()) { // TODONVDEC P0: Re-think this. This should be simpler. AutoAVPacket eofAutoPacket; ReferenceAVPacket eofPacket(eofAutoPacket); @@ -1247,7 +1265,7 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( // We got a valid packet. Send it to the decoder, and we'll receive it in // the next iteration. - if (useCustomInterface) { + if (deviceInterface_->canDecodePacketDirectly()) { status = deviceInterface_->sendPacket(packet); } else { status = avcodec_send_packet(streamInfo.codecContext.get(), packet.get()); @@ -1272,14 +1290,15 @@ UniqueAVFrame SingleStreamDecoder::decodeAVFrame( getFFMPEGErrorStringFromErrorCode(status)); } - // Note that we don't flush the decoder when we reach EOF (even though that's - // mentioned in https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). - // This is because we may have packets internally in the decoder that we - // haven't received as frames. Eventually we will either hit AVERROR_EOF from - // av_receive_frame() or the user will have seeked to a different location in - // the file and that will flush the decoder. - streamInfo.lastDecodedAvFramePts = getPtsOrDts(avFrame); - streamInfo.lastDecodedAvFrameDuration = getDuration(avFrame); + // Note that we don't flush the decoder when we reach EOF (even though + // that's mentioned in + // https://ffmpeg.org/doxygen/trunk/group__lavc__encdec.html). This is + // because we may have packets internally in the decoder that we haven't + // received as frames. Eventually we will either hit AVERROR_EOF from + // av_receive_frame() or the user will have seeked to a different location + // in the file and that will flush the decoder. + lastDecodedAvFramePts_ = getPtsOrDts(avFrame); + lastDecodedAvFrameDuration_ = getDuration(avFrame); return avFrame; } @@ -1303,16 +1322,8 @@ FrameOutput SingleStreamDecoder::convertAVFrameToFrameOutput( if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { convertAudioAVFrameToFrameOutputOnCPU(avFrame, frameOutput); } else { - TORCH_CHECK( - deviceInterface_ != nullptr, - "No device interface available for video decoding. This ", - "shouldn't happen, please report."); deviceInterface_->convertAVFrameToFrameOutput( - streamInfo.videoStreamOptions, - streamInfo.timeBase, - avFrame, - frameOutput, - preAllocatedOutputTensor); + avFrame, frameOutput, preAllocatedOutputTensor); } return frameOutput; } @@ -1348,8 +1359,8 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( UniqueAVFrame convertedAVFrame; if (mustConvert) { - if (!streamInfo.swrContext) { - streamInfo.swrContext.reset(createSwrContext( + if (!swrContext_) { + swrContext_.reset(createSwrContext( srcSampleFormat, outSampleFormat, srcSampleRate, @@ -1359,7 +1370,7 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( } convertedAVFrame = convertAudioAVFrameSamples( - streamInfo.swrContext, + swrContext_, srcAVFrame, outSampleFormat, outSampleRate, @@ -1407,15 +1418,15 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // When sample rate conversion is involved, swresample buffers some of the // samples in-between calls to swr_convert (see the libswresample docs). - // That's because the last few samples in a given frame require future samples - // from the next frame to be properly converted. This function flushes out the - // samples that are stored in swresample's buffers. + // That's because the last few samples in a given frame require future + // samples from the next frame to be properly converted. This function + // flushes out the samples that are stored in swresample's buffers. auto& streamInfo = streamInfos_[activeStreamIndex_]; - if (!streamInfo.swrContext) { + if (!swrContext_) { return std::nullopt; } auto numRemainingSamples = // this is an upper bound - swr_get_out_samples(streamInfo.swrContext.get(), 0); + swr_get_out_samples(swrContext_.get(), 0); if (numRemainingSamples == 0) { return std::nullopt; @@ -1432,11 +1443,7 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { } auto actualNumRemainingSamples = swr_convert( - streamInfo.swrContext.get(), - outputBuffers.data(), - numRemainingSamples, - nullptr, - 0); + swrContext_.get(), outputBuffers.data(), numRemainingSamples, nullptr, 0); return lastSamples.narrow( /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); @@ -1446,25 +1453,10 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // OUTPUT ALLOCATION AND SHAPE CONVERSION // -------------------------------------------------------------------------- -FrameBatchOutput::FrameBatchOutput( - int64_t numFrames, - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata) - : ptsSeconds(torch::empty({numFrames}, {torch::kFloat64})), - durationSeconds(torch::empty({numFrames}, {torch::kFloat64})) { - auto frameDims = getHeightAndWidthFromOptionsOrMetadata( - videoStreamOptions, streamMetadata); - int height = frameDims.height; - int width = frameDims.width; - data = allocateEmptyHWCTensor( - height, width, videoStreamOptions.device, numFrames); -} - -// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require so. -// The [N] leading batch-dimension is optional i.e. the input tensor can be 3D -// or 4D. -// Calling permute() is guaranteed to return a view as per the docs: -// https://pytorch.org/docs/stable/generated/torch.permute.html +// Returns a [N]CHW *view* of a [N]HWC input tensor, if the options require +// so. The [N] leading batch-dimension is optional i.e. the input tensor can +// be 3D or 4D. Calling permute() is guaranteed to return a view as per the +// docs: https://pytorch.org/docs/stable/generated/torch.permute.html torch::Tensor SingleStreamDecoder::maybePermuteHWC2CHW( torch::Tensor& hwcTensor) { if (streamInfos_[activeStreamIndex_].videoStreamOptions.dimensionOrder == @@ -1684,8 +1676,8 @@ void SingleStreamDecoder::validateFrameIndex( "and the number of frames must be known."); } - // Note that if we do not have the number of frames available in our metadata, - // then we assume that the frameIndex is valid. + // Note that if we do not have the number of frames available in our + // metadata, then we assume that the frameIndex is valid. std::optional numFrames = getNumFrames(streamMetadata); if (numFrames.has_value()) { if (frameIndex >= numFrames.value()) { @@ -1736,28 +1728,4 @@ double SingleStreamDecoder::getPtsSecondsForFrame(int64_t frameIndex) { streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); } -// -------------------------------------------------------------------------- -// FrameDims APIs -// -------------------------------------------------------------------------- - -FrameDims getHeightAndWidthFromResizedAVFrame(const AVFrame& resizedAVFrame) { - return FrameDims(resizedAVFrame.height, resizedAVFrame.width); -} - -FrameDims getHeightAndWidthFromOptionsOrMetadata( - const VideoStreamOptions& videoStreamOptions, - const StreamMetadata& streamMetadata) { - return FrameDims( - videoStreamOptions.height.value_or(*streamMetadata.height), - videoStreamOptions.width.value_or(*streamMetadata.width)); -} - -FrameDims getHeightAndWidthFromOptionsOrAVFrame( - const VideoStreamOptions& videoStreamOptions, - const UniqueAVFrame& avFrame) { - return FrameDims( - videoStreamOptions.height.value_or(avFrame->height), - videoStreamOptions.width.value_or(avFrame->width)); -} - } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index aee48440d..1af1fabd7 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -17,6 +17,7 @@ #include "src/torchcodec/_core/FFMPEGCommon.h" #include "src/torchcodec/_core/Frame.h" #include "src/torchcodec/_core/StreamOptions.h" +#include "src/torchcodec/_core/Transform.h" namespace facebook::torchcodec { @@ -83,6 +84,7 @@ class SingleStreamDecoder { void addVideoStream( int streamIndex, + std::vector& transforms, const VideoStreamOptions& videoStreamOptions = VideoStreamOptions(), std::optional customFrameMappings = std::nullopt); void addAudioStream( @@ -226,17 +228,8 @@ class SingleStreamDecoder { std::vector keyFrames; std::vector allFrames; - // TODO since the decoder is single-stream, these should be decoder fields, - // not streamInfo fields. And they should be defined right next to - // `cursor_`, with joint documentation. - int64_t lastDecodedAvFramePts = 0; - int64_t lastDecodedAvFrameDuration = 0; VideoStreamOptions videoStreamOptions; AudioStreamOptions audioStreamOptions; - - // color-conversion fields. Only one of FilterGraphContext and - // UniqueSwsContext should be non-null. - UniqueSwrContext swrContext; }; // -------------------------------------------------------------------------- @@ -357,16 +350,49 @@ class SingleStreamDecoder { const int NO_ACTIVE_STREAM = -2; int activeStreamIndex_ = NO_ACTIVE_STREAM; - bool cursorWasJustSet_ = false; // The desired position of the cursor in the stream. We send frames >= this // pts to the user when they request a frame. int64_t cursor_ = INT64_MIN; + bool cursorWasJustSet_ = false; + int64_t lastDecodedAvFramePts_ = 0; + int64_t lastDecodedAvFrameDuration_ = 0; + + // Audio only. We cache it for performance. The video equivalents live in + // deviceInterface_. We store swrContext_ here because we only handle audio + // on the CPU. + UniqueSwrContext swrContext_; + // Stores various internal decoding stats. DecodeStats decodeStats_; + // Stores the AVIOContext for the input buffer. std::unique_ptr avioContextHolder_; + + // We will receive a vector of transforms upon adding a stream and store it + // here. However, we need to know if any of those operations change the + // dimensions of the output frame. If they do, we need to figure out what are + // the final dimensions of the output frame after ALL transformations. We + // figure this out as soon as we receive the transforms. If any of the + // transforms change the final output frame dimensions, we store that in + // resizedOutputDims_. If resizedOutputDims_ has no value, that means there + // are no transforms that change the output frame dimensions. + // + // The priority order for output frame dimension is: + // + // 1. resizedOutputDims_; the resize requested by the user always takes + // priority. + // 2. The dimemnsions of the actual decoded AVFrame. This can change + // per-decoded frame, and is unknown in SingleStreamDecoder. Only the + // DeviceInterface learns it immediately after decoding a raw frame but + // before the color transformation. + // 3. metdataDims_; the dimensions we learned from the metadata. + std::vector> transforms_; + std::optional resizedOutputDims_; + FrameDims metadataDims_; + // Whether or not we have already scanned all streams to update the metadata. bool scannedAllStreams_ = false; + // Tracks that we've already been initialized. bool initialized_ = false; }; diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 65f2782a8..9b02cceca 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -14,7 +14,6 @@ namespace facebook::torchcodec { enum ColorConversionLibrary { - // TODO: Add an AUTO option later. // Use the libavfilter library for color conversion. FILTERGRAPH, // Use the libswscale library for color conversion. @@ -29,14 +28,17 @@ struct VideoStreamOptions { // utilize all cores. If not set, it will be the default FFMPEG behavior for // the given codec. std::optional ffmpegThreadCount; + // Currently the dimension order can be either NHWC or NCHW. // H=height, W=width, C=channel. std::string dimensionOrder = "NCHW"; - // The output height and width of the frame. If not specified, the output - // is the same as the original video. - std::optional width; - std::optional height; - std::optional colorConversionLibrary; + + // By default we have to use filtergraph, as it is more general. We can only + // use swscale when we have met strict requirements. See + // CpuDeviceInterface::initialze() for the logic. + ColorConversionLibrary colorConversionLibrary = + ColorConversionLibrary::FILTERGRAPH; + // By default we use CPU for decoding for both C++ and python users. torch::Device device = torch::kCPU; // Device variant (e.g., "default", "beta", etc.) diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp new file mode 100644 index 000000000..d0a5104f3 --- /dev/null +++ b/src/torchcodec/_core/Transform.cpp @@ -0,0 +1,60 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/_core/Transform.h" +#include +#include "src/torchcodec/_core/FFMPEGCommon.h" + +namespace facebook::torchcodec { + +namespace { + +std::string toFilterGraphInterpolation( + ResizeTransform::InterpolationMode mode) { + switch (mode) { + case ResizeTransform::InterpolationMode::BILINEAR: + return "bilinear"; + default: + TORCH_CHECK( + false, + "Unknown interpolation mode: " + + std::to_string(static_cast(mode))); + } +} + +int toSwsInterpolation(ResizeTransform::InterpolationMode mode) { + switch (mode) { + case ResizeTransform::InterpolationMode::BILINEAR: + return SWS_BILINEAR; + default: + TORCH_CHECK( + false, + "Unknown interpolation mode: " + + std::to_string(static_cast(mode))); + } +} + +} // namespace + +std::string ResizeTransform::getFilterGraphCpu() const { + return "scale=" + std::to_string(outputDims_.width) + ":" + + std::to_string(outputDims_.height) + + ":sws_flags=" + toFilterGraphInterpolation(interpolationMode_); +} + +std::optional ResizeTransform::getOutputFrameDims() const { + return outputDims_; +} + +bool ResizeTransform::isResize() const { + return true; +} + +int ResizeTransform::getSwsFlags() const { + return toSwsInterpolation(interpolationMode_); +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h new file mode 100644 index 000000000..6aea255ab --- /dev/null +++ b/src/torchcodec/_core/Transform.h @@ -0,0 +1,59 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include "src/torchcodec/_core/Frame.h" + +namespace facebook::torchcodec { + +class Transform { + public: + virtual std::string getFilterGraphCpu() const = 0; + virtual ~Transform() = default; + + // If the transformation does not change the output frame dimensions, then + // there is no need to override this member function. The default + // implementation returns an empty optional, indicating that the output frame + // has the same dimensions as the input frame. + // + // If the transformation does change the output frame dimensions, then it + // must override this member function and return the output frame dimensions. + virtual std::optional getOutputFrameDims() const { + return std::nullopt; + } + + // The ResizeTransform is special, because it is the only transform that + // swscale can handle. + virtual bool isResize() const { + return false; + } +}; + +class ResizeTransform : public Transform { + public: + enum class InterpolationMode { BILINEAR }; + + ResizeTransform(const FrameDims& dims) + : outputDims_(dims), interpolationMode_(InterpolationMode::BILINEAR) {} + + ResizeTransform(const FrameDims& dims, InterpolationMode interpolationMode) + : outputDims_(dims), interpolationMode_(interpolationMode) {} + + std::string getFilterGraphCpu() const override; + std::optional getOutputFrameDims() const override; + bool isResize() const override; + + int getSwsFlags() const; + + private: + FrameDims outputDims_; + InterpolationMode interpolationMode_; +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 699685641..c72f719a3 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -263,10 +263,25 @@ void _add_video_stream( custom_frame_mappings = std::nullopt, std::optional color_conversion_library = std::nullopt) { VideoStreamOptions videoStreamOptions; - videoStreamOptions.width = width; - videoStreamOptions.height = height; videoStreamOptions.ffmpegThreadCount = num_threads; + // TODO: Eliminate this temporary bridge code. This exists because we have + // not yet exposed the transforms API on the Python side. We also want + // to remove the `width` and `height` arguments from the Python API. + // + // TEMPORARY BRIDGE CODE START + TORCH_CHECK( + width.has_value() == height.has_value(), + "width and height must both be set or unset."); + std::vector transforms; + if (width.has_value()) { + transforms.push_back( + new ResizeTransform(FrameDims(height.value(), width.value()))); + width.reset(); + height.reset(); + } + // TEMPORARY BRIDGE CODE END + if (dimension_order.has_value()) { std::string stdDimensionOrder{dimension_order.value()}; TORCH_CHECK(stdDimensionOrder == "NHWC" || stdDimensionOrder == "NCHW"); @@ -300,7 +315,10 @@ void _add_video_stream( : std::nullopt; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addVideoStream( - stream_index.value_or(-1), videoStreamOptions, converted_mappings); + stream_index.value_or(-1), + transforms, + videoStreamOptions, + converted_mappings); } // Add a new video stream at `stream_index` using the provided options. diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 241a638b4..1481d3a2a 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -146,33 +146,14 @@ double computeAverageCosineSimilarity( return averageCosineSimilarity; } -// TEST(DecoderOptionsTest, ConvertsFromStringToOptions) { -// std::string optionsString = -// "ffmpeg_thread_count=3,dimension_order=NCHW,width=100,height=120"; -// SingleStreamDecoder::DecoderOptions options = -// SingleStreamDecoder::DecoderOptions(optionsString); -// EXPECT_EQ(options.ffmpegThreadCount, 3); -// } - -TEST(SingleStreamDecoderTest, RespectsWidthAndHeightFromOptions) { - std::string path = getResourcePath("nasa_13013.mp4"); - std::unique_ptr decoder = - std::make_unique(path); - VideoStreamOptions videoStreamOptions; - videoStreamOptions.width = 100; - videoStreamOptions.height = 120; - decoder->addVideoStream(-1, videoStreamOptions); - torch::Tensor tensor = decoder->getNextFrame().data; - EXPECT_EQ(tensor.sizes(), std::vector({3, 120, 100})); -} - TEST(SingleStreamDecoderTest, RespectsOutputTensorDimensionOrderFromOptions) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr decoder = std::make_unique(path); VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; - decoder->addVideoStream(-1, videoStreamOptions); + std::vector transforms; + decoder->addVideoStream(-1, transforms, videoStreamOptions); torch::Tensor tensor = decoder->getNextFrame().data; EXPECT_EQ(tensor.sizes(), std::vector({270, 480, 3})); } @@ -181,7 +162,8 @@ TEST_P(SingleStreamDecoderTest, ReturnsFirstTwoFramesOfVideo) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); auto output = ourDecoder->getNextFrame(); torch::Tensor tensor0FromOurDecoder = output.data; EXPECT_EQ(tensor0FromOurDecoder.sizes(), std::vector({3, 270, 480})); @@ -220,7 +202,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) { ourDecoder->scanFileAndUpdateMetadataAndIndex(); int bestVideoStreamIndex = *ourDecoder->getContainerMetadata().bestVideoStreamIndex; - ourDecoder->addVideoStream(bestVideoStreamIndex); + std::vector transforms; + ourDecoder->addVideoStream(bestVideoStreamIndex, transforms); // Frame with index 180 corresponds to timestamp 6.006. auto frameIndices = torch::tensor({0, 180}); auto output = ourDecoder->getFramesAtIndices(frameIndices); @@ -245,7 +228,9 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; VideoStreamOptions videoStreamOptions; videoStreamOptions.dimensionOrder = "NHWC"; - ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); + std::vector transforms; + ourDecoder->addVideoStream( + bestVideoStreamIndex, transforms, videoStreamOptions); // Frame with index 180 corresponds to timestamp 6.006. auto frameIndices = torch::tensor({0, 180}); auto output = ourDecoder->getFramesAtIndices(frameIndices); @@ -266,7 +251,8 @@ TEST_P(SingleStreamDecoderTest, SeeksCloseToEof) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); ourDecoder->setCursorPtsInSeconds(388388. / 30'000); auto output = ourDecoder->getNextFrame(); EXPECT_EQ(output.ptsSeconds, 388'388. / 30'000); @@ -279,7 +265,8 @@ TEST_P(SingleStreamDecoderTest, GetsFramePlayedAtTimestamp) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); auto output = ourDecoder->getFramePlayedAt(6.006); EXPECT_EQ(output.ptsSeconds, 6.006); // The frame's duration is 0.033367 according to ffprobe, @@ -309,7 +296,8 @@ TEST_P(SingleStreamDecoderTest, SeeksToFrameWithSpecificPts) { std::string path = getResourcePath("nasa_13013.mp4"); std::unique_ptr ourDecoder = createDecoderFromPath(path, GetParam()); - ourDecoder->addVideoStream(-1); + std::vector transforms; + ourDecoder->addVideoStream(-1, transforms); ourDecoder->setCursorPtsInSeconds(6.0); auto output = ourDecoder->getNextFrame(); torch::Tensor tensor6FromOurDecoder = output.data; @@ -412,7 +400,9 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorFilterGraph) { VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH; - ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); + std::vector transforms; + ourDecoder->addVideoStream( + bestVideoStreamIndex, transforms, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); @@ -429,7 +419,9 @@ TEST_P(SingleStreamDecoderTest, PreAllocatedTensorSwscale) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; VideoStreamOptions videoStreamOptions; videoStreamOptions.colorConversionLibrary = ColorConversionLibrary::SWSCALE; - ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); + std::vector transforms; + ourDecoder->addVideoStream( + bestVideoStreamIndex, transforms, videoStreamOptions); auto output = ourDecoder->getFrameAtIndexInternal(0, preAllocatedOutputTensor); EXPECT_EQ(output.data.data_ptr(), preAllocatedOutputTensor.data_ptr()); diff --git a/test/test_ops.py b/test/test_ops.py index 715687afe..b50aec88b 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -626,6 +626,16 @@ def test_color_conversion_library_with_scaling( ) swscale_frame0, _, _ = get_next_frame(swscale_decoder) assert_frames_equal(filtergraph_frame0, swscale_frame0) + assert filtergraph_frame0.shape == (3, target_height, target_width) + + @needs_cuda + def test_scaling_on_cuda_fails(self): + decoder = create_from_file(str(NASA_VIDEO.path)) + with pytest.raises( + RuntimeError, + match="Transforms are only supported for CPU devices.", + ): + add_video_stream(decoder, device="cuda", width=100, height=100) @pytest.mark.parametrize("dimension_order", ("NHWC", "NCHW")) @pytest.mark.parametrize("color_conversion_library", ("filtergraph", "swscale"))