-
Notifications
You must be signed in to change notification settings - Fork 73
Enable CUDA device for video encoder #1008
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
Dan-Flores
wants to merge
20
commits into
meta-pytorch:main
Choose a base branch
from
Dan-Flores:encode_gpu
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+417
−9
Draft
Changes from all commits
Commits
Show all changes
20 commits
Select commit
Hold shift + click to select a range
ca1f538
changes
Dan-Flores 50cdb21
lint
Dan-Flores 54d1a1f
BT.601, test_nvenc_against_ffmpeg_cli
Dan-Flores 88e1299
remove cuda header from Encoder.cpp
Dan-Flores 926b7ea
separate encoding frame ctx init
Dan-Flores eee8889
Merge branch 'main' of https://github.com/meta-pytorch/torchcodec int…
Dan-Flores 43c6221
lint
Dan-Flores 4af1f53
parametrize other nvenc
Dan-Flores bdd133f
disable av1_nvenc
Dan-Flores d5f2637
Merge branch 'main' into encode_gpu
Dan-Flores 9c7bae7
reduce files affected, add GpuEncoder.cpp
Dan-Flores bf78468
actually add GpuEncoder.cpp
Dan-Flores 7e5e6d4
move more encoding to gpuEncoder.cpp, reduce diff
Dan-Flores ffbdf4e
remove device arg instead use frames device
Dan-Flores 15e23f6
nits
Dan-Flores 2342837
remove repeat torch_check
Dan-Flores f8b64ea
feedback
Dan-Flores 6d88905
additional small fixes
Dan-Flores fc05cb8
Merge branch 'main' of https://github.com/meta-pytorch/torchcodec int…
Dan-Flores 6b8c1fe
add in_fbcode
Dan-Flores File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,219 @@ | ||
| // Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| // All rights reserved. | ||
| // | ||
| // This source code is licensed under the BSD-style license found in the | ||
| // LICENSE file in the root directory of this source tree. | ||
|
|
||
| #include "GpuEncoder.h" | ||
|
|
||
| #include <ATen/cuda/CUDAEvent.h> | ||
| #include <c10/cuda/CUDAStream.h> | ||
| #include <cuda_runtime.h> | ||
| #include <torch/types.h> | ||
|
|
||
| #include "CUDACommon.h" | ||
| #include "FFMPEGCommon.h" | ||
|
|
||
| extern "C" { | ||
| #include <libavutil/hwcontext_cuda.h> | ||
| #include <libavutil/pixdesc.h> | ||
| } | ||
|
|
||
| namespace facebook::torchcodec { | ||
| namespace { | ||
|
|
||
| // Redefinition from CudaDeviceInterface.cpp anonymous namespace | ||
| int getFlagsAVHardwareDeviceContextCreate() { | ||
| #if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) | ||
| return AV_CUDA_USE_CURRENT_CONTEXT; | ||
| #else | ||
| return 0; | ||
| #endif | ||
| } | ||
|
|
||
| // Redefinition from CudaDeviceInterface.cpp anonymous namespace | ||
| // TODO-VideoEncoder: unify device context creation, add caching to encoder | ||
| UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) { | ||
| enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); | ||
| TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); | ||
|
|
||
| int deviceIndex = getDeviceIndex(device); | ||
|
|
||
| c10::cuda::CUDAGuard deviceGuard(device); | ||
| // We set the device because we may be called from a different thread than | ||
| // the one that initialized the cuda context. | ||
| TORCH_CHECK( | ||
| cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device"); | ||
|
|
||
| AVBufferRef* hardwareDeviceCtxRaw = nullptr; | ||
| std::string deviceOrdinal = std::to_string(deviceIndex); | ||
|
|
||
| int err = av_hwdevice_ctx_create( | ||
| &hardwareDeviceCtxRaw, | ||
| type, | ||
| deviceOrdinal.c_str(), | ||
| nullptr, | ||
| getFlagsAVHardwareDeviceContextCreate()); | ||
|
|
||
| if (err < 0) { | ||
| /* clang-format off */ | ||
| TORCH_CHECK( | ||
| false, | ||
| "Failed to create specified HW device. This typically happens when ", | ||
| "your installed FFmpeg doesn't support CUDA (see ", | ||
| "https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec", | ||
| "). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err)); | ||
| /* clang-format on */ | ||
| } | ||
|
|
||
| return UniqueAVBufferRef(hardwareDeviceCtxRaw); | ||
| } | ||
|
|
||
| // RGB to NV12 color conversion matrices (inverse of YUV to RGB) | ||
| // Note: NPP's ColorTwist function apparently expects "limited range" | ||
| // coefficient format even when producing full range output. The matrix below | ||
| // uses the limited range coefficient format (Y with +16 offset) for NPP | ||
| // compatibility. | ||
|
|
||
| // BT.601 limited range (matches FFmpeg default behavior) | ||
| const Npp32f defaultLimitedRangeRgbToNv12[3][4] = { | ||
| // Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B) | ||
| {0.257f, 0.504f, 0.098f, 16.0f}, | ||
| // U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients) | ||
| {-0.148f, -0.291f, 0.439f, 128.0f}, | ||
| // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients) | ||
| {0.439f, -0.368f, -0.071f, 128.0f}}; | ||
| } // namespace | ||
|
|
||
| GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) { | ||
| TORCH_CHECK( | ||
| device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); | ||
|
|
||
| initializeCudaContextWithPytorch(device_); | ||
| initializeHardwareContext(); | ||
| } | ||
|
|
||
| GpuEncoder::~GpuEncoder() {} | ||
|
|
||
| void GpuEncoder::initializeHardwareContext() { | ||
| hardwareDeviceCtx_ = createHardwareDeviceContext(device_); | ||
| nppCtx_ = getNppStreamContext(device_); | ||
| } | ||
|
|
||
| void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) { | ||
| TORCH_CHECK( | ||
| hardwareDeviceCtx_, "Hardware device context has not been initialized"); | ||
| TORCH_CHECK(codecContext != nullptr, "codecContext is null"); | ||
| codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); | ||
| } | ||
|
|
||
| // Allocates and initializes AVHWFramesContext, and sets pixel format fields | ||
| // to enable encoding with CUDA device. The hw_frames_ctx field is needed by | ||
| // FFmpeg to allocate frames on GPU's memory. | ||
| void GpuEncoder::setupHardwareFrameContext(AVCodecContext* codecContext) { | ||
| TORCH_CHECK( | ||
| hardwareDeviceCtx_, "Hardware device context has not been initialized"); | ||
| TORCH_CHECK(codecContext != nullptr, "codecContext is null"); | ||
|
|
||
| AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); | ||
| TORCH_CHECK( | ||
| hwFramesCtxRef != nullptr, | ||
| "Failed to allocate hardware frames context for codec"); | ||
|
|
||
| // Always set pixel formats to options that support CUDA encoding. | ||
| // TODO-VideoEncoder: Enable user set pixel formats to be set and properly | ||
| // handled with NPP functions below | ||
| codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; | ||
| codecContext->pix_fmt = AV_PIX_FMT_CUDA; | ||
|
|
||
| AVHWFramesContext* hwFramesCtx = | ||
| reinterpret_cast<AVHWFramesContext*>(hwFramesCtxRef->data); | ||
| hwFramesCtx->format = codecContext->pix_fmt; | ||
| hwFramesCtx->sw_format = codecContext->sw_pix_fmt; | ||
| hwFramesCtx->width = codecContext->width; | ||
| hwFramesCtx->height = codecContext->height; | ||
|
|
||
| int ret = av_hwframe_ctx_init(hwFramesCtxRef); | ||
| if (ret < 0) { | ||
| av_buffer_unref(&hwFramesCtxRef); | ||
| TORCH_CHECK( | ||
| false, | ||
| "Failed to initialize CUDA frames context for codec: ", | ||
| getFFMPEGErrorStringFromErrorCode(ret)); | ||
| } | ||
|
|
||
| codecContext->hw_frames_ctx = hwFramesCtxRef; | ||
| } | ||
|
|
||
| UniqueAVFrame GpuEncoder::convertTensorToAVFrame( | ||
| const torch::Tensor& tensor, | ||
| [[maybe_unused]] AVPixelFormat targetFormat, | ||
| int frameIndex, | ||
| AVCodecContext* codecContext) { | ||
| TORCH_CHECK( | ||
| tensor.dim() == 3 && tensor.size(0) == 3, | ||
| "Expected 3D RGB tensor (CHW format), got shape: ", | ||
| tensor.sizes()); | ||
|
|
||
| UniqueAVFrame avFrame(av_frame_alloc()); | ||
| TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); | ||
| int height = static_cast<int>(tensor.size(1)); | ||
| int width = static_cast<int>(tensor.size(2)); | ||
|
|
||
| // TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method | ||
| avFrame->format = AV_PIX_FMT_CUDA; | ||
| avFrame->height = height; | ||
| avFrame->width = width; | ||
| avFrame->pts = frameIndex; | ||
|
|
||
| // FFmpeg's av_hwframe_get_buffer is used to allocate memory on CUDA device. | ||
| // TODO-VideoEncoder: Consider using pytorch to allocate CUDA memory for | ||
| // efficiency | ||
| int ret = | ||
| av_hwframe_get_buffer(codecContext->hw_frames_ctx, avFrame.get(), 0); | ||
| TORCH_CHECK( | ||
| ret >= 0, | ||
| "Failed to allocate hardware frame: ", | ||
| getFFMPEGErrorStringFromErrorCode(ret)); | ||
|
|
||
| TORCH_CHECK( | ||
| avFrame != nullptr && avFrame->data[0] != nullptr, | ||
| "avFrame must be pre-allocated with CUDA memory"); | ||
|
|
||
| torch::Tensor hwcFrame = tensor.permute({1, 2, 0}).contiguous(); | ||
|
|
||
| at::cuda::CUDAStream currentStream = | ||
| at::cuda::getCurrentCUDAStream(device_.index()); | ||
|
|
||
| nppCtx_->hStream = currentStream.stream(); | ||
| cudaError_t cudaErr = | ||
| cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags); | ||
| TORCH_CHECK( | ||
| cudaErr == cudaSuccess, | ||
| "cudaStreamGetFlags failed: ", | ||
| cudaGetErrorString(cudaErr)); | ||
|
|
||
| NppiSize oSizeROI = {width, height}; | ||
| NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( | ||
| static_cast<const Npp8u*>(hwcFrame.data_ptr()), | ||
| hwcFrame.stride(0) * hwcFrame.element_size(), | ||
| avFrame->data, | ||
| avFrame->linesize, | ||
| oSizeROI, | ||
| defaultLimitedRangeRgbToNv12, | ||
| *nppCtx_); | ||
|
|
||
| TORCH_CHECK( | ||
| status == NPP_SUCCESS, | ||
| "Failed to convert RGB to NV12: NPP error code ", | ||
| status); | ||
|
|
||
| // TODO-VideoEncoder: Enable configuration of color properties, similar to | ||
| // FFmpeg. Below are the default color properties used by FFmpeg. | ||
| avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 | ||
| avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range | ||
|
|
||
| return avFrame; | ||
| } | ||
|
|
||
| } // namespace facebook::torchcodec | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,51 @@ | ||
| // Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| // All rights reserved. | ||
| // | ||
| // This source code is licensed under the BSD-style license found in the | ||
| // LICENSE file in the root directory of this source tree. | ||
|
|
||
| #pragma once | ||
|
|
||
| #include <torch/types.h> | ||
| #include <memory> | ||
| #include <optional> | ||
|
|
||
| #include "CUDACommon.h" | ||
| #include "FFMPEGCommon.h" | ||
| #include "StreamOptions.h" | ||
|
|
||
| extern "C" { | ||
| #include <libavcodec/avcodec.h> | ||
| #include <libavutil/buffer.h> | ||
| #include <libavutil/hwcontext.h> | ||
| } | ||
|
|
||
| namespace facebook::torchcodec { | ||
|
|
||
| class GpuEncoder { | ||
| public: | ||
| explicit GpuEncoder(const torch::Device& device); | ||
| ~GpuEncoder(); | ||
|
|
||
| void registerHardwareDeviceWithCodec(AVCodecContext* codecContext); | ||
| void setupHardwareFrameContext(AVCodecContext* codecContext); | ||
|
|
||
| UniqueAVFrame convertTensorToAVFrame( | ||
| const torch::Tensor& tensor, | ||
| AVPixelFormat targetFormat, | ||
| int frameIndex, | ||
| AVCodecContext* codecContext); | ||
|
|
||
| const torch::Device& device() const { | ||
| return device_; | ||
| } | ||
|
|
||
| private: | ||
| torch::Device device_; | ||
| UniqueAVBufferRef hardwareDeviceCtx_; | ||
| UniqueNppContext nppCtx_; | ||
|
|
||
| void initializeHardwareContext(); | ||
| }; | ||
|
|
||
| } // namespace facebook::torchcodec |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed on the TODOs. Seeing this makes me think we will probably want to follow-up quickly and implement
GpuEncoder.cppas part ofCudaDeviceInterface.cpp. But this is better done as a follow-up considering this encoder itself is already a big change.