diff --git a/src/torchcodec/decoders/_core/CMakeLists.txt b/src/torchcodec/decoders/_core/CMakeLists.txt index ed8e8ef36..a36b94089 100644 --- a/src/torchcodec/decoders/_core/CMakeLists.txt +++ b/src/torchcodec/decoders/_core/CMakeLists.txt @@ -32,6 +32,8 @@ find_package(Python3 ${PYTHON_VERSION} EXACT COMPONENTS Development) function(make_torchcodec_library library_name ffmpeg_target) set( sources + CUDACommon.h + CUDACommon.cpp FFMPEGCommon.h FFMPEGCommon.cpp VideoDecoder.h diff --git a/src/torchcodec/decoders/_core/CUDACommon.cpp b/src/torchcodec/decoders/_core/CUDACommon.cpp new file mode 100644 index 000000000..f0582293a --- /dev/null +++ b/src/torchcodec/decoders/_core/CUDACommon.cpp @@ -0,0 +1,146 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "src/torchcodec/decoders/_core/CUDACommon.h" + +// This source file is organized in the following way: +// +// | +// | +// | +// | +// | +// | +// | +// | +// +// If code needs to access definitions in the CUDA specific includes, then it is CUDA +// specific code, and belongs inside of the guard. If that behavior needs to be +// accessible to general code, then it should be added to the API for general code. + +#ifdef ENABLE_CUDA + +#include +#include +#include + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +AVBufferRef* getCudaContext() { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + int err = 0; + AVBufferRef* hw_device_ctx; + err = av_hwdevice_ctx_create( + &hw_device_ctx, + type, + nullptr, + nullptr, + // Introduced in 58.26.100: + // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 +#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) + AV_CUDA_USE_CURRENT_CONTEXT +#else + 0 +#endif + ); + if (err < 0) { + TORCH_CHECK( + false, + "Failed to create specified HW device", + getFFMPEGErrorStringFromErrorCode(err)); + } + return hw_device_ctx; +} + +torch::Tensor allocateDeviceTensor( + at::IntArrayRef shape, + torch::Device device, + const torch::Dtype dtype = torch::kUInt8) { + return torch::empty( + shape, + torch::TensorOptions() + .dtype(dtype) + .layout(torch::kStrided) + .device(device)); +} + +} // namespace + +} // facebook::torchcodec + +#endif // ENABLE_CUDA + +// API implementations for general code to access CUDA specific behaviors. + +namespace facebook::torchcodec { + +AVBufferRef* initializeCudaContext(const torch::Device& device) { +#ifdef ENABLE_CUDA + TORCH_CHECK(device.type() == torch::DeviceType::CUDA, "Invalid device type."); + + // We create a small tensor using pytorch to initialize the cuda context. + torch::Tensor dummyTensorForCudaInitialization = torch::zeros( + {1}, + torch::TensorOptions().dtype(torch::kUInt8).device(device)); + return av_buffer_ref(getCudaContext()); +#else + throw std::runtime_error( + "CUDA support is not enabled in this build of TorchCodec."); +#endif +} + +torch::Tensor convertFrameToTensorUsingCuda( + const AVCodecContext* codecContext, + const VideoDecoder::VideoStreamDecoderOptions& options, + const AVFrame* src) { +#ifdef ENABLE_CUDA + NVTX_SCOPED_RANGE("convertFrameUsingCuda"); + TORCH_CHECK( + src->format == AV_PIX_FMT_CUDA, + "Expected format to be AV_PIX_FMT_CUDA, got " + + std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); + int width = options.width.value_or(codecContext->width); + int height = options.height.value_or(codecContext->height); + NppStatus status; + NppiSize oSizeROI; + oSizeROI.width = width; + oSizeROI.height = height; + Npp8u* input[2]; + input[0] = (Npp8u*)src->data[0]; + input[1] = (Npp8u*)src->data[1]; + torch::Tensor dst = allocateDeviceTensor({height, width, 3}, options.device); + auto start = std::chrono::high_resolution_clock::now(); + status = nppiNV12ToRGB_8u_P2C3R( + input, + src->linesize[0], + static_cast(dst.data_ptr()), + dst.stride(0), + oSizeROI); + TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); + auto end = std::chrono::high_resolution_clock::now(); + std::chrono::duration duration = end - start; + VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width + << " took: " << duration.count() << "us" << std::endl; + if (options.dimensionOrder == "NCHW") { + // The docs guaranty this to return a view: + // https://pytorch.org/docs/stable/generated/torch.permute.html + dst = dst.permute({2, 0, 1}); + } + return dst; +#else + throw std::runtime_error( + "CUDA support is not enabled in this build of TorchCodec."); +#endif +} + +} // facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/CUDACommon.h b/src/torchcodec/decoders/_core/CUDACommon.h new file mode 100644 index 000000000..eba388924 --- /dev/null +++ b/src/torchcodec/decoders/_core/CUDACommon.h @@ -0,0 +1,41 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include "src/torchcodec/decoders/_core/VideoDecoder.h" +#include +#include + +#ifdef ENABLE_NVTX +#include +#endif + +// The API for general code to access CUDA specific behaviors. CUDA specific behaviors +// require CUDA specific definitions which are only available on systems with CUDA +// installed. Hence, CUDA specific behaviors have to be guarded with ifdefs. +// +// In order to prevent ifdefs in general code, we create an API with a function for +// each behavior we need. General code can call the API, as the correct guards happen +// internally. General code still needs to check in general code if CUDA is being used, +// as the functions will throw an exception if CUDA is not available. + +namespace facebook::torchcodec { + +#ifdef ENABLE_NVTX +#define NVTX_SCOPED_RANGE(Annotation) nvtx3::scoped_range loop{Annotation} +#else +#define NVTX_SCOPED_RANGE(Annotation) do {} while (0) +#endif + +AVBufferRef* initializeCudaContext(const torch::Device& device); + +torch::Tensor convertFrameToTensorUsingCuda( + const AVCodecContext* codecContext, + const VideoDecoder::VideoStreamDecoderOptions& options, + const AVFrame* src); + +} // facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index a5c0fddfb..389b34d6f 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -5,6 +5,7 @@ // LICENSE file in the root directory of this source tree. #include "src/torchcodec/decoders/_core/VideoDecoder.h" +#include "src/torchcodec/decoders/_core/CUDACommon.h" #include #include #include @@ -13,25 +14,12 @@ #include #include -#ifdef ENABLE_CUDA -#include -#include -#include -#ifdef ENABLE_NVTX -#include -#endif -#endif - extern "C" { #include #include #include #include #include -#include -#ifdef ENABLE_CUDA -#include -#endif } namespace facebook::torchcodec { @@ -107,87 +95,6 @@ std::vector splitStringWithDelimiters( return result; } -#ifdef ENABLE_CUDA - -AVBufferRef* getCudaContext() { - enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); - TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); - int err = 0; - AVBufferRef* hw_device_ctx; - err = av_hwdevice_ctx_create( - &hw_device_ctx, - type, - nullptr, - nullptr, - // Introduced in 58.26.100: - // https://github.com/FFmpeg/FFmpeg/blob/4acb9b7d1046944345ae506165fb55883d04d8a6/doc/APIchanges#L265 -#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) - AV_CUDA_USE_CURRENT_CONTEXT -#else - 0 -#endif - ); - if (err < 0) { - TORCH_CHECK( - false, - "Failed to create specified HW device", - getFFMPEGErrorStringFromErrorCode(err)); - } - return hw_device_ctx; -} - -torch::Tensor allocateDeviceTensor( - at::IntArrayRef shape, - torch::Device device, - const torch::Dtype dtype = torch::kUInt8) { - return torch::empty( - shape, - torch::TensorOptions() - .dtype(dtype) - .layout(torch::kStrided) - .device(device)); -} - -torch::Tensor convertFrameToTensorUsingCUDA( - const AVCodecContext* codecContext, - const VideoDecoder::VideoStreamDecoderOptions& options, - const AVFrame* src) { - TORCH_CHECK( - src->format == AV_PIX_FMT_CUDA, - "Expected format to be AV_PIX_FMT_CUDA, got " + - std::string(av_get_pix_fmt_name((AVPixelFormat)src->format))); - int width = options.width.value_or(codecContext->width); - int height = options.height.value_or(codecContext->height); - NppStatus status; - NppiSize oSizeROI; - oSizeROI.width = width; - oSizeROI.height = height; - Npp8u* input[2]; - input[0] = (Npp8u*)src->data[0]; - input[1] = (Npp8u*)src->data[1]; - torch::Tensor dst = allocateDeviceTensor({height, width, 3}, options.device); - auto start = std::chrono::high_resolution_clock::now(); - status = nppiNV12ToRGB_8u_P2C3R( - input, - src->linesize[0], - static_cast(dst.data_ptr()), - dst.stride(0), - oSizeROI); - TORCH_CHECK(status == NPP_SUCCESS, "Failed to convert NV12 frame."); - auto end = std::chrono::high_resolution_clock::now(); - std::chrono::duration duration = end - start; - VLOG(9) << "NPP Conversion of frame height=" << height << " width=" << width - << " took: " << duration.count() << "us" << std::endl; - if (options.dimensionOrder == "NCHW") { - // The docs guaranty this to return a view: - // https://pytorch.org/docs/stable/generated/torch.permute.html - dst = dst.permute({2, 0, 1}); - } - return dst; -} - -#endif - } // namespace VideoDecoder::VideoStreamDecoderOptions::VideoStreamDecoderOptions( @@ -490,21 +397,12 @@ void VideoDecoder::addVideoStreamDecoder( TORCH_CHECK_EQ(retVal, AVSUCCESS); if (options.device.type() == torch::DeviceType::CUDA) { -#ifdef ENABLE_CUDA - // We create a small tensor using pytorch to initialize the cuda context. - torch::Tensor dummyTensorForCudaInitialization = torch::zeros( - {1}, - torch::TensorOptions().dtype(torch::kUInt8).device(options.device)); - codecContext->hw_device_ctx = av_buffer_ref(getCudaContext()); + codecContext->hw_device_ctx = initializeCudaContext(options.device); TORCH_INTERNAL_ASSERT( codecContext->hw_device_ctx, "Failed to create/reference the CUDA HW device context for index=" + std::to_string(options.device.index()) + "."); -#else - throw std::runtime_error( - "CUDA support is not enabled in this build of TorchCodec."); -#endif } retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); @@ -765,9 +663,8 @@ void VideoDecoder::maybeSeekToBeforeDesiredPts() { VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( std::function filterFunction) { -#ifdef ENABLE_NVTX - nvtx3::scoped_range loop{"decodeOneFrame"}; -#endif + NVTX_SCOPED_RANGE("decodeOneFrame"); + if (activeStreamIndices_.size() == 0) { throw std::runtime_error("No active streams configured."); } @@ -856,9 +753,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getDecodedOutputWithFilter( continue; } { -#ifdef ENABLE_NVTX - nvtx3::scoped_range loop{"avcodec_send_packet"}; -#endif + NVTX_SCOPED_RANGE("avcodec_send_packet"); ffmpegStatus = avcodec_send_packet( streams_[packet->stream_index].codecContext.get(), packet.get()); } @@ -912,17 +807,8 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( output.frame = convertFrameToTensorUsingFilterGraph(streamIndex, frame.get()); } else if (streamInfo.options.device.is_cuda()) { -#ifdef ENABLE_CUDA - { -#ifdef ENABLE_NVTX - nvtx3::scoped_range loop{"convertFrameUsingCuda"}; -#endif - output.frame = convertFrameToTensorUsingCUDA( + output.frame = convertFrameToTensorUsingCuda( streamInfo.codecContext.get(), streamInfo.options, frame.get()); - } -#else - throw std::runtime_error("CUDA is not enabled in this build."); -#endif // ENABLE_CUDA } } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement