diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 4fc7af75e..9b7f18b8f 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -100,7 +100,7 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp GpuEncoder.cpp) endif() set(core_library_dependencies diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 19ac9220d..2569a72ab 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -724,6 +724,9 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { + if (frames_.device().is_cuda()) { + gpuEncoder_ = std::make_unique(frames_.device()); + } const AVCodec* avCodec = nullptr; // If codec arg is provided, find codec using logic similar to FFmpeg: // https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835 @@ -820,6 +823,12 @@ void VideoEncoder::initializeEncoder( videoStreamOptions.preset.value().c_str(), 0); } + + if (gpuEncoder_) { + gpuEncoder_->registerHardwareDeviceWithCodec(avCodecContext_.get()); + gpuEncoder_->setupHardwareFrameContext(avCodecContext_.get()); + } + int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions); av_dict_free(&avCodecOptions); @@ -860,7 +869,13 @@ void VideoEncoder::encode() { int numFrames = static_cast(frames_.sizes()[0]); for (int i = 0; i < numFrames; ++i) { torch::Tensor currFrame = frames_[i]; - UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i); + UniqueAVFrame avFrame; + if (gpuEncoder_) { + avFrame = gpuEncoder_->convertTensorToAVFrame( + currFrame, outPixelFormat_, i, avCodecContext_.get()); + } else { + avFrame = convertTensorToAVFrame(currFrame, i); + } encodeFrame(autoAVPacket, avFrame); } diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 1bdc1e443..fe3284737 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -3,7 +3,9 @@ #include #include #include "AVIOContextHolder.h" +#include "DeviceInterface.h" #include "FFMPEGCommon.h" +#include "GpuEncoder.h" #include "StreamOptions.h" extern "C" { @@ -183,6 +185,7 @@ class VideoEncoder { AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; std::unique_ptr avioContextHolder_; + std::unique_ptr gpuEncoder_; bool encodeWasCalled_ = false; AVDictionary* avFormatOptions_ = nullptr; diff --git a/src/torchcodec/_core/GpuEncoder.cpp b/src/torchcodec/_core/GpuEncoder.cpp new file mode 100644 index 000000000..5e80a1e06 --- /dev/null +++ b/src/torchcodec/_core/GpuEncoder.cpp @@ -0,0 +1,219 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "GpuEncoder.h" + +#include +#include +#include +#include + +#include "CUDACommon.h" +#include "FFMPEGCommon.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +// Redefinition from CudaDeviceInterface.cpp anonymous namespace +int getFlagsAVHardwareDeviceContextCreate() { +#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) + return AV_CUDA_USE_CURRENT_CONTEXT; +#else + return 0; +#endif +} + +// Redefinition from CudaDeviceInterface.cpp anonymous namespace +// TODO-VideoEncoder: unify device context creation, add caching to encoder +UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + + int deviceIndex = getDeviceIndex(device); + + c10::cuda::CUDAGuard deviceGuard(device); + // We set the device because we may be called from a different thread than + // the one that initialized the cuda context. + TORCH_CHECK( + cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device"); + + AVBufferRef* hardwareDeviceCtxRaw = nullptr; + std::string deviceOrdinal = std::to_string(deviceIndex); + + int err = av_hwdevice_ctx_create( + &hardwareDeviceCtxRaw, + type, + deviceOrdinal.c_str(), + nullptr, + getFlagsAVHardwareDeviceContextCreate()); + + if (err < 0) { + /* clang-format off */ + TORCH_CHECK( + false, + "Failed to create specified HW device. This typically happens when ", + "your installed FFmpeg doesn't support CUDA (see ", + "https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec", + "). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err)); + /* clang-format on */ + } + + return UniqueAVBufferRef(hardwareDeviceCtxRaw); +} + +// RGB to NV12 color conversion matrices (inverse of YUV to RGB) +// Note: NPP's ColorTwist function apparently expects "limited range" +// coefficient format even when producing full range output. The matrix below +// uses the limited range coefficient format (Y with +16 offset) for NPP +// compatibility. + +// BT.601 limited range (matches FFmpeg default behavior) +const Npp32f defaultLimitedRangeRgbToNv12[3][4] = { + // Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B) + {0.257f, 0.504f, 0.098f, 16.0f}, + // U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients) + {-0.148f, -0.291f, 0.439f, 128.0f}, + // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients) + {0.439f, -0.368f, -0.071f, 128.0f}}; +} // namespace + +GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) { + TORCH_CHECK( + device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + + initializeCudaContextWithPytorch(device_); + initializeHardwareContext(); +} + +GpuEncoder::~GpuEncoder() {} + +void GpuEncoder::initializeHardwareContext() { + hardwareDeviceCtx_ = createHardwareDeviceContext(device_); + nppCtx_ = getNppStreamContext(device_); +} + +void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) { + TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); +} + +// Allocates and initializes AVHWFramesContext, and sets pixel format fields +// to enable encoding with CUDA device. The hw_frames_ctx field is needed by +// FFmpeg to allocate frames on GPU's memory. +void GpuEncoder::setupHardwareFrameContext(AVCodecContext* codecContext) { + TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + + AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); + TORCH_CHECK( + hwFramesCtxRef != nullptr, + "Failed to allocate hardware frames context for codec"); + + // Always set pixel formats to options that support CUDA encoding. + // TODO-VideoEncoder: Enable user set pixel formats to be set and properly + // handled with NPP functions below + codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; + codecContext->pix_fmt = AV_PIX_FMT_CUDA; + + AVHWFramesContext* hwFramesCtx = + reinterpret_cast(hwFramesCtxRef->data); + hwFramesCtx->format = codecContext->pix_fmt; + hwFramesCtx->sw_format = codecContext->sw_pix_fmt; + hwFramesCtx->width = codecContext->width; + hwFramesCtx->height = codecContext->height; + + int ret = av_hwframe_ctx_init(hwFramesCtxRef); + if (ret < 0) { + av_buffer_unref(&hwFramesCtxRef); + TORCH_CHECK( + false, + "Failed to initialize CUDA frames context for codec: ", + getFFMPEGErrorStringFromErrorCode(ret)); + } + + codecContext->hw_frames_ctx = hwFramesCtxRef; +} + +UniqueAVFrame GpuEncoder::convertTensorToAVFrame( + const torch::Tensor& tensor, + [[maybe_unused]] AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) { + TORCH_CHECK( + tensor.dim() == 3 && tensor.size(0) == 3, + "Expected 3D RGB tensor (CHW format), got shape: ", + tensor.sizes()); + + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + int height = static_cast(tensor.size(1)); + int width = static_cast(tensor.size(2)); + + // TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method + avFrame->format = AV_PIX_FMT_CUDA; + avFrame->height = height; + avFrame->width = width; + avFrame->pts = frameIndex; + + // FFmpeg's av_hwframe_get_buffer is used to allocate memory on CUDA device. + // TODO-VideoEncoder: Consider using pytorch to allocate CUDA memory for + // efficiency + int ret = + av_hwframe_get_buffer(codecContext->hw_frames_ctx, avFrame.get(), 0); + TORCH_CHECK( + ret >= 0, + "Failed to allocate hardware frame: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + TORCH_CHECK( + avFrame != nullptr && avFrame->data[0] != nullptr, + "avFrame must be pre-allocated with CUDA memory"); + + torch::Tensor hwcFrame = tensor.permute({1, 2, 0}).contiguous(); + + at::cuda::CUDAStream currentStream = + at::cuda::getCurrentCUDAStream(device_.index()); + + nppCtx_->hStream = currentStream.stream(); + cudaError_t cudaErr = + cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags); + TORCH_CHECK( + cudaErr == cudaSuccess, + "cudaStreamGetFlags failed: ", + cudaGetErrorString(cudaErr)); + + NppiSize oSizeROI = {width, height}; + NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( + static_cast(hwcFrame.data_ptr()), + hwcFrame.stride(0) * hwcFrame.element_size(), + avFrame->data, + avFrame->linesize, + oSizeROI, + defaultLimitedRangeRgbToNv12, + *nppCtx_); + + TORCH_CHECK( + status == NPP_SUCCESS, + "Failed to convert RGB to NV12: NPP error code ", + status); + + // TODO-VideoEncoder: Enable configuration of color properties, similar to + // FFmpeg. Below are the default color properties used by FFmpeg. + avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 + avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range + + return avFrame; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/GpuEncoder.h b/src/torchcodec/_core/GpuEncoder.h new file mode 100644 index 000000000..a5a6ad68c --- /dev/null +++ b/src/torchcodec/_core/GpuEncoder.h @@ -0,0 +1,51 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include "CUDACommon.h" +#include "FFMPEGCommon.h" +#include "StreamOptions.h" + +extern "C" { +#include +#include +#include +} + +namespace facebook::torchcodec { + +class GpuEncoder { + public: + explicit GpuEncoder(const torch::Device& device); + ~GpuEncoder(); + + void registerHardwareDeviceWithCodec(AVCodecContext* codecContext); + void setupHardwareFrameContext(AVCodecContext* codecContext); + + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext); + + const torch::Device& device() const { + return device_; + } + + private: + torch::Device device_; + UniqueAVBufferRef hardwareDeviceCtx_; + UniqueNppContext nppCtx_; + + void initializeHardwareContext(); +}; + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index ce0f27d3b..9faafb502 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -41,6 +41,8 @@ struct VideoStreamOptions { ColorConversionLibrary::FILTERGRAPH; // By default we use CPU for decoding for both C++ and python users. + // Note: For video encoding, device is determined by the location of the input + // frame tensor. torch::Device device = torch::kCPU; // Device variant (e.g., "ffmpeg", "beta", etc.) std::string_view deviceVariant = "ffmpeg"; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 4ec72974d..7030928a5 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -1019,6 +1019,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("_create_from_file_like", &_create_from_file_like); m.impl( "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions); + m.impl("encode_video_to_file", &encode_video_to_file); + m.impl("encode_video_to_tensor", &encode_video_to_tensor); + m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 921f5ee54..8dd50c99d 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -222,7 +222,7 @@ def encode_video_to_file_like( """Encode video frames to a file-like object. Args: - frames: Video frames tensor + frames: Video frames tensor. The device of the frames tensor will be used for encoding. frame_rate: Frame rate in frames per second format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 3fede6b8e..a240052a6 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -15,6 +15,7 @@ class VideoEncoder: tensor of shape ``(N, C, H, W)`` where N is the number of frames, C is 3 channels (RGB), H is height, and W is width. Values must be uint8 in the range ``[0, 255]``. + The device of the frames tensor will be used for encoding. frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. """ diff --git a/test/test_encoders.py b/test/test_encoders.py index fef19ac99..157616a38 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -17,6 +17,7 @@ assert_tensor_close_on_at_least, get_ffmpeg_major_version, get_ffmpeg_minor_version, + in_fbcode, IS_WINDOWS, NASA_AUDIO_MP3, needs_ffmpeg_cli, @@ -819,15 +820,22 @@ def test_extra_options_errors(self, method, tmp_path, extra_options, error): getattr(encoder, method)(**valid_params, extra_options=extra_options) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_contiguity(self, method, tmp_path): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_contiguity(self, method, tmp_path, device): # Ensure that 2 sets of video frames with the same pixel values are encoded # in the same way, regardless of their memory layout. Here we encode 2 equal # frame tensors, one is contiguous while the other is non-contiguous. - num_frames, channels, height, width = 5, 3, 64, 64 - contiguous_frames = torch.randint( - 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 - ).contiguous() + num_frames, channels, height, width = 5, 3, 256, 256 + contiguous_frames = ( + torch.randint( + 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 + ) + .contiguous() + .to(device) + ) assert contiguous_frames.is_contiguous() # Permute NCHW to NHWC, then update the memory layout, then permute back @@ -843,7 +851,11 @@ def test_contiguity(self, method, tmp_path): ) def encode_to_tensor(frames): - common_params = dict(crf=0, pixel_format="yuv444p") + common_params = dict( + crf=0, + pixel_format="yuv444p", + codec="h264_nvenc" if device != "cpu" else None, + ) if method == "to_file": dest = str(tmp_path / "output.mp4") VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params) @@ -1291,3 +1303,105 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range assert metadata["profile"].lower() == expected_profile assert metadata["color_space"] == colorspace assert metadata["color_range"] == color_range + + @pytest.mark.needs_cuda + @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") + @pytest.mark.parametrize( + "format_codec", + [ + ("mov", "h264_nvenc"), + ("mp4", "hevc_nvenc"), + ("avi", "h264_nvenc"), + # ("mkv", "av1_nvenc"), # av1_nvenc is not supported on CI + ], + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + # TODO-VideoEncoder: Enable additional pixel formats ("yuv420p", "yuv444p") + def test_nvenc_against_ffmpeg_cli(self, tmp_path, format_codec, method): + # Encode with FFmpeg CLI using nvenc codecs + format, codec = format_codec + device = "cuda" + pixel_format = "nv12" + qp = 1 # Lossless (qp=0) is not supported on av1_nvenc, so we use 1 + source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device) + + temp_raw_path = str(tmp_path / "temp_input.raw") + with open(temp_raw_path, "wb") as f: + f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) + + ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_nvenc_output.{format}") + frame_rate = 30 + + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", # Input format + "-s", + f"{source_frames.shape[3]}x{source_frames.shape[2]}", + "-r", + str(frame_rate), + "-i", + temp_raw_path, + "-c:v", + codec, # Use specified NVENC hardware encoder + ] + + ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) # Output format + if codec == "av1_nvenc": + ffmpeg_cmd.extend(["-rc", "constqp"]) # Set rate control mode for AV1 + ffmpeg_cmd.extend(["-qp", str(qp)]) # Use lossless qp for other codecs + ffmpeg_cmd.extend([ffmpeg_encoded_path]) + + # TODO-VideoEncoder: Ensure CI does not skip this test, as we know NVENC is available. + try: + subprocess.run(ffmpeg_cmd, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + if b"No NVENC capable devices found" in e.stderr: + pytest.skip("NVENC not available on this system") + else: + raise + + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) + + encoder_extra_options = {"qp": qp} + if codec == "av1_nvenc": + encoder_extra_options["rc"] = 0 # constqp mode + if method == "to_file": + encoder_output_path = str(tmp_path / f"nvenc_output.{format}") + encoder.to_file( + dest=encoder_output_path, + codec=codec, + pixel_format=pixel_format, + extra_options=encoder_extra_options, + ) + encoder_output = encoder_output_path + elif method == "to_tensor": + encoder_output = encoder.to_tensor( + format=format, + codec=codec, + pixel_format=pixel_format, + extra_options=encoder_extra_options, + ) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like=file_like, + format=format, + codec=codec, + pixel_format=pixel_format, + extra_options=encoder_extra_options, + ) + encoder_output = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") + + ffmpeg_frames = self.decode(ffmpeg_encoded_path).data + encoder_frames = self.decode(encoder_output).data + + assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] + for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): + assert psnr(ff_frame, enc_frame) > 25 + assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=95, atol=2)