Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/_core/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function(make_torchcodec_libraries
)

if(ENABLE_CUDA)
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp)
list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp GpuEncoder.cpp)
endif()

set(core_library_dependencies
Expand Down
17 changes: 16 additions & 1 deletion src/torchcodec/_core/Encoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,9 @@ VideoEncoder::VideoEncoder(

void VideoEncoder::initializeEncoder(
const VideoStreamOptions& videoStreamOptions) {
if (frames_.device().is_cuda()) {
gpuEncoder_ = std::make_unique<GpuEncoder>(frames_.device());
}
const AVCodec* avCodec = nullptr;
// If codec arg is provided, find codec using logic similar to FFmpeg:
// https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835
Expand Down Expand Up @@ -820,6 +823,12 @@ void VideoEncoder::initializeEncoder(
videoStreamOptions.preset.value().c_str(),
0);
}

if (gpuEncoder_) {
gpuEncoder_->registerHardwareDeviceWithCodec(avCodecContext_.get());
gpuEncoder_->setupHardwareFrameContext(avCodecContext_.get());
}

int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions);
av_dict_free(&avCodecOptions);

Expand Down Expand Up @@ -860,7 +869,13 @@ void VideoEncoder::encode() {
int numFrames = static_cast<int>(frames_.sizes()[0]);
for (int i = 0; i < numFrames; ++i) {
torch::Tensor currFrame = frames_[i];
UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i);
UniqueAVFrame avFrame;
if (gpuEncoder_) {
avFrame = gpuEncoder_->convertTensorToAVFrame(
currFrame, outPixelFormat_, i, avCodecContext_.get());
} else {
avFrame = convertTensorToAVFrame(currFrame, i);
}
encodeFrame(autoAVPacket, avFrame);
}

Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/_core/Encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
#include <map>
#include <string>
#include "AVIOContextHolder.h"
#include "DeviceInterface.h"
#include "FFMPEGCommon.h"
#include "GpuEncoder.h"
#include "StreamOptions.h"

extern "C" {
Expand Down Expand Up @@ -183,6 +185,7 @@ class VideoEncoder {
AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE;

std::unique_ptr<AVIOContextHolder> avioContextHolder_;
std::unique_ptr<GpuEncoder> gpuEncoder_;

bool encodeWasCalled_ = false;
AVDictionary* avFormatOptions_ = nullptr;
Expand Down
219 changes: 219 additions & 0 deletions src/torchcodec/_core/GpuEncoder.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "GpuEncoder.h"

#include <ATen/cuda/CUDAEvent.h>
#include <c10/cuda/CUDAStream.h>
#include <cuda_runtime.h>
#include <torch/types.h>

#include "CUDACommon.h"
#include "FFMPEGCommon.h"

extern "C" {
#include <libavutil/hwcontext_cuda.h>
#include <libavutil/pixdesc.h>
}

namespace facebook::torchcodec {
namespace {

// Redefinition from CudaDeviceInterface.cpp anonymous namespace
int getFlagsAVHardwareDeviceContextCreate() {
#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100)
return AV_CUDA_USE_CURRENT_CONTEXT;
#else
return 0;
#endif
}

// Redefinition from CudaDeviceInterface.cpp anonymous namespace
// TODO-VideoEncoder: unify device context creation, add caching to encoder
Comment on lines +34 to +35
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed on the TODOs. Seeing this makes me think we will probably want to follow-up quickly and implement GpuEncoder.cpp as part of CudaDeviceInterface.cpp. But this is better done as a follow-up considering this encoder itself is already a big change.

UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) {
enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda");
TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device");

int deviceIndex = getDeviceIndex(device);

c10::cuda::CUDAGuard deviceGuard(device);
// We set the device because we may be called from a different thread than
// the one that initialized the cuda context.
TORCH_CHECK(
cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device");

AVBufferRef* hardwareDeviceCtxRaw = nullptr;
std::string deviceOrdinal = std::to_string(deviceIndex);

int err = av_hwdevice_ctx_create(
&hardwareDeviceCtxRaw,
type,
deviceOrdinal.c_str(),
nullptr,
getFlagsAVHardwareDeviceContextCreate());

if (err < 0) {
/* clang-format off */
TORCH_CHECK(
false,
"Failed to create specified HW device. This typically happens when ",
"your installed FFmpeg doesn't support CUDA (see ",
"https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec",
"). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err));
/* clang-format on */
}

return UniqueAVBufferRef(hardwareDeviceCtxRaw);
}

// RGB to NV12 color conversion matrices (inverse of YUV to RGB)
// Note: NPP's ColorTwist function apparently expects "limited range"
// coefficient format even when producing full range output. The matrix below
// uses the limited range coefficient format (Y with +16 offset) for NPP
// compatibility.

// BT.601 limited range (matches FFmpeg default behavior)
const Npp32f defaultLimitedRangeRgbToNv12[3][4] = {
// Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B)
{0.257f, 0.504f, 0.098f, 16.0f},
// U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients)
{-0.148f, -0.291f, 0.439f, 128.0f},
// V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients)
{0.439f, -0.368f, -0.071f, 128.0f}};
} // namespace

GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) {
TORCH_CHECK(
device_.type() == torch::kCUDA, "Unsupported device: ", device_.str());

initializeCudaContextWithPytorch(device_);
initializeHardwareContext();
}

GpuEncoder::~GpuEncoder() {}

void GpuEncoder::initializeHardwareContext() {
hardwareDeviceCtx_ = createHardwareDeviceContext(device_);
nppCtx_ = getNppStreamContext(device_);
}

void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) {
TORCH_CHECK(
hardwareDeviceCtx_, "Hardware device context has not been initialized");
TORCH_CHECK(codecContext != nullptr, "codecContext is null");
codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get());
}

// Allocates and initializes AVHWFramesContext, and sets pixel format fields
// to enable encoding with CUDA device. The hw_frames_ctx field is needed by
// FFmpeg to allocate frames on GPU's memory.
void GpuEncoder::setupHardwareFrameContext(AVCodecContext* codecContext) {
TORCH_CHECK(
hardwareDeviceCtx_, "Hardware device context has not been initialized");
TORCH_CHECK(codecContext != nullptr, "codecContext is null");

AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get());
TORCH_CHECK(
hwFramesCtxRef != nullptr,
"Failed to allocate hardware frames context for codec");

// Always set pixel formats to options that support CUDA encoding.
// TODO-VideoEncoder: Enable user set pixel formats to be set and properly
// handled with NPP functions below
codecContext->sw_pix_fmt = AV_PIX_FMT_NV12;
codecContext->pix_fmt = AV_PIX_FMT_CUDA;

AVHWFramesContext* hwFramesCtx =
reinterpret_cast<AVHWFramesContext*>(hwFramesCtxRef->data);
hwFramesCtx->format = codecContext->pix_fmt;
hwFramesCtx->sw_format = codecContext->sw_pix_fmt;
hwFramesCtx->width = codecContext->width;
hwFramesCtx->height = codecContext->height;

int ret = av_hwframe_ctx_init(hwFramesCtxRef);
if (ret < 0) {
av_buffer_unref(&hwFramesCtxRef);
TORCH_CHECK(
false,
"Failed to initialize CUDA frames context for codec: ",
getFFMPEGErrorStringFromErrorCode(ret));
}

codecContext->hw_frames_ctx = hwFramesCtxRef;
}

UniqueAVFrame GpuEncoder::convertTensorToAVFrame(
const torch::Tensor& tensor,
[[maybe_unused]] AVPixelFormat targetFormat,
int frameIndex,
AVCodecContext* codecContext) {
TORCH_CHECK(
tensor.dim() == 3 && tensor.size(0) == 3,
"Expected 3D RGB tensor (CHW format), got shape: ",
tensor.sizes());

UniqueAVFrame avFrame(av_frame_alloc());
TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame");
int height = static_cast<int>(tensor.size(1));
int width = static_cast<int>(tensor.size(2));

// TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method
avFrame->format = AV_PIX_FMT_CUDA;
avFrame->height = height;
avFrame->width = width;
avFrame->pts = frameIndex;

// FFmpeg's av_hwframe_get_buffer is used to allocate memory on CUDA device.
// TODO-VideoEncoder: Consider using pytorch to allocate CUDA memory for
// efficiency
int ret =
av_hwframe_get_buffer(codecContext->hw_frames_ctx, avFrame.get(), 0);
TORCH_CHECK(
ret >= 0,
"Failed to allocate hardware frame: ",
getFFMPEGErrorStringFromErrorCode(ret));

TORCH_CHECK(
avFrame != nullptr && avFrame->data[0] != nullptr,
"avFrame must be pre-allocated with CUDA memory");

torch::Tensor hwcFrame = tensor.permute({1, 2, 0}).contiguous();

at::cuda::CUDAStream currentStream =
at::cuda::getCurrentCUDAStream(device_.index());

nppCtx_->hStream = currentStream.stream();
cudaError_t cudaErr =
cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags);
TORCH_CHECK(
cudaErr == cudaSuccess,
"cudaStreamGetFlags failed: ",
cudaGetErrorString(cudaErr));

NppiSize oSizeROI = {width, height};
NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx(
static_cast<const Npp8u*>(hwcFrame.data_ptr()),
hwcFrame.stride(0) * hwcFrame.element_size(),
avFrame->data,
avFrame->linesize,
oSizeROI,
defaultLimitedRangeRgbToNv12,
*nppCtx_);

TORCH_CHECK(
status == NPP_SUCCESS,
"Failed to convert RGB to NV12: NPP error code ",
status);

// TODO-VideoEncoder: Enable configuration of color properties, similar to
// FFmpeg. Below are the default color properties used by FFmpeg.
avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601
avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range

return avFrame;
}

} // namespace facebook::torchcodec
51 changes: 51 additions & 0 deletions src/torchcodec/_core/GpuEncoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Copyright (c) Meta Platforms, Inc. and affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#pragma once

#include <torch/types.h>
#include <memory>
#include <optional>

#include "CUDACommon.h"
#include "FFMPEGCommon.h"
#include "StreamOptions.h"

extern "C" {
#include <libavcodec/avcodec.h>
#include <libavutil/buffer.h>
#include <libavutil/hwcontext.h>
}

namespace facebook::torchcodec {

class GpuEncoder {
public:
explicit GpuEncoder(const torch::Device& device);
~GpuEncoder();

void registerHardwareDeviceWithCodec(AVCodecContext* codecContext);
void setupHardwareFrameContext(AVCodecContext* codecContext);

UniqueAVFrame convertTensorToAVFrame(
const torch::Tensor& tensor,
AVPixelFormat targetFormat,
int frameIndex,
AVCodecContext* codecContext);

const torch::Device& device() const {
return device_;
}

private:
torch::Device device_;
UniqueAVBufferRef hardwareDeviceCtx_;
UniqueNppContext nppCtx_;

void initializeHardwareContext();
};

} // namespace facebook::torchcodec
2 changes: 2 additions & 0 deletions src/torchcodec/_core/StreamOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ struct VideoStreamOptions {
ColorConversionLibrary::FILTERGRAPH;

// By default we use CPU for decoding for both C++ and python users.
// Note: For video encoding, device is determined by the location of the input
// frame tensor.
torch::Device device = torch::kCPU;
// Device variant (e.g., "ffmpeg", "beta", etc.)
std::string_view deviceVariant = "ffmpeg";
Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1019,6 +1019,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) {
m.impl("_create_from_file_like", &_create_from_file_like);
m.impl(
"_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions);
m.impl("encode_video_to_file", &encode_video_to_file);
m.impl("encode_video_to_tensor", &encode_video_to_tensor);
m.impl("_encode_video_to_file_like", &_encode_video_to_file_like);
}

TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ def encode_video_to_file_like(
"""Encode video frames to a file-like object.
Args:
frames: Video frames tensor
frames: Video frames tensor. The device of the frames tensor will be used for encoding.
frame_rate: Frame rate in frames per second
format: Video format (e.g., "mp4", "mov", "mkv")
file_like: File-like object that supports write() and seek() methods
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/encoders/_video_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class VideoEncoder:
tensor of shape ``(N, C, H, W)`` where N is the number of frames,
C is 3 channels (RGB), H is height, and W is width.
Values must be uint8 in the range ``[0, 255]``.
The device of the frames tensor will be used for encoding.
frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate.
"""

Expand Down
Loading
Loading