From ca1f53801ef85f476b8edd3ed3ce7487329c236c Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Mon, 3 Nov 2025 20:56:30 +0000 Subject: [PATCH 01/17] changes --- src/torchcodec/_core/CpuDeviceInterface.h | 5 -- src/torchcodec/_core/CudaDeviceInterface.cpp | 39 +++++++-- src/torchcodec/_core/CudaDeviceInterface.h | 3 +- src/torchcodec/_core/DeviceInterface.h | 7 +- src/torchcodec/_core/Encoder.cpp | 17 ++++ src/torchcodec/_core/Encoder.h | 2 + src/torchcodec/_core/FFMPEGCommon.cpp | 2 +- src/torchcodec/_core/SingleStreamDecoder.cpp | 2 +- src/torchcodec/_core/custom_ops.cpp | 14 +++- src/torchcodec/_core/ops.py | 10 ++- src/torchcodec/encoders/_video_encoder.py | 23 +++++- test/test_encoders.py | 86 ++++++++++++++++++-- 12 files changed, 179 insertions(+), 31 deletions(-) diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 801b83826..55e34c3b6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -18,11 +18,6 @@ class CpuDeviceInterface : public DeviceInterface { virtual ~CpuDeviceInterface() {} - std::optional findCodec( - [[maybe_unused]] const AVCodecID& codecId) override { - return std::nullopt; - } - virtual void initialize( const AVStream* avStream, const UniqueDecodingAVFormatContext& avFormatCtx, diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 0e20c5e8d..4011c7340 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -329,11 +329,40 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } +namespace { +// Helper function to check if a codec supports CUDA hardware acceleration +bool codecSupportsCudaHardware(const AVCodec* codec) { + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { + return true; + } + } + return false; +} +} // namespace + // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional CudaDeviceInterface::findCodec( + +std::optional CudaDeviceInterface::findEncoder( + const AVCodecID& codecId) { + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_encoder(codec)) { + continue; + } + if (codecSupportsCudaHardware(codec)) { + return codec; + } + } + return std::nullopt; +} + +std::optional CudaDeviceInterface::findDecoder( const AVCodecID& codecId) { void* i = nullptr; const AVCodec* codec = nullptr; @@ -342,12 +371,8 @@ std::optional CudaDeviceInterface::findCodec( continue; } - const AVCodecHWConfig* config = nullptr; - for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; - ++j) { - if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { - return codec; - } + if (codecSupportsCudaHardware(codec)) { + return codec; } } diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index c892bd49b..9c0c2fdb9 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -18,7 +18,8 @@ class CudaDeviceInterface : public DeviceInterface { virtual ~CudaDeviceInterface(); - std::optional findCodec(const AVCodecID& codecId) override; + std::optional findEncoder(const AVCodecID& codecId) override; + std::optional findDecoder(const AVCodecID& codecId) override; void initialize( const AVStream* avStream, diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 319fe01a8..3ef956056 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -46,7 +46,12 @@ class DeviceInterface { return device_; }; - virtual std::optional findCodec( + virtual std::optional findEncoder( + [[maybe_unused]] const AVCodecID& codecId) { + return std::nullopt; + }; + + virtual std::optional findDecoder( [[maybe_unused]] const AVCodecID& codecId) { return std::nullopt; }; diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 362a02a95..9a0525873 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -724,6 +724,12 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { + deviceInterface_ = createDeviceInterface( + videoStreamOptions.device, videoStreamOptions.deviceVariant); + TORCH_CHECK( + deviceInterface_ != nullptr, + "Failed to create device interface. This should never happen, please report."); + const AVCodec* avCodec = nullptr; // If codec arg is provided, find codec using logic similar to FFmpeg: // https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835 @@ -749,7 +755,13 @@ void VideoEncoder::initializeEncoder( avFormatContext_->oformat != nullptr, "Output format is null, unable to find default codec."); avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); + // TODO: merge above logic w this logic + // Try to find a hardware-accelerated encoder if not using CPU + if (videoStreamOptions.device.type() != torch::kCPU) { + avCodec = deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec).value_or(avCodec); TORCH_CHECK(avCodec != nullptr, "Video codec not found"); + } + } AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); @@ -820,6 +832,11 @@ void VideoEncoder::initializeEncoder( videoStreamOptions.preset.value().c_str(), 0); } + + // Register the hardware device context with the codec + // context before calling avcodec_open2(). + deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get()); + int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions); av_dict_free(&avCodecOptions); diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 3d59eb6f6..ef17f33af 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -3,6 +3,7 @@ #include #include #include "AVIOContextHolder.h" +#include "DeviceInterface.h" #include "FFMPEGCommon.h" #include "StreamOptions.h" @@ -183,6 +184,7 @@ class VideoEncoder { AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; std::unique_ptr avioContextHolder_; + std::unique_ptr deviceInterface_; bool encodeWasCalled_ = false; AVDictionary* avFormatOptions_ = nullptr; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index e1b88b36a..14ddd3e6e 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() { AVCodecOnlyUseForCallingAVFindBestStream makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) { -#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) +#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3 return const_cast(codec); #else return codec; diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index ac7489bbe..dd1d9cbb3 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -462,7 +462,7 @@ void SingleStreamDecoder::addStream( // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) + deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 3c6048187..a1187b657 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + "encode_video_to_file(Tensor frames, int frame_rate, str filename, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str device=\"cpu\",str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -641,6 +641,7 @@ void encode_video_to_file( const at::Tensor& frames, int64_t frame_rate, std::string_view file_name, + std::string_view device = "cpu", std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, @@ -650,6 +651,8 @@ void encode_video_to_file( videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; + + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.preset = preset; if (extra_options.has_value()) { @@ -669,6 +672,7 @@ at::Tensor encode_video_to_tensor( const at::Tensor& frames, int64_t frame_rate, std::string_view format, + std::string_view device = "cpu", std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, @@ -679,6 +683,8 @@ at::Tensor encode_video_to_tensor( videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; + + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.preset = preset; if (extra_options.has_value()) { @@ -700,6 +706,7 @@ void _encode_video_to_file_like( int64_t frame_rate, std::string_view format, int64_t file_like_context, + std::string_view device = "cpu", std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, @@ -715,6 +722,7 @@ void _encode_video_to_file_like( videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.preset = preset; if (extra_options.has_value()) { diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index c3562f679..ec1adaea2 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -213,6 +213,7 @@ def encode_video_to_file_like( frame_rate: int, format: str, file_like: Union[io.RawIOBase, io.BufferedIOBase], + device: Optional[str] = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, @@ -226,6 +227,7 @@ def encode_video_to_file_like( frame_rate: Frame rate in frames per second format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods + device: Device to use for encoding (default: "cpu") codec: Optional codec name (e.g., "libx264", "h264") pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p") crf: Optional constant rate factor for encoding quality @@ -239,6 +241,7 @@ def encode_video_to_file_like( frame_rate, format, _pybind_ops.create_file_like_context(file_like, True), # True means for writing + device, codec, pixel_format, crf, @@ -331,11 +334,12 @@ def encode_video_to_file_abstract( frames: torch.Tensor, frame_rate: int, filename: str, + device: str = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, preset: Optional[str] = None, crf: Optional[Union[int, float]] = None, - extra_options: Optional[list[str]] = None, + extra_options: Optional[list[str]] = None = None, ) -> None: return @@ -345,11 +349,12 @@ def encode_video_to_tensor_abstract( frames: torch.Tensor, frame_rate: int, format: str, + device: str = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, preset: Optional[str] = None, crf: Optional[Union[int, float]] = None, - extra_options: Optional[list[str]] = None, + extra_options: Optional[list[str]] = None = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) @@ -360,6 +365,7 @@ def _encode_video_to_file_like_abstract( frame_rate: int, format: str, file_like_context: int, + device: str = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, preset: Optional[str] = None, diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 0bb754025..36a1320e4 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -1,8 +1,8 @@ from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Optional, Optional, Union import torch -from torch import Tensor +from torch import device as torch_device, Tensor from torchcodec import _core @@ -16,9 +16,18 @@ class VideoEncoder: C is 3 channels (RGB), H is height, and W is width. Values must be uint8 in the range ``[0, 255]``. frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. + device (str or torch.device, optional): The device to use for encoding. Default: "cpu". + If you pass a CUDA device, frames will be encoded on GPU. + Note: The "beta" CUDA backend is not supported for encoding. """ - def __init__(self, frames: Tensor, *, frame_rate: int): + def __init__( + self, + frames: Tensor, + *, + frame_rate: int, + device: Optional[Union[str, torch_device]] = "cpu", + ): torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") if not isinstance(frames, Tensor): raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.") @@ -29,8 +38,13 @@ def __init__(self, frames: Tensor, *, frame_rate: int): if frame_rate <= 0: raise ValueError(f"{frame_rate = } must be > 0.") + # Validate and store device + if isinstance(device, torch_device): + device = str(device) + self._frames = frames self._frame_rate = frame_rate + self._device = device def to_file( self, @@ -69,6 +83,7 @@ def to_file( frames=self._frames, frame_rate=self._frame_rate, filename=str(dest), + device=self._device, codec=codec, pixel_format=pixel_format, crf=crf, @@ -117,6 +132,7 @@ def to_tensor( frames=self._frames, frame_rate=self._frame_rate, format=format, + device=self._device, codec=codec, pixel_format=pixel_format, crf=crf, @@ -169,6 +185,7 @@ def to_file_like( frame_rate=self._frame_rate, format=format, file_like=file_like, + device=self._device, codec=codec, pixel_format=pixel_format, crf=crf, diff --git a/test/test_encoders.py b/test/test_encoders.py index ad2f0cefe..6de34e031 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -11,6 +11,7 @@ import torch from torchcodec.decoders import AudioDecoder, VideoDecoder +from torchcodec.decoders._video_decoder import VideoDecoder from torchcodec.encoders import AudioEncoder, VideoEncoder from .utils import ( @@ -765,12 +766,81 @@ def test_extra_options_errors(self, method, tmp_path, extra_options, error): getattr(encoder, method)(**valid_params, extra_options=extra_options) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_contiguity(self, method, tmp_path): + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_pixel_format_errors(self, method, tmp_path): + frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames, frame_rate=30) + + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + + with pytest.raises( + RuntimeError, + match=r"Unknown pixel format: invalid_pix_fmt[\s\S]*Supported pixel formats.*yuv420p", + ): + getattr(encoder, method)(**valid_params, pixel_format="invalid_pix_fmt") + + with pytest.raises( + RuntimeError, + match=r"Specified pixel format rgb24 is not supported[\s\S]*Supported pixel formats.*yuv420p", + ): + getattr(encoder, method)(**valid_params, pixel_format="rgb24") + + @pytest.mark.parametrize( + "extra_options,error", + [ + ({"qp": -10}, "qp=-10 is out of valid range"), + ( + {"qp": ""}, + "Option qp expects a numeric value but got", + ), + ( + {"direct-pred": "a"}, + "Option direct-pred expects a numeric value but got 'a'", + ), + ({"tune": "not_a_real_tune"}, "avcodec_open2 failed: Invalid argument"), + ( + {"tune": 10}, + "avcodec_open2 failed: Invalid argument", + ), + ], + ) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_extra_options_errors(self, method, tmp_path, extra_options, error): + frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) + encoder = VideoEncoder(frames, frame_rate=30) + + if method == "to_file": + valid_params = dict(dest=str(tmp_path / "output.mp4")) + elif method == "to_tensor": + valid_params = dict(format="mp4") + elif method == "to_file_like": + valid_params = dict(file_like=io.BytesIO(), format="mp4") + else: + raise ValueError(f"Unknown method: {method}") + + with pytest.raises( + RuntimeError, + match=error, + ): + getattr(encoder, method)(**valid_params, extra_options=extra_options) + + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + @pytest.mark.parametrize( + "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) + ) + def test_contiguity(self, method, tmp_path, device): # Ensure that 2 sets of video frames with the same pixel values are encoded # in the same way, regardless of their memory layout. Here we encode 2 equal # frame tensors, one is contiguous while the other is non-contiguous. - num_frames, channels, height, width = 5, 3, 64, 64 + num_frames, channels, height, width = 5, 3, 256, 256 contiguous_frames = torch.randint( 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 ).contiguous() @@ -792,16 +862,18 @@ def encode_to_tensor(frames): common_params = dict(crf=0, pixel_format="yuv444p") if method == "to_file": dest = str(tmp_path / "output.mp4") - VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params) + VideoEncoder(frames, frame_rate=30, device=device).to_file(dest=dest, **common_params) with open(dest, "rb") as f: - return torch.frombuffer(f.read(), dtype=torch.uint8) + return torch.frombuffer(f.read(), dtype=torch.uint8).clone() elif method == "to_tensor": - return VideoEncoder(frames, frame_rate=30).to_tensor( - format="mp4", **common_params + return VideoEncoder(frames, frame_rate=30, device=device).to_tensor( + + format="mp4" + , **common_params ) elif method == "to_file_like": file_like = io.BytesIO() - VideoEncoder(frames, frame_rate=30).to_file_like( + VideoEncoder(frames, frame_rate=30, device=device).to_file_like( file_like, format="mp4", **common_params ) return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) From 50cdb21d93251b87d000b97d6ac7f087c015542b Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Mon, 3 Nov 2025 21:04:14 +0000 Subject: [PATCH 02/17] lint --- src/torchcodec/_core/Encoder.cpp | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 9a0525873..ad6688165 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -754,13 +754,12 @@ void VideoEncoder::initializeEncoder( TORCH_CHECK( avFormatContext_->oformat != nullptr, "Output format is null, unable to find default codec."); - avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); - // TODO: merge above logic w this logic // Try to find a hardware-accelerated encoder if not using CPU + avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); if (videoStreamOptions.device.type() != torch::kCPU) { avCodec = deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec).value_or(avCodec); - TORCH_CHECK(avCodec != nullptr, "Video codec not found"); } + TORCH_CHECK(avCodec != nullptr, "Video codec not found"); } From 54d1a1f942bf8ba7c961bcf6e4ea5d2484a2fa08 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Thu, 20 Nov 2025 05:51:50 +0000 Subject: [PATCH 03/17] BT.601, test_nvenc_against_ffmpeg_cli --- .../_core/BetaCudaDeviceInterface.cpp | 10 + .../_core/BetaCudaDeviceInterface.h | 6 + src/torchcodec/_core/CUDACommon.cpp | 77 ++++++++ src/torchcodec/_core/CUDACommon.h | 7 + src/torchcodec/_core/CpuDeviceInterface.cpp | 78 ++++++++ src/torchcodec/_core/CpuDeviceInterface.h | 6 + src/torchcodec/_core/CudaDeviceInterface.cpp | 68 +++++++ src/torchcodec/_core/CudaDeviceInterface.h | 6 + src/torchcodec/_core/DeviceInterface.h | 8 + src/torchcodec/_core/Encoder.cpp | 94 +++------- src/torchcodec/_core/Encoder.h | 3 - src/torchcodec/_core/custom_ops.cpp | 3 + src/torchcodec/_core/ops.py | 4 +- src/torchcodec/encoders/_video_encoder.py | 2 +- test/test_encoders.py | 173 ++++++++++-------- 15 files changed, 395 insertions(+), 150 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 45f6ba1a5..86c7d5e27 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -833,6 +833,16 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } +UniqueAVFrame BetaCudaDeviceInterface::convertTensorToAVFrame( + [[maybe_unused]] const torch::Tensor& tensor, + [[maybe_unused]] AVPixelFormat targetFormat, + [[maybe_unused]] int frameIndex, + [[maybe_unused]] AVCodecContext* codecContext) { + TORCH_CHECK( + false, + "Beta CUDA device interface does not support video encoding currently."); +} + std::string BetaCudaDeviceInterface::getDetails() { std::string details = "Beta CUDA Device Interface."; if (cpuFallback_) { diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index cefb1a983..fba998a50 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -48,6 +48,12 @@ class BetaCudaDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) override; + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) override; + int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; int receiveFrame(UniqueAVFrame& avFrame) override; diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index bbd17db39..d31acc968 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -156,6 +156,21 @@ const Npp32f bt709FullRangeColorTwist[3][4] = { {1.0f, -0.187324273f, -0.468124273f, -128.0f}, {1.0f, 1.8556f, 0.0f, -128.0f}}; +// RGB to NV12 color conversion matrices (inverse of YUV to RGB) +// Note: NPP's ColorTwist function apparently expects "limited range" +// coefficient format even when producing full range output. All matrices below +// use the limited range coefficient format (Y with +16 offset) for NPP +// compatibility. + +// BT.601 limited range (matches FFmpeg default behavior) +const Npp32f defaultLimitedRangeRgbToNv12[3][4] = { + // Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B) + {0.257f, 0.504f, 0.098f, 16.0f}, + // U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients) + {-0.148f, -0.291f, 0.439f, 128.0f}, + // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients) + {0.439f, -0.368f, -0.071f, 128.0f}}; + torch::Tensor convertNV12FrameToRGB( UniqueAVFrame& avFrame, const torch::Device& device, @@ -246,6 +261,68 @@ torch::Tensor convertNV12FrameToRGB( return dst; } +void convertRGBTensorToNV12Frame( + const torch::Tensor& rgbTensor, + UniqueAVFrame& nv12Frame, + const torch::Device& device, + const UniqueNppContext& nppCtx, + at::cuda::CUDAStream inputStream) { + TORCH_CHECK(rgbTensor.is_cuda(), "RGB tensor must be on CUDA device"); + TORCH_CHECK( + rgbTensor.dim() == 3 && rgbTensor.size(0) == 3, + "Expected 3D RGB tensor in CHW format, got shape: ", + rgbTensor.sizes()); + TORCH_CHECK( + nv12Frame != nullptr && nv12Frame->data[0] != nullptr, + "nv12Frame must be pre-allocated with CUDA memory"); + + // Convert CHW to HWC for NPP processing + int height = static_cast(rgbTensor.size(1)); + int width = static_cast(rgbTensor.size(2)); + torch::Tensor hwcFrame = rgbTensor.permute({1, 2, 0}).contiguous(); + + // Set up stream synchronization - make NPP stream wait for input tensor + // operations + at::cuda::CUDAStream nppStream = + at::cuda::getCurrentCUDAStream(device.index()); + at::cuda::CUDAEvent inputDoneEvent; + inputDoneEvent.record(inputStream); + inputDoneEvent.block(nppStream); + + // Setup NPP context + nppCtx->hStream = nppStream.stream(); + cudaError_t cudaErr = + cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags); + TORCH_CHECK( + cudaErr == cudaSuccess, + "cudaStreamGetFlags failed: ", + cudaGetErrorString(cudaErr)); + + // Always use FFmpeg's default behavior: BT.601 limited range + NppiSize oSizeROI = {width, height}; + + NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( + static_cast(hwcFrame.data_ptr()), + hwcFrame.stride(0) * hwcFrame.element_size(), + nv12Frame->data, + nv12Frame->linesize, + oSizeROI, + defaultLimitedRangeRgbToNv12, + *nppCtx); + + TORCH_CHECK( + status == NPP_SUCCESS, + "Failed to convert RGB to NV12: NPP error code ", + status); + + // Validate CUDA operations completed successfully + cudaError_t memCheck = cudaGetLastError(); + TORCH_CHECK( + memCheck == cudaSuccess, + "CUDA error detected: ", + cudaGetErrorString(memCheck)); +} + UniqueNppContext getNppStreamContext(const torch::Device& device) { int deviceIndex = getDeviceIndex(device); diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index 4cc27c23b..15502540f 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -37,6 +37,13 @@ torch::Tensor convertNV12FrameToRGB( at::cuda::CUDAStream nvdecStream, std::optional preAllocatedOutputTensor = std::nullopt); +void convertRGBTensorToNV12Frame( + const torch::Tensor& rgbTensor, + UniqueAVFrame& nv12Frame, + const torch::Device& device, + const UniqueNppContext& nppCtx, + at::cuda::CUDAStream inputStream); + UniqueNppContext getNppStreamContext(const torch::Device& device); void returnNppStreamContextToCache( const torch::Device& device, diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index 70f46b7e4..d7e58cb45 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -429,6 +429,84 @@ std::optional CpuDeviceInterface::maybeFlushAudioBuffers() { /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); } +UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrame( + const torch::Tensor& frame, + AVPixelFormat outPixelFormat, + int frameIndex, + [[maybe_unused]] AVCodecContext* codecContext) { + int inHeight = static_cast(frame.sizes()[1]); + int inWidth = static_cast(frame.sizes()[2]); + + // For now, reuse input dimensions as output dimensions + int outWidth = inWidth; + int outHeight = inHeight; + + // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) + AVPixelFormat inPixelFormat = AV_PIX_FMT_GBRP; + + // Initialize and cache scaling context if it does not exist + if (!swsContext_) { + swsContext_.reset(sws_getContext( + inWidth, + inHeight, + inPixelFormat, + outWidth, + outHeight, + outPixelFormat, + SWS_BICUBIC, // Used by FFmpeg CLI + nullptr, + nullptr, + nullptr)); + TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); + } + + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + + // Set output frame properties + avFrame->format = outPixelFormat; + avFrame->width = outWidth; + avFrame->height = outHeight; + avFrame->pts = frameIndex; + + int status = av_frame_get_buffer(avFrame.get(), 0); + TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); + + // Need to convert/scale the frame + // Create temporary frame with input format + UniqueAVFrame inputFrame(av_frame_alloc()); + TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); + + inputFrame->format = inPixelFormat; + inputFrame->width = inWidth; + inputFrame->height = inHeight; + + uint8_t* tensorData = static_cast(frame.data_ptr()); + + // TODO-VideoEncoder: Reorder tensor if in NHWC format + int channelSize = inHeight * inWidth; + // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format + // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format + inputFrame->data[0] = tensorData + channelSize; + inputFrame->data[1] = tensorData + (2 * channelSize); + inputFrame->data[2] = tensorData; + + inputFrame->linesize[0] = inWidth; + inputFrame->linesize[1] = inWidth; + inputFrame->linesize[2] = inWidth; + + status = sws_scale( + swsContext_.get(), + inputFrame->data, + inputFrame->linesize, + 0, + inputFrame->height, + avFrame->data, + avFrame->linesize); + TORCH_CHECK(status == outHeight, "sws_scale failed"); + return avFrame; +} + std::string CpuDeviceInterface::getDetails() { return std::string("CPU Device Interface."); } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 55e34c3b6..c33d5d051 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -38,6 +38,12 @@ class CpuDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) override; + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) override; + std::string getDetails() override; private: diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 4011c7340..d67773399 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -1,8 +1,10 @@ #include #include +#include #include #include +#include "CUDACommon.h" #include "Cache.h" #include "CudaDeviceInterface.h" #include "FFMPEGCommon.h" @@ -142,6 +144,34 @@ void CudaDeviceInterface::registerHardwareDeviceWithCodec( hardwareDeviceCtx_, "Hardware device context has not been initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); + // is there any way to preserve actual desired format? + // codecContext->sw_pix_fmt = codecContext->pix_fmt; + // Should we always produce AV_PIX_FMT_NV12? + codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; + codecContext->pix_fmt = AV_PIX_FMT_CUDA; + + AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); + TORCH_CHECK( + hwFramesCtxRef != nullptr, + "Failed to allocate hardware frames context for codec"); + + AVHWFramesContext* hwFramesCtx = + reinterpret_cast(hwFramesCtxRef->data); + hwFramesCtx->format = codecContext->pix_fmt; + hwFramesCtx->sw_format = codecContext->sw_pix_fmt; + hwFramesCtx->width = codecContext->width; + hwFramesCtx->height = codecContext->height; + + int ret = av_hwframe_ctx_init(hwFramesCtxRef); + if (ret < 0) { + av_buffer_unref(&hwFramesCtxRef); + TORCH_CHECK( + false, + "Failed to initialize CUDA frames context for codec: ", + getFFMPEGErrorStringFromErrorCode(ret)); + } + + codecContext->hw_frames_ctx = hwFramesCtxRef; } UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( @@ -379,6 +409,44 @@ std::optional CudaDeviceInterface::findDecoder( return std::nullopt; } +UniqueAVFrame CudaDeviceInterface::convertTensorToAVFrame( + const torch::Tensor& frame, + [[maybe_unused]] AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) { + TORCH_CHECK(frame.is_cuda(), "CUDA device interface requires CUDA tensors"); + TORCH_CHECK( + frame.dim() == 3 && frame.size(0) == 3, + "Expected 3D RGB tensor (CHW format), got shape: ", + frame.sizes()); + + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + + avFrame->format = AV_PIX_FMT_CUDA; + avFrame->width = static_cast(frame.size(2)); + avFrame->height = static_cast(frame.size(1)); + avFrame->pts = frameIndex; + + int ret = av_hwframe_get_buffer( + codecContext ? codecContext->hw_frames_ctx : nullptr, avFrame.get(), 0); + TORCH_CHECK( + ret >= 0, + "Failed to allocate hardware frame: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + at::cuda::CUDAStream currentStream = + at::cuda::getCurrentCUDAStream(device_.index()); + + convertRGBTensorToNV12Frame(frame, avFrame, device_, nppCtx_, currentStream); + + // Set color properties to FFmpeg defaults + avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 + avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range + + return avFrame; +} + std::string CudaDeviceInterface::getDetails() { // Note: for this interface specifically the fallback is only known after a // frame has been decoded, not before: that's when FFmpeg decides to fallback, diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 9c0c2fdb9..aee42043b 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -40,6 +40,12 @@ class CudaDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) override; + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) override; + std::string getDetails() override; private: diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 3ef956056..8f3ee0ae8 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -97,6 +97,14 @@ class DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; + // Convert tensor to AVFrame, implemented per device interface. + // This is similar to convertAVFrameToFrameOutput for encoding + virtual UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) = 0; + // ------------------------------------------ // Extension points for custom decoding paths // ------------------------------------------ diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index ad6688165..391eb4778 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -5,6 +5,8 @@ #include "torch/types.h" extern "C" { +#include +#include #include #include } @@ -523,7 +525,9 @@ void AudioEncoder::flushBuffers() { namespace { -torch::Tensor validateFrames(const torch::Tensor& frames) { +torch::Tensor validateFrames( + const torch::Tensor& frames, + const torch::Device& device) { TORCH_CHECK( frames.dtype() == torch::kUInt8, "frames must have uint8 dtype, got ", @@ -536,6 +540,15 @@ torch::Tensor validateFrames(const torch::Tensor& frames) { frames.sizes()[1] == 3, "frame must have 3 channels (R, G, B), got ", frames.sizes()[1]); + if (device.type() != torch::kCPU) { + TORCH_CHECK( + frames.is_cuda(), + "When using CUDA encoding (device=", + device.str(), + "), frames must be on a CUDA device. Got frames on ", + frames.device().str(), + ". Please move frames to a CUDA device: frames.to('cuda')"); + } return frames.contiguous(); } @@ -665,7 +678,8 @@ VideoEncoder::VideoEncoder( int frameRate, std::string_view fileName, const VideoStreamOptions& videoStreamOptions) - : frames_(validateFrames(frames)), inFrameRate_(frameRate) { + : frames_(validateFrames(frames, videoStreamOptions.device)), + inFrameRate_(frameRate) { setFFmpegLogLevel(); // Allocate output format context @@ -698,7 +712,7 @@ VideoEncoder::VideoEncoder( std::string_view formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions) - : frames_(validateFrames(frames)), + : frames_(validateFrames(frames, videoStreamOptions.device)), inFrameRate_(frameRate), avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); @@ -757,10 +771,11 @@ void VideoEncoder::initializeEncoder( // Try to find a hardware-accelerated encoder if not using CPU avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); if (videoStreamOptions.device.type() != torch::kCPU) { - avCodec = deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec).value_or(avCodec); + avCodec = + deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec) + .value_or(avCodec); } TORCH_CHECK(avCodec != nullptr, "Video codec not found"); - } AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); @@ -872,7 +887,8 @@ void VideoEncoder::encode() { int numFrames = static_cast(frames_.sizes()[0]); for (int i = 0; i < numFrames; ++i) { torch::Tensor currFrame = frames_[i]; - UniqueAVFrame avFrame = convertTensorToAVFrame(currFrame, i); + UniqueAVFrame avFrame = deviceInterface_->convertTensorToAVFrame( + currFrame, outPixelFormat_, i, avCodecContext_.get()); encodeFrame(autoAVPacket, avFrame); } @@ -885,72 +901,6 @@ void VideoEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); } -UniqueAVFrame VideoEncoder::convertTensorToAVFrame( - const torch::Tensor& frame, - int frameIndex) { - // Initialize and cache scaling context if it does not exist - if (!swsContext_) { - swsContext_.reset(sws_getContext( - inWidth_, - inHeight_, - inPixelFormat_, - outWidth_, - outHeight_, - outPixelFormat_, - SWS_BICUBIC, // Used by FFmpeg CLI - nullptr, - nullptr, - nullptr)); - TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); - } - - UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); - - // Set output frame properties - avFrame->format = outPixelFormat_; - avFrame->width = outWidth_; - avFrame->height = outHeight_; - avFrame->pts = frameIndex; - - int status = av_frame_get_buffer(avFrame.get(), 0); - TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); - - // Need to convert/scale the frame - // Create temporary frame with input format - UniqueAVFrame inputFrame(av_frame_alloc()); - TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); - - inputFrame->format = inPixelFormat_; - inputFrame->width = inWidth_; - inputFrame->height = inHeight_; - - uint8_t* tensorData = static_cast(frame.data_ptr()); - - // TODO-VideoEncoder: Reorder tensor if in NHWC format - int channelSize = inHeight_ * inWidth_; - // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format - // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format - inputFrame->data[0] = tensorData + channelSize; - inputFrame->data[1] = tensorData + (2 * channelSize); - inputFrame->data[2] = tensorData; - - inputFrame->linesize[0] = inWidth_; - inputFrame->linesize[1] = inWidth_; - inputFrame->linesize[2] = inWidth_; - - status = sws_scale( - swsContext_.get(), - inputFrame->data, - inputFrame->linesize, - 0, - inputFrame->height, - avFrame->data, - avFrame->linesize); - TORCH_CHECK(status == outHeight_, "sws_scale failed"); - return avFrame; -} - torch::Tensor VideoEncoder::encodeToTensor() { TORCH_CHECK( avioContextHolder_ != nullptr, diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index ef17f33af..c04d57e40 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -161,9 +161,6 @@ class VideoEncoder { private: void initializeEncoder(const VideoStreamOptions& videoStreamOptions); - UniqueAVFrame convertTensorToAVFrame( - const torch::Tensor& frame, - int frameIndex); void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void flushBuffers(); diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index a1187b657..12bf80fbf 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -1032,6 +1032,9 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, BackendSelect, m) { m.impl("_create_from_file_like", &_create_from_file_like); m.impl( "_get_json_ffmpeg_library_versions", &_get_json_ffmpeg_library_versions); + m.impl("encode_video_to_file", &encode_video_to_file); + m.impl("encode_video_to_tensor", &encode_video_to_tensor); + m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); } TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index ec1adaea2..d241593d5 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -339,7 +339,7 @@ def encode_video_to_file_abstract( pixel_format: Optional[str] = None, preset: Optional[str] = None, crf: Optional[Union[int, float]] = None, - extra_options: Optional[list[str]] = None = None, + extra_options: Optional[list[str]] = None, ) -> None: return @@ -354,7 +354,7 @@ def encode_video_to_tensor_abstract( pixel_format: Optional[str] = None, preset: Optional[str] = None, crf: Optional[Union[int, float]] = None, - extra_options: Optional[list[str]] = None = None, + extra_options: Optional[list[str]] = None, ) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 36a1320e4..1f9bb284f 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -1,5 +1,5 @@ from pathlib import Path -from typing import Any, Dict, Optional, Optional, Union +from typing import Any, Dict, Optional, Union import torch from torch import device as torch_device, Tensor diff --git a/test/test_encoders.py b/test/test_encoders.py index 6de34e031..7ae067166 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -10,8 +10,6 @@ import pytest import torch from torchcodec.decoders import AudioDecoder, VideoDecoder - -from torchcodec.decoders._video_decoder import VideoDecoder from torchcodec.encoders import AudioEncoder, VideoEncoder from .utils import ( @@ -765,72 +763,6 @@ def test_extra_options_errors(self, method, tmp_path, extra_options, error): ): getattr(encoder, method)(**valid_params, extra_options=extra_options) - @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - @pytest.mark.parametrize( - "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) - ) - def test_pixel_format_errors(self, method, tmp_path): - frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) - encoder = VideoEncoder(frames, frame_rate=30) - - if method == "to_file": - valid_params = dict(dest=str(tmp_path / "output.mp4")) - elif method == "to_tensor": - valid_params = dict(format="mp4") - elif method == "to_file_like": - valid_params = dict(file_like=io.BytesIO(), format="mp4") - - with pytest.raises( - RuntimeError, - match=r"Unknown pixel format: invalid_pix_fmt[\s\S]*Supported pixel formats.*yuv420p", - ): - getattr(encoder, method)(**valid_params, pixel_format="invalid_pix_fmt") - - with pytest.raises( - RuntimeError, - match=r"Specified pixel format rgb24 is not supported[\s\S]*Supported pixel formats.*yuv420p", - ): - getattr(encoder, method)(**valid_params, pixel_format="rgb24") - - @pytest.mark.parametrize( - "extra_options,error", - [ - ({"qp": -10}, "qp=-10 is out of valid range"), - ( - {"qp": ""}, - "Option qp expects a numeric value but got", - ), - ( - {"direct-pred": "a"}, - "Option direct-pred expects a numeric value but got 'a'", - ), - ({"tune": "not_a_real_tune"}, "avcodec_open2 failed: Invalid argument"), - ( - {"tune": 10}, - "avcodec_open2 failed: Invalid argument", - ), - ], - ) - @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_extra_options_errors(self, method, tmp_path, extra_options, error): - frames = torch.zeros((5, 3, 64, 64), dtype=torch.uint8) - encoder = VideoEncoder(frames, frame_rate=30) - - if method == "to_file": - valid_params = dict(dest=str(tmp_path / "output.mp4")) - elif method == "to_tensor": - valid_params = dict(format="mp4") - elif method == "to_file_like": - valid_params = dict(file_like=io.BytesIO(), format="mp4") - else: - raise ValueError(f"Unknown method: {method}") - - with pytest.raises( - RuntimeError, - match=error, - ): - getattr(encoder, method)(**valid_params, extra_options=extra_options) - @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) @pytest.mark.parametrize( "device", ("cpu", pytest.param("cuda", marks=pytest.mark.needs_cuda)) @@ -862,14 +794,14 @@ def encode_to_tensor(frames): common_params = dict(crf=0, pixel_format="yuv444p") if method == "to_file": dest = str(tmp_path / "output.mp4") - VideoEncoder(frames, frame_rate=30, device=device).to_file(dest=dest, **common_params) + VideoEncoder(frames, frame_rate=30, device=device).to_file( + dest=dest, **common_params + ) with open(dest, "rb") as f: return torch.frombuffer(f.read(), dtype=torch.uint8).clone() elif method == "to_tensor": return VideoEncoder(frames, frame_rate=30, device=device).to_tensor( - - format="mp4" - , **common_params + format="mp4", **common_params ) elif method == "to_file_like": file_like = io.BytesIO() @@ -1282,3 +1214,100 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range assert metadata["profile"].lower() == expected_profile assert metadata["color_space"] == colorspace assert metadata["color_range"] == color_range + + @pytest.mark.needs_cuda + @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") + @pytest.mark.parametrize("preset", ("slow", "fast")) + @pytest.mark.parametrize("pixel_format", ("nv12", "yuv420p")) + @pytest.mark.parametrize("format", ("mov", "mp4", "avi", "mkv", "flv")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + def test_nvenc_against_ffmpeg_cli( + self, tmp_path, preset, pixel_format, format, method + ): + # Encode with FFmpeg CLI using h264_nvenc + device = "cuda" + source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device) + + temp_raw_path = str(tmp_path / "temp_input.raw") + with open(temp_raw_path, "wb") as f: + f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) + + ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_nvenc_output.{format}") + frame_rate = 30 + + ffmpeg_cmd = [ + "ffmpeg", + "-y", + "-f", + "rawvideo", + "-pix_fmt", + "rgb24", # Input format + "-s", + f"{source_frames.shape[3]}x{source_frames.shape[2]}", + "-r", + str(frame_rate), + "-i", + temp_raw_path, + "-c:v", + "h264_nvenc", # Use NVENC hardware encoder + ] + + ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) # Output format + ffmpeg_cmd.extend(["-preset", preset]) # Use parametrized preset + ffmpeg_cmd.extend(["-qp", "0"]) # Use lossless qp for consistency + ffmpeg_cmd.extend([ffmpeg_encoded_path]) + + # Will this prevent CI from treating test as failed if NVENC is not available? + try: + subprocess.run(ffmpeg_cmd, check=True, capture_output=True) + except subprocess.CalledProcessError as e: + if b"No NVENC capable devices found" in e.stderr: + pytest.skip("NVENC not available on this system") + else: + raise + + encoder = VideoEncoder( + frames=source_frames, frame_rate=frame_rate, device=device + ) + + encoder_extra_options = {"qp": 0} + if method == "to_file": + encoder_output_path = str(tmp_path / f"nvenc_output.{format}") + encoder.to_file( + dest=encoder_output_path, + codec="h264_nvenc", + pixel_format=pixel_format, + preset=preset, + extra_options=encoder_extra_options, + ) + encoder_output = encoder_output_path + elif method == "to_tensor": + encoder_output = encoder.to_tensor( + format=format, + codec="h264_nvenc", + pixel_format=pixel_format, + preset=preset, + extra_options=encoder_extra_options, + ) + elif method == "to_file_like": + file_like = io.BytesIO() + encoder.to_file_like( + file_like=file_like, + format=format, + codec="h264_nvenc", + pixel_format=pixel_format, + preset=preset, + extra_options=encoder_extra_options, + ) + encoder_output = file_like.getvalue() + else: + raise ValueError(f"Unknown method: {method}") + + ffmpeg_frames = self.decode(ffmpeg_encoded_path).data + encoder_frames = self.decode(encoder_output).data + + assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] + for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): + assert psnr(ff_frame, enc_frame) > 25 + assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=99, atol=10) + assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=95, atol=2) From 88e1299d1e95c24399002812137430893be8c881 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Thu, 20 Nov 2025 14:40:48 +0000 Subject: [PATCH 04/17] remove cuda header from Encoder.cpp --- src/torchcodec/_core/Encoder.cpp | 1 - src/torchcodec/_core/custom_ops.cpp | 9 +++------ 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 391eb4778..d189533df 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -6,7 +6,6 @@ extern "C" { #include -#include #include #include } diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 12bf80fbf..2aaa0e8c4 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -415,7 +415,6 @@ void _add_video_stream( } validateDeviceInterface(std::string(device), std::string(device_variant)); - videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.deviceVariant = device_variant; @@ -648,11 +647,10 @@ void encode_video_to_file( std::optional preset = std::nullopt, std::optional> extra_options = std::nullopt) { VideoStreamOptions videoStreamOptions; + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; - - videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.preset = preset; if (extra_options.has_value()) { @@ -680,11 +678,10 @@ at::Tensor encode_video_to_tensor( std::optional> extra_options = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; - - videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.preset = preset; if (extra_options.has_value()) { @@ -719,10 +716,10 @@ void _encode_video_to_file_like( std::unique_ptr avioContextHolder(fileLikeContext); VideoStreamOptions videoStreamOptions; + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; - videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.preset = preset; if (extra_options.has_value()) { From 926b7ea337c5609e896e27c9c7088dd45d049a86 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Thu, 20 Nov 2025 17:01:32 +0000 Subject: [PATCH 05/17] separate encoding frame ctx init --- src/torchcodec/_core/CudaDeviceInterface.cpp | 6 ++++++ src/torchcodec/_core/CudaDeviceInterface.h | 2 ++ src/torchcodec/_core/DeviceInterface.h | 6 ++++++ src/torchcodec/_core/Encoder.cpp | 3 +++ test/test_encoders.py | 2 +- 5 files changed, 18 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index d67773399..84cefc142 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -144,6 +144,12 @@ void CudaDeviceInterface::registerHardwareDeviceWithCodec( hardwareDeviceCtx_, "Hardware device context has not been initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); +} + +void CudaDeviceInterface::setupEncodingContext(AVCodecContext* codecContext) { + TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); // is there any way to preserve actual desired format? // codecContext->sw_pix_fmt = codecContext->pix_fmt; // Should we always produce AV_PIX_FMT_NV12? diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index aee42043b..83761020c 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -35,6 +35,8 @@ class CudaDeviceInterface : public DeviceInterface { void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override; + void setupEncodingContext(AVCodecContext* codecContext) override; + void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 8f3ee0ae8..2b69dbfc9 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -92,6 +92,12 @@ class DeviceInterface { virtual void registerHardwareDeviceWithCodec( [[maybe_unused]] AVCodecContext* codecContext) {} + // Setup device-specific encoding context (e.g., hardware frame contexts). + // Called after registerHardwareDeviceWithCodec for encoders. + // Default implementation does nothing (suitable for CPU and basic cases). + virtual void setupEncodingContext( + [[maybe_unused]] AVCodecContext* codecContext) {} + virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index d189533df..31eb5212d 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -850,6 +850,9 @@ void VideoEncoder::initializeEncoder( // context before calling avcodec_open2(). deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get()); + // Setup device-specific encoding context (e.g., hardware frame contexts) + deviceInterface_->setupEncodingContext(avCodecContext_.get()); + int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions); av_dict_free(&avCodecOptions); diff --git a/test/test_encoders.py b/test/test_encoders.py index 7ae067166..d79b3cae4 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -775,7 +775,7 @@ def test_contiguity(self, method, tmp_path, device): num_frames, channels, height, width = 5, 3, 256, 256 contiguous_frames = torch.randint( 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 - ).contiguous() + ).contiguous().to(device) assert contiguous_frames.is_contiguous() # Permute NCHW to NHWC, then update the memory layout, then permute back From 43c62217640b6591a7b59ecec0688706e6f8e4c3 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Fri, 21 Nov 2025 23:04:21 +0000 Subject: [PATCH 06/17] lint --- test/test_encoders.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index a477a1c55..eb3d8193f 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -804,9 +804,13 @@ def test_contiguity(self, method, tmp_path, device): # frame tensors, one is contiguous while the other is non-contiguous. num_frames, channels, height, width = 5, 3, 256, 256 - contiguous_frames = torch.randint( - 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 - ).contiguous().to(device) + contiguous_frames = ( + torch.randint( + 0, 256, size=(num_frames, channels, height, width), dtype=torch.uint8 + ) + .contiguous() + .to(device) + ) assert contiguous_frames.is_contiguous() # Permute NCHW to NHWC, then update the memory layout, then permute back From 4af1f53ce5d3aa90a16a2e787c18957c0ca5b727 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Sat, 22 Nov 2025 20:06:46 +0000 Subject: [PATCH 07/17] parametrize other nvenc --- test/test_encoders.py | 41 +++++++++++++++++++++++++++-------------- 1 file changed, 27 insertions(+), 14 deletions(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index eb3d8193f..a612f6df7 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1270,15 +1270,28 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range @pytest.mark.needs_cuda @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") - @pytest.mark.parametrize("preset", ("slow", "fast")) @pytest.mark.parametrize("pixel_format", ("nv12", "yuv420p")) - @pytest.mark.parametrize("format", ("mov", "mp4", "avi", "mkv", "flv")) + @pytest.mark.parametrize("format_codec", [ + ("mov", "h264_nvenc"), + ("mp4", "hevc_nvenc"), + ("avi", "h264_nvenc"), + pytest.param( + ("mkv", "av1_nvenc"), + marks=pytest.mark.skipif( + get_ffmpeg_major_version() <= 5, + reason="av1_nvenc not supported in FFmpeg 4 and 5" + ) + ), + ("flv", "h264_nvenc") + ]) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_nvenc_against_ffmpeg_cli( - self, tmp_path, preset, pixel_format, format, method + self, tmp_path, pixel_format, format_codec, method ): - # Encode with FFmpeg CLI using h264_nvenc + # Encode with FFmpeg CLI using nvenc codecs + format, codec = format_codec device = "cuda" + qp = 1 # Lossless (qp=0) is not supported on av1_nvenc, so we use 1 source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device) temp_raw_path = str(tmp_path / "temp_input.raw") @@ -1302,12 +1315,13 @@ def test_nvenc_against_ffmpeg_cli( "-i", temp_raw_path, "-c:v", - "h264_nvenc", # Use NVENC hardware encoder + codec, # Use specified NVENC hardware encoder ] ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) # Output format - ffmpeg_cmd.extend(["-preset", preset]) # Use parametrized preset - ffmpeg_cmd.extend(["-qp", "0"]) # Use lossless qp for consistency + if codec == "av1_nvenc": + ffmpeg_cmd.extend(["-rc", "constqp"]) # Set rate control mode for AV1 else: + ffmpeg_cmd.extend(["-qp", str(qp)]) # Use lossless qp for other codecs ffmpeg_cmd.extend([ffmpeg_encoded_path]) # Will this prevent CI from treating test as failed if NVENC is not available? @@ -1323,23 +1337,23 @@ def test_nvenc_against_ffmpeg_cli( frames=source_frames, frame_rate=frame_rate, device=device ) - encoder_extra_options = {"qp": 0} + encoder_extra_options = {"qp": qp} + if codec == "av1_nvenc": + encoder_extra_options["rc"] = 0 # constqp mode if method == "to_file": encoder_output_path = str(tmp_path / f"nvenc_output.{format}") encoder.to_file( dest=encoder_output_path, - codec="h264_nvenc", + codec=codec, pixel_format=pixel_format, - preset=preset, extra_options=encoder_extra_options, ) encoder_output = encoder_output_path elif method == "to_tensor": encoder_output = encoder.to_tensor( format=format, - codec="h264_nvenc", + codec=codec, pixel_format=pixel_format, - preset=preset, extra_options=encoder_extra_options, ) elif method == "to_file_like": @@ -1347,9 +1361,8 @@ def test_nvenc_against_ffmpeg_cli( encoder.to_file_like( file_like=file_like, format=format, - codec="h264_nvenc", + codec=codec, pixel_format=pixel_format, - preset=preset, extra_options=encoder_extra_options, ) encoder_output = file_like.getvalue() From bdd133f732588ba7fe53ac80b9879c81b039263e Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Sat, 22 Nov 2025 21:40:31 +0000 Subject: [PATCH 08/17] disable av1_nvenc --- test/test_encoders.py | 28 +++++++++++++--------------- 1 file changed, 13 insertions(+), 15 deletions(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index a612f6df7..1125fd94c 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1271,19 +1271,15 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range @pytest.mark.needs_cuda @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") @pytest.mark.parametrize("pixel_format", ("nv12", "yuv420p")) - @pytest.mark.parametrize("format_codec", [ - ("mov", "h264_nvenc"), - ("mp4", "hevc_nvenc"), - ("avi", "h264_nvenc"), - pytest.param( - ("mkv", "av1_nvenc"), - marks=pytest.mark.skipif( - get_ffmpeg_major_version() <= 5, - reason="av1_nvenc not supported in FFmpeg 4 and 5" - ) - ), - ("flv", "h264_nvenc") - ]) + @pytest.mark.parametrize( + "format_codec", + [ + ("mov", "h264_nvenc"), + ("mp4", "hevc_nvenc"), + ("avi", "h264_nvenc"), + # ("mkv", "av1_nvenc"), # av1_nvenc is not supported on CI + ], + ) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_nvenc_against_ffmpeg_cli( self, tmp_path, pixel_format, format_codec, method @@ -1291,7 +1287,7 @@ def test_nvenc_against_ffmpeg_cli( # Encode with FFmpeg CLI using nvenc codecs format, codec = format_codec device = "cuda" - qp = 1 # Lossless (qp=0) is not supported on av1_nvenc, so we use 1 + qp = 1 # Lossless (qp=0) is not supported on av1_nvenc, so we use 1 source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device) temp_raw_path = str(tmp_path / "temp_input.raw") @@ -1320,7 +1316,9 @@ def test_nvenc_against_ffmpeg_cli( ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) # Output format if codec == "av1_nvenc": - ffmpeg_cmd.extend(["-rc", "constqp"]) # Set rate control mode for AV1 else: + ffmpeg_cmd.extend( + ["-rc", "constqp"] + ) # Set rate control mode for AV1 else: ffmpeg_cmd.extend(["-qp", str(qp)]) # Use lossless qp for other codecs ffmpeg_cmd.extend([ffmpeg_encoded_path]) From 9c7bae7eff4adc4d5f111a82c10c7ab71e0289dd Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Wed, 26 Nov 2025 05:18:12 +0000 Subject: [PATCH 09/17] reduce files affected, add GpuEncoder.cpp --- .../_core/BetaCudaDeviceInterface.cpp | 10 -- .../_core/BetaCudaDeviceInterface.h | 6 - src/torchcodec/_core/CMakeLists.txt | 2 +- src/torchcodec/_core/CpuDeviceInterface.cpp | 78 ------------ src/torchcodec/_core/CpuDeviceInterface.h | 6 - src/torchcodec/_core/CudaDeviceInterface.cpp | 109 +--------------- src/torchcodec/_core/CudaDeviceInterface.h | 9 -- src/torchcodec/_core/DeviceInterface.h | 19 --- src/torchcodec/_core/Encoder.cpp | 118 +++++++++++++++--- src/torchcodec/_core/Encoder.h | 8 ++ 10 files changed, 117 insertions(+), 248 deletions(-) diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp index 86c7d5e27..45f6ba1a5 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.cpp +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.cpp @@ -833,16 +833,6 @@ void BetaCudaDeviceInterface::convertAVFrameToFrameOutput( gpuFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } -UniqueAVFrame BetaCudaDeviceInterface::convertTensorToAVFrame( - [[maybe_unused]] const torch::Tensor& tensor, - [[maybe_unused]] AVPixelFormat targetFormat, - [[maybe_unused]] int frameIndex, - [[maybe_unused]] AVCodecContext* codecContext) { - TORCH_CHECK( - false, - "Beta CUDA device interface does not support video encoding currently."); -} - std::string BetaCudaDeviceInterface::getDetails() { std::string details = "Beta CUDA Device Interface."; if (cpuFallback_) { diff --git a/src/torchcodec/_core/BetaCudaDeviceInterface.h b/src/torchcodec/_core/BetaCudaDeviceInterface.h index fba998a50..cefb1a983 100644 --- a/src/torchcodec/_core/BetaCudaDeviceInterface.h +++ b/src/torchcodec/_core/BetaCudaDeviceInterface.h @@ -48,12 +48,6 @@ class BetaCudaDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) override; - UniqueAVFrame convertTensorToAVFrame( - const torch::Tensor& tensor, - AVPixelFormat targetFormat, - int frameIndex, - AVCodecContext* codecContext) override; - int sendPacket(ReferenceAVPacket& packet) override; int sendEOFPacket() override; int receiveFrame(UniqueAVFrame& avFrame) override; diff --git a/src/torchcodec/_core/CMakeLists.txt b/src/torchcodec/_core/CMakeLists.txt index 4fc7af75e..9b7f18b8f 100644 --- a/src/torchcodec/_core/CMakeLists.txt +++ b/src/torchcodec/_core/CMakeLists.txt @@ -100,7 +100,7 @@ function(make_torchcodec_libraries ) if(ENABLE_CUDA) - list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp) + list(APPEND core_sources CudaDeviceInterface.cpp BetaCudaDeviceInterface.cpp NVDECCache.cpp CUDACommon.cpp NVCUVIDRuntimeLoader.cpp GpuEncoder.cpp) endif() set(core_library_dependencies diff --git a/src/torchcodec/_core/CpuDeviceInterface.cpp b/src/torchcodec/_core/CpuDeviceInterface.cpp index d7e58cb45..70f46b7e4 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.cpp +++ b/src/torchcodec/_core/CpuDeviceInterface.cpp @@ -429,84 +429,6 @@ std::optional CpuDeviceInterface::maybeFlushAudioBuffers() { /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); } -UniqueAVFrame CpuDeviceInterface::convertTensorToAVFrame( - const torch::Tensor& frame, - AVPixelFormat outPixelFormat, - int frameIndex, - [[maybe_unused]] AVCodecContext* codecContext) { - int inHeight = static_cast(frame.sizes()[1]); - int inWidth = static_cast(frame.sizes()[2]); - - // For now, reuse input dimensions as output dimensions - int outWidth = inWidth; - int outHeight = inHeight; - - // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) - AVPixelFormat inPixelFormat = AV_PIX_FMT_GBRP; - - // Initialize and cache scaling context if it does not exist - if (!swsContext_) { - swsContext_.reset(sws_getContext( - inWidth, - inHeight, - inPixelFormat, - outWidth, - outHeight, - outPixelFormat, - SWS_BICUBIC, // Used by FFmpeg CLI - nullptr, - nullptr, - nullptr)); - TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); - } - - UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); - - // Set output frame properties - avFrame->format = outPixelFormat; - avFrame->width = outWidth; - avFrame->height = outHeight; - avFrame->pts = frameIndex; - - int status = av_frame_get_buffer(avFrame.get(), 0); - TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); - - // Need to convert/scale the frame - // Create temporary frame with input format - UniqueAVFrame inputFrame(av_frame_alloc()); - TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); - - inputFrame->format = inPixelFormat; - inputFrame->width = inWidth; - inputFrame->height = inHeight; - - uint8_t* tensorData = static_cast(frame.data_ptr()); - - // TODO-VideoEncoder: Reorder tensor if in NHWC format - int channelSize = inHeight * inWidth; - // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format - // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format - inputFrame->data[0] = tensorData + channelSize; - inputFrame->data[1] = tensorData + (2 * channelSize); - inputFrame->data[2] = tensorData; - - inputFrame->linesize[0] = inWidth; - inputFrame->linesize[1] = inWidth; - inputFrame->linesize[2] = inWidth; - - status = sws_scale( - swsContext_.get(), - inputFrame->data, - inputFrame->linesize, - 0, - inputFrame->height, - avFrame->data, - avFrame->linesize); - TORCH_CHECK(status == outHeight, "sws_scale failed"); - return avFrame; -} - std::string CpuDeviceInterface::getDetails() { return std::string("CPU Device Interface."); } diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index c33d5d051..55e34c3b6 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -38,12 +38,6 @@ class CpuDeviceInterface : public DeviceInterface { FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) override; - UniqueAVFrame convertTensorToAVFrame( - const torch::Tensor& tensor, - AVPixelFormat targetFormat, - int frameIndex, - AVCodecContext* codecContext) override; - std::string getDetails() override; private: diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 84cefc142..26ee20556 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -146,40 +146,6 @@ void CudaDeviceInterface::registerHardwareDeviceWithCodec( codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); } -void CudaDeviceInterface::setupEncodingContext(AVCodecContext* codecContext) { - TORCH_CHECK( - hardwareDeviceCtx_, "Hardware device context has not been initialized"); - TORCH_CHECK(codecContext != nullptr, "codecContext is null"); - // is there any way to preserve actual desired format? - // codecContext->sw_pix_fmt = codecContext->pix_fmt; - // Should we always produce AV_PIX_FMT_NV12? - codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; - codecContext->pix_fmt = AV_PIX_FMT_CUDA; - - AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); - TORCH_CHECK( - hwFramesCtxRef != nullptr, - "Failed to allocate hardware frames context for codec"); - - AVHWFramesContext* hwFramesCtx = - reinterpret_cast(hwFramesCtxRef->data); - hwFramesCtx->format = codecContext->pix_fmt; - hwFramesCtx->sw_format = codecContext->sw_pix_fmt; - hwFramesCtx->width = codecContext->width; - hwFramesCtx->height = codecContext->height; - - int ret = av_hwframe_ctx_init(hwFramesCtxRef); - if (ret < 0) { - av_buffer_unref(&hwFramesCtxRef); - TORCH_CHECK( - false, - "Failed to initialize CUDA frames context for codec: ", - getFFMPEGErrorStringFromErrorCode(ret)); - } - - codecContext->hw_frames_ctx = hwFramesCtxRef; -} - UniqueAVFrame CudaDeviceInterface::maybeConvertAVFrameToNV12OrRGB24( UniqueAVFrame& avFrame) { // We need FFmpeg filters to handle those conversion cases which are not @@ -365,39 +331,10 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( avFrame, device_, nppCtx_, nvdecStream, preAllocatedOutputTensor); } -namespace { -// Helper function to check if a codec supports CUDA hardware acceleration -bool codecSupportsCudaHardware(const AVCodec* codec) { - const AVCodecHWConfig* config = nullptr; - for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; ++j) { - if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { - return true; - } - } - return false; -} -} // namespace - // inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device - -std::optional CudaDeviceInterface::findEncoder( - const AVCodecID& codecId) { - void* i = nullptr; - const AVCodec* codec = nullptr; - while ((codec = av_codec_iterate(&i)) != nullptr) { - if (codec->id != codecId || !av_codec_is_encoder(codec)) { - continue; - } - if (codecSupportsCudaHardware(codec)) { - return codec; - } - } - return std::nullopt; -} - std::optional CudaDeviceInterface::findDecoder( const AVCodecID& codecId) { void* i = nullptr; @@ -407,52 +344,18 @@ std::optional CudaDeviceInterface::findDecoder( continue; } - if (codecSupportsCudaHardware(codec)) { - return codec; + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { + return codec; + } } } return std::nullopt; } -UniqueAVFrame CudaDeviceInterface::convertTensorToAVFrame( - const torch::Tensor& frame, - [[maybe_unused]] AVPixelFormat targetFormat, - int frameIndex, - AVCodecContext* codecContext) { - TORCH_CHECK(frame.is_cuda(), "CUDA device interface requires CUDA tensors"); - TORCH_CHECK( - frame.dim() == 3 && frame.size(0) == 3, - "Expected 3D RGB tensor (CHW format), got shape: ", - frame.sizes()); - - UniqueAVFrame avFrame(av_frame_alloc()); - TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); - - avFrame->format = AV_PIX_FMT_CUDA; - avFrame->width = static_cast(frame.size(2)); - avFrame->height = static_cast(frame.size(1)); - avFrame->pts = frameIndex; - - int ret = av_hwframe_get_buffer( - codecContext ? codecContext->hw_frames_ctx : nullptr, avFrame.get(), 0); - TORCH_CHECK( - ret >= 0, - "Failed to allocate hardware frame: ", - getFFMPEGErrorStringFromErrorCode(ret)); - - at::cuda::CUDAStream currentStream = - at::cuda::getCurrentCUDAStream(device_.index()); - - convertRGBTensorToNV12Frame(frame, avFrame, device_, nppCtx_, currentStream); - - // Set color properties to FFmpeg defaults - avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 - avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range - - return avFrame; -} - std::string CudaDeviceInterface::getDetails() { // Note: for this interface specifically the fallback is only known after a // frame has been decoded, not before: that's when FFmpeg decides to fallback, diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 83761020c..3b44e524d 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -18,7 +18,6 @@ class CudaDeviceInterface : public DeviceInterface { virtual ~CudaDeviceInterface(); - std::optional findEncoder(const AVCodecID& codecId) override; std::optional findDecoder(const AVCodecID& codecId) override; void initialize( @@ -35,19 +34,11 @@ class CudaDeviceInterface : public DeviceInterface { void registerHardwareDeviceWithCodec(AVCodecContext* codecContext) override; - void setupEncodingContext(AVCodecContext* codecContext) override; - void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor) override; - UniqueAVFrame convertTensorToAVFrame( - const torch::Tensor& tensor, - AVPixelFormat targetFormat, - int frameIndex, - AVCodecContext* codecContext) override; - std::string getDetails() override; private: diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 2b69dbfc9..52e97c4cd 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -46,11 +46,6 @@ class DeviceInterface { return device_; }; - virtual std::optional findEncoder( - [[maybe_unused]] const AVCodecID& codecId) { - return std::nullopt; - }; - virtual std::optional findDecoder( [[maybe_unused]] const AVCodecID& codecId) { return std::nullopt; @@ -92,25 +87,11 @@ class DeviceInterface { virtual void registerHardwareDeviceWithCodec( [[maybe_unused]] AVCodecContext* codecContext) {} - // Setup device-specific encoding context (e.g., hardware frame contexts). - // Called after registerHardwareDeviceWithCodec for encoders. - // Default implementation does nothing (suitable for CPU and basic cases). - virtual void setupEncodingContext( - [[maybe_unused]] AVCodecContext* codecContext) {} - virtual void convertAVFrameToFrameOutput( UniqueAVFrame& avFrame, FrameOutput& frameOutput, std::optional preAllocatedOutputTensor = std::nullopt) = 0; - // Convert tensor to AVFrame, implemented per device interface. - // This is similar to convertAVFrameToFrameOutput for encoding - virtual UniqueAVFrame convertTensorToAVFrame( - const torch::Tensor& tensor, - AVPixelFormat targetFormat, - int frameIndex, - AVCodecContext* codecContext) = 0; - // ------------------------------------------ // Extension points for custom decoding paths // ------------------------------------------ diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 4e6fde742..bdc31c753 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -5,7 +5,6 @@ #include "torch/types.h" extern "C" { -#include #include #include } @@ -737,11 +736,9 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { - deviceInterface_ = createDeviceInterface( - videoStreamOptions.device, videoStreamOptions.deviceVariant); - TORCH_CHECK( - deviceInterface_ != nullptr, - "Failed to create device interface. This should never happen, please report."); + if (videoStreamOptions.device.is_cuda()) { + gpuEncoder_ = std::make_unique(videoStreamOptions.device); + } const AVCodec* avCodec = nullptr; // If codec arg is provided, find codec using logic similar to FFmpeg: @@ -769,10 +766,9 @@ void VideoEncoder::initializeEncoder( "Output format is null, unable to find default codec."); // Try to find a hardware-accelerated encoder if not using CPU avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); - if (videoStreamOptions.device.type() != torch::kCPU) { - avCodec = - deviceInterface_->findEncoder(avFormatContext_->oformat->video_codec) - .value_or(avCodec); + if (gpuEncoder_) { + avCodec = gpuEncoder_->findEncoder(avFormatContext_->oformat->video_codec) + .value_or(avCodec); } TORCH_CHECK(avCodec != nullptr, "Video codec not found"); } @@ -848,10 +844,10 @@ void VideoEncoder::initializeEncoder( // Register the hardware device context with the codec // context before calling avcodec_open2(). - deviceInterface_->registerHardwareDeviceWithCodec(avCodecContext_.get()); - - // Setup device-specific encoding context (e.g., hardware frame contexts) - deviceInterface_->setupEncodingContext(avCodecContext_.get()); + if (gpuEncoder_) { + gpuEncoder_->registerHardwareDeviceWithCodec(avCodecContext_.get()); + gpuEncoder_->setupEncodingContext(avCodecContext_.get()); + } int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions); av_dict_free(&avCodecOptions); @@ -893,8 +889,14 @@ void VideoEncoder::encode() { int numFrames = static_cast(frames_.sizes()[0]); for (int i = 0; i < numFrames; ++i) { torch::Tensor currFrame = frames_[i]; - UniqueAVFrame avFrame = deviceInterface_->convertTensorToAVFrame( - currFrame, outPixelFormat_, i, avCodecContext_.get()); + UniqueAVFrame avFrame; + if (gpuEncoder_) { + avFrame = gpuEncoder_->convertTensorToAVFrame( + currFrame, outPixelFormat_, i, avCodecContext_.get()); + } else { + // Use direct CPU conversion for CPU devices + avFrame = convertCpuTensorToAVFrame(currFrame, outPixelFormat_, i); + } encodeFrame(autoAVPacket, avFrame); } @@ -907,6 +909,90 @@ void VideoEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); } +UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex) { + TORCH_CHECK(tensor.is_cpu(), "CPU encoder requires CPU tensors"); + TORCH_CHECK( + tensor.dim() == 3 && tensor.size(0) == 3, + "Expected 3D RGB tensor (CHW format), got shape: ", + tensor.sizes()); + + int inHeight = static_cast(tensor.sizes()[1]); + int inWidth = static_cast(tensor.sizes()[2]); + + // For now, reuse input dimensions as output dimensions + int outWidth = inWidth; + int outHeight = inHeight; + + // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) + AVPixelFormat inPixelFormat = AV_PIX_FMT_GBRP; + + // Initialize and cache scaling context if it does not exist + if (!swsContext_) { + swsContext_.reset(sws_getContext( + inWidth, + inHeight, + inPixelFormat, + outWidth, + outHeight, + targetFormat, + SWS_BICUBIC, // Used by FFmpeg CLI + nullptr, + nullptr, + nullptr)); + TORCH_CHECK(swsContext_ != nullptr, "Failed to create scaling context"); + } + + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + + // Set output frame properties + avFrame->format = targetFormat; + avFrame->width = outWidth; + avFrame->height = outHeight; + avFrame->pts = frameIndex; + + int status = av_frame_get_buffer(avFrame.get(), 0); + TORCH_CHECK(status >= 0, "Failed to allocate frame buffer"); + + // Need to convert/scale the frame + // Create temporary frame with input format + UniqueAVFrame inputFrame(av_frame_alloc()); + TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); + + inputFrame->format = inPixelFormat; + inputFrame->width = inWidth; + inputFrame->height = inHeight; + + uint8_t* tensorData = static_cast(tensor.data_ptr()); + + // TODO-VideoEncoder: Reorder tensor if in NHWC format + int channelSize = inHeight * inWidth; + // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format + // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format + inputFrame->data[0] = tensorData + channelSize; // G channel + inputFrame->data[1] = tensorData + (2 * channelSize); // B channel + inputFrame->data[2] = tensorData; // R channel + + inputFrame->linesize[0] = inWidth; + inputFrame->linesize[1] = inWidth; + inputFrame->linesize[2] = inWidth; + + status = sws_scale( + swsContext_.get(), + inputFrame->data, + inputFrame->linesize, + 0, + inputFrame->height, + avFrame->data, + avFrame->linesize); + TORCH_CHECK(status == outHeight, "sws_scale failed"); + + return avFrame; +} + torch::Tensor VideoEncoder::encodeToTensor() { TORCH_CHECK( avioContextHolder_ != nullptr, diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index c32e44943..63c19d746 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -5,6 +5,7 @@ #include "AVIOContextHolder.h" #include "DeviceInterface.h" #include "FFMPEGCommon.h" +#include "GpuEncoder.h" #include "StreamOptions.h" extern "C" { @@ -164,6 +165,12 @@ class VideoEncoder { void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void flushBuffers(); + // CPU tensor-to-frame conversion for CPU encoding + UniqueAVFrame convertCpuTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex); + UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; AVStream* avStream_ = nullptr; @@ -182,6 +189,7 @@ class VideoEncoder { std::unique_ptr avioContextHolder_; std::unique_ptr deviceInterface_; + std::unique_ptr gpuEncoder_; bool encodeWasCalled_ = false; AVDictionary* avFormatOptions_ = nullptr; From bf784688225bbdc635d011d313aea766895cd633 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Wed, 26 Nov 2025 06:05:08 +0000 Subject: [PATCH 10/17] actually add GpuEncoder.cpp --- src/torchcodec/_core/Encoder.cpp | 47 ++++--- src/torchcodec/_core/Encoder.h | 1 - src/torchcodec/_core/GpuEncoder.cpp | 194 ++++++++++++++++++++++++++++ src/torchcodec/_core/GpuEncoder.h | 58 +++++++++ 4 files changed, 275 insertions(+), 25 deletions(-) create mode 100644 src/torchcodec/_core/GpuEncoder.cpp create mode 100644 src/torchcodec/_core/GpuEncoder.h diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index bdc31c753..2f948ee42 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -895,7 +895,7 @@ void VideoEncoder::encode() { currFrame, outPixelFormat_, i, avCodecContext_.get()); } else { // Use direct CPU conversion for CPU devices - avFrame = convertCpuTensorToAVFrame(currFrame, outPixelFormat_, i); + avFrame = convertCpuTensorToAVFrame(currFrame, i); } encodeFrame(autoAVPacket, avFrame); } @@ -911,7 +911,6 @@ void VideoEncoder::encode() { UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( const torch::Tensor& tensor, - AVPixelFormat targetFormat, int frameIndex) { TORCH_CHECK(tensor.is_cpu(), "CPU encoder requires CPU tensors"); TORCH_CHECK( @@ -919,25 +918,25 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( "Expected 3D RGB tensor (CHW format), got shape: ", tensor.sizes()); - int inHeight = static_cast(tensor.sizes()[1]); - int inWidth = static_cast(tensor.sizes()[2]); + inHeight_ = static_cast(tensor.sizes()[1]); + inWidth_ = static_cast(tensor.sizes()[2]); // For now, reuse input dimensions as output dimensions - int outWidth = inWidth; - int outHeight = inHeight; + outWidth_ = inWidth_; + outHeight_ = inHeight_; // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) - AVPixelFormat inPixelFormat = AV_PIX_FMT_GBRP; + inPixelFormat_ = AV_PIX_FMT_GBRP; // Initialize and cache scaling context if it does not exist if (!swsContext_) { swsContext_.reset(sws_getContext( - inWidth, - inHeight, - inPixelFormat, - outWidth, - outHeight, - targetFormat, + inWidth_, + inHeight_, + inPixelFormat_, + outWidth_, + outHeight_, + outPixelFormat_, SWS_BICUBIC, // Used by FFmpeg CLI nullptr, nullptr, @@ -949,9 +948,9 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); // Set output frame properties - avFrame->format = targetFormat; - avFrame->width = outWidth; - avFrame->height = outHeight; + avFrame->format = outPixelFormat_; + avFrame->width = outWidth_; + avFrame->height = outHeight_; avFrame->pts = frameIndex; int status = av_frame_get_buffer(avFrame.get(), 0); @@ -962,23 +961,23 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( UniqueAVFrame inputFrame(av_frame_alloc()); TORCH_CHECK(inputFrame != nullptr, "Failed to allocate input AVFrame"); - inputFrame->format = inPixelFormat; - inputFrame->width = inWidth; - inputFrame->height = inHeight; + inputFrame->format = inPixelFormat_; + inputFrame->width = inWidth_; + inputFrame->height = inHeight_; uint8_t* tensorData = static_cast(tensor.data_ptr()); // TODO-VideoEncoder: Reorder tensor if in NHWC format - int channelSize = inHeight * inWidth; + int channelSize = inHeight_ * inWidth_; // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format inputFrame->data[0] = tensorData + channelSize; // G channel inputFrame->data[1] = tensorData + (2 * channelSize); // B channel inputFrame->data[2] = tensorData; // R channel - inputFrame->linesize[0] = inWidth; - inputFrame->linesize[1] = inWidth; - inputFrame->linesize[2] = inWidth; + inputFrame->linesize[0] = inWidth_; + inputFrame->linesize[1] = inWidth_; + inputFrame->linesize[2] = inWidth_; status = sws_scale( swsContext_.get(), @@ -988,7 +987,7 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( inputFrame->height, avFrame->data, avFrame->linesize); - TORCH_CHECK(status == outHeight, "sws_scale failed"); + TORCH_CHECK(status == outHeight_, "sws_scale failed"); return avFrame; } diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 63c19d746..cfb4dc1a6 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -168,7 +168,6 @@ class VideoEncoder { // CPU tensor-to-frame conversion for CPU encoding UniqueAVFrame convertCpuTensorToAVFrame( const torch::Tensor& tensor, - AVPixelFormat targetFormat, int frameIndex); UniqueEncodingAVFormatContext avFormatContext_; diff --git a/src/torchcodec/_core/GpuEncoder.cpp b/src/torchcodec/_core/GpuEncoder.cpp new file mode 100644 index 000000000..5cb1b2741 --- /dev/null +++ b/src/torchcodec/_core/GpuEncoder.cpp @@ -0,0 +1,194 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include "GpuEncoder.h" + +#include +#include +#include +#include + +#include "CUDACommon.h" +#include "FFMPEGCommon.h" + +extern "C" { +#include +#include +} + +namespace facebook::torchcodec { +namespace { + +// Redefinition from CudaDeviceInterface.cpp anonymous namespace +int getFlagsAVHardwareDeviceContextCreate() { +#if LIBAVUTIL_VERSION_INT >= AV_VERSION_INT(58, 26, 100) + return AV_CUDA_USE_CURRENT_CONTEXT; +#else + return 0; +#endif +} + +// Redefinition from CudaDeviceInterface.cpp anonymous namespace +// TODO-VideoEncoder: unify device context creation, add caching to encoder +UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) { + enum AVHWDeviceType type = av_hwdevice_find_type_by_name("cuda"); + TORCH_CHECK(type != AV_HWDEVICE_TYPE_NONE, "Failed to find cuda device"); + + int deviceIndex = getDeviceIndex(device); + + c10::cuda::CUDAGuard deviceGuard(device); + // We set the device because we may be called from a different thread than + // the one that initialized the cuda context. + TORCH_CHECK( + cudaSetDevice(deviceIndex) == cudaSuccess, "Failed to set CUDA device"); + + AVBufferRef* hardwareDeviceCtxRaw = nullptr; + std::string deviceOrdinal = std::to_string(deviceIndex); + + int err = av_hwdevice_ctx_create( + &hardwareDeviceCtxRaw, + type, + deviceOrdinal.c_str(), + nullptr, + getFlagsAVHardwareDeviceContextCreate()); + + if (err < 0) { + /* clang-format off */ + TORCH_CHECK( + false, + "Failed to create specified HW device. This typically happens when ", + "your installed FFmpeg doesn't support CUDA (see ", + "https://github.com/pytorch/torchcodec#installing-cuda-enabled-torchcodec", + "). FFmpeg error: ", getFFMPEGErrorStringFromErrorCode(err)); + /* clang-format on */ + } + + return UniqueAVBufferRef(hardwareDeviceCtxRaw); +} + +} // anonymous namespace + +GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) { + TORCH_CHECK( + device_.type() == torch::kCUDA, "Unsupported device: ", device_.str()); + + initializeCudaContextWithPytorch(device_); + initializeHardwareContext(); +} + +GpuEncoder::~GpuEncoder() {} + +void GpuEncoder::initializeHardwareContext() { + hardwareDeviceCtx_ = createHardwareDeviceContext(device_); + nppCtx_ = getNppStreamContext(device_); +} + +std::optional GpuEncoder::findEncoder( + const AVCodecID& codecId) { + void* i = nullptr; + const AVCodec* codec = nullptr; + while ((codec = av_codec_iterate(&i)) != nullptr) { + if (codec->id != codecId || !av_codec_is_encoder(codec)) { + continue; + } + + const AVCodecHWConfig* config = nullptr; + for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; + ++j) { + if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { + return codec; + } + } + } + return std::nullopt; +} + +void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) { + TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); +} + +void GpuEncoder::setupEncodingContext(AVCodecContext* codecContext) { + TORCH_CHECK( + hardwareDeviceCtx_, "Hardware device context has not been initialized"); + TORCH_CHECK(codecContext != nullptr, "codecContext is null"); + + codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; + codecContext->pix_fmt = AV_PIX_FMT_CUDA; + + AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); + TORCH_CHECK( + hwFramesCtxRef != nullptr, + "Failed to allocate hardware frames context for codec"); + + AVHWFramesContext* hwFramesCtx = + reinterpret_cast(hwFramesCtxRef->data); + hwFramesCtx->format = codecContext->pix_fmt; + hwFramesCtx->sw_format = codecContext->sw_pix_fmt; + hwFramesCtx->width = codecContext->width; + hwFramesCtx->height = codecContext->height; + + int ret = av_hwframe_ctx_init(hwFramesCtxRef); + if (ret < 0) { + av_buffer_unref(&hwFramesCtxRef); + TORCH_CHECK( + false, + "Failed to initialize CUDA frames context for codec: ", + getFFMPEGErrorStringFromErrorCode(ret)); + } + + codecContext->hw_frames_ctx = hwFramesCtxRef; +} + +UniqueAVFrame GpuEncoder::convertTensorToAVFrame( + const torch::Tensor& tensor, + [[maybe_unused]] AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext) { + TORCH_CHECK(tensor.is_cuda(), "GpuEncoder requires CUDA tensors"); + TORCH_CHECK( + tensor.dim() == 3 && tensor.size(0) == 3, + "Expected 3D RGB tensor (CHW format), got shape: ", + tensor.sizes()); + + return convertRGBTensorToNV12Frame(tensor, frameIndex, codecContext); +} + +UniqueAVFrame GpuEncoder::convertRGBTensorToNV12Frame( + const torch::Tensor& tensor, + int frameIndex, + AVCodecContext* codecContext) { + UniqueAVFrame avFrame(av_frame_alloc()); + TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + + avFrame->format = AV_PIX_FMT_CUDA; + avFrame->width = static_cast(tensor.size(2)); + avFrame->height = static_cast(tensor.size(1)); + avFrame->pts = frameIndex; + + int ret = av_hwframe_get_buffer( + codecContext ? codecContext->hw_frames_ctx : nullptr, avFrame.get(), 0); + TORCH_CHECK( + ret >= 0, + "Failed to allocate hardware frame: ", + getFFMPEGErrorStringFromErrorCode(ret)); + + at::cuda::CUDAStream currentStream = + at::cuda::getCurrentCUDAStream(device_.index()); + + facebook::torchcodec::convertRGBTensorToNV12Frame( + tensor, avFrame, device_, nppCtx_, currentStream); + + // Set color properties to FFmpeg defaults + avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 + avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range + + return avFrame; +} + +} // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/GpuEncoder.h b/src/torchcodec/_core/GpuEncoder.h new file mode 100644 index 000000000..7c0940e98 --- /dev/null +++ b/src/torchcodec/_core/GpuEncoder.h @@ -0,0 +1,58 @@ +// Copyright (c) Meta Platforms, Inc. and affiliates. +// All rights reserved. +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#pragma once + +#include +#include +#include + +#include "CUDACommon.h" +#include "FFMPEGCommon.h" +#include "StreamOptions.h" + +extern "C" { +#include +#include +#include +} + +namespace facebook::torchcodec { + +class GpuEncoder { + public: + explicit GpuEncoder(const torch::Device& device); + ~GpuEncoder(); + + std::optional findEncoder(const AVCodecID& codecId); + void registerHardwareDeviceWithCodec(AVCodecContext* codecContext); + void setupEncodingContext(AVCodecContext* codecContext); + + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& tensor, + AVPixelFormat targetFormat, + int frameIndex, + AVCodecContext* codecContext); + + const torch::Device& device() const { + return device_; + } + + private: + torch::Device device_; + UniqueAVBufferRef hardwareDeviceCtx_; + UniqueNppContext nppCtx_; + + void initializeHardwareContext(); + void setupHardwareFrameContext(AVCodecContext* codecContext); + + UniqueAVFrame convertRGBTensorToNV12Frame( + const torch::Tensor& tensor, + int frameIndex, + AVCodecContext* codecContext); +}; + +} // namespace facebook::torchcodec From 7e5e6d4432cbd5e22f6d3503835d46db404a4ca3 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Wed, 26 Nov 2025 23:44:12 +0000 Subject: [PATCH 11/17] move more encoding to gpuEncoder.cpp, reduce diff --- src/torchcodec/_core/CUDACommon.cpp | 77 -------------------- src/torchcodec/_core/CUDACommon.h | 7 -- src/torchcodec/_core/CpuDeviceInterface.h | 5 ++ src/torchcodec/_core/CudaDeviceInterface.cpp | 2 +- src/torchcodec/_core/CudaDeviceInterface.h | 2 +- src/torchcodec/_core/DeviceInterface.h | 2 +- src/torchcodec/_core/Encoder.cpp | 37 +++++----- src/torchcodec/_core/Encoder.h | 9 +-- src/torchcodec/_core/FFMPEGCommon.cpp | 2 +- src/torchcodec/_core/GpuEncoder.cpp | 70 +++++++++++++++--- src/torchcodec/_core/GpuEncoder.h | 6 -- src/torchcodec/_core/SingleStreamDecoder.cpp | 2 +- src/torchcodec/encoders/_video_encoder.py | 1 - 13 files changed, 90 insertions(+), 132 deletions(-) diff --git a/src/torchcodec/_core/CUDACommon.cpp b/src/torchcodec/_core/CUDACommon.cpp index d31acc968..bbd17db39 100644 --- a/src/torchcodec/_core/CUDACommon.cpp +++ b/src/torchcodec/_core/CUDACommon.cpp @@ -156,21 +156,6 @@ const Npp32f bt709FullRangeColorTwist[3][4] = { {1.0f, -0.187324273f, -0.468124273f, -128.0f}, {1.0f, 1.8556f, 0.0f, -128.0f}}; -// RGB to NV12 color conversion matrices (inverse of YUV to RGB) -// Note: NPP's ColorTwist function apparently expects "limited range" -// coefficient format even when producing full range output. All matrices below -// use the limited range coefficient format (Y with +16 offset) for NPP -// compatibility. - -// BT.601 limited range (matches FFmpeg default behavior) -const Npp32f defaultLimitedRangeRgbToNv12[3][4] = { - // Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B) - {0.257f, 0.504f, 0.098f, 16.0f}, - // U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients) - {-0.148f, -0.291f, 0.439f, 128.0f}, - // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients) - {0.439f, -0.368f, -0.071f, 128.0f}}; - torch::Tensor convertNV12FrameToRGB( UniqueAVFrame& avFrame, const torch::Device& device, @@ -261,68 +246,6 @@ torch::Tensor convertNV12FrameToRGB( return dst; } -void convertRGBTensorToNV12Frame( - const torch::Tensor& rgbTensor, - UniqueAVFrame& nv12Frame, - const torch::Device& device, - const UniqueNppContext& nppCtx, - at::cuda::CUDAStream inputStream) { - TORCH_CHECK(rgbTensor.is_cuda(), "RGB tensor must be on CUDA device"); - TORCH_CHECK( - rgbTensor.dim() == 3 && rgbTensor.size(0) == 3, - "Expected 3D RGB tensor in CHW format, got shape: ", - rgbTensor.sizes()); - TORCH_CHECK( - nv12Frame != nullptr && nv12Frame->data[0] != nullptr, - "nv12Frame must be pre-allocated with CUDA memory"); - - // Convert CHW to HWC for NPP processing - int height = static_cast(rgbTensor.size(1)); - int width = static_cast(rgbTensor.size(2)); - torch::Tensor hwcFrame = rgbTensor.permute({1, 2, 0}).contiguous(); - - // Set up stream synchronization - make NPP stream wait for input tensor - // operations - at::cuda::CUDAStream nppStream = - at::cuda::getCurrentCUDAStream(device.index()); - at::cuda::CUDAEvent inputDoneEvent; - inputDoneEvent.record(inputStream); - inputDoneEvent.block(nppStream); - - // Setup NPP context - nppCtx->hStream = nppStream.stream(); - cudaError_t cudaErr = - cudaStreamGetFlags(nppCtx->hStream, &nppCtx->nStreamFlags); - TORCH_CHECK( - cudaErr == cudaSuccess, - "cudaStreamGetFlags failed: ", - cudaGetErrorString(cudaErr)); - - // Always use FFmpeg's default behavior: BT.601 limited range - NppiSize oSizeROI = {width, height}; - - NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( - static_cast(hwcFrame.data_ptr()), - hwcFrame.stride(0) * hwcFrame.element_size(), - nv12Frame->data, - nv12Frame->linesize, - oSizeROI, - defaultLimitedRangeRgbToNv12, - *nppCtx); - - TORCH_CHECK( - status == NPP_SUCCESS, - "Failed to convert RGB to NV12: NPP error code ", - status); - - // Validate CUDA operations completed successfully - cudaError_t memCheck = cudaGetLastError(); - TORCH_CHECK( - memCheck == cudaSuccess, - "CUDA error detected: ", - cudaGetErrorString(memCheck)); -} - UniqueNppContext getNppStreamContext(const torch::Device& device) { int deviceIndex = getDeviceIndex(device); diff --git a/src/torchcodec/_core/CUDACommon.h b/src/torchcodec/_core/CUDACommon.h index 15502540f..4cc27c23b 100644 --- a/src/torchcodec/_core/CUDACommon.h +++ b/src/torchcodec/_core/CUDACommon.h @@ -37,13 +37,6 @@ torch::Tensor convertNV12FrameToRGB( at::cuda::CUDAStream nvdecStream, std::optional preAllocatedOutputTensor = std::nullopt); -void convertRGBTensorToNV12Frame( - const torch::Tensor& rgbTensor, - UniqueAVFrame& nv12Frame, - const torch::Device& device, - const UniqueNppContext& nppCtx, - at::cuda::CUDAStream inputStream); - UniqueNppContext getNppStreamContext(const torch::Device& device); void returnNppStreamContextToCache( const torch::Device& device, diff --git a/src/torchcodec/_core/CpuDeviceInterface.h b/src/torchcodec/_core/CpuDeviceInterface.h index 55e34c3b6..801b83826 100644 --- a/src/torchcodec/_core/CpuDeviceInterface.h +++ b/src/torchcodec/_core/CpuDeviceInterface.h @@ -18,6 +18,11 @@ class CpuDeviceInterface : public DeviceInterface { virtual ~CpuDeviceInterface() {} + std::optional findCodec( + [[maybe_unused]] const AVCodecID& codecId) override { + return std::nullopt; + } + virtual void initialize( const AVStream* avStream, const UniqueDecodingAVFormatContext& avFormatCtx, diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index 26ee20556..ecb360ba7 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -335,7 +335,7 @@ void CudaDeviceInterface::convertAVFrameToFrameOutput( // we have to do this because of an FFmpeg bug where hardware decoding is not // appropriately set, so we just go off and find the matching codec for the CUDA // device -std::optional CudaDeviceInterface::findDecoder( +std::optional CudaDeviceInterface::findCodec( const AVCodecID& codecId) { void* i = nullptr; const AVCodec* codec = nullptr; diff --git a/src/torchcodec/_core/CudaDeviceInterface.h b/src/torchcodec/_core/CudaDeviceInterface.h index 3b44e524d..c892bd49b 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.h +++ b/src/torchcodec/_core/CudaDeviceInterface.h @@ -18,7 +18,7 @@ class CudaDeviceInterface : public DeviceInterface { virtual ~CudaDeviceInterface(); - std::optional findDecoder(const AVCodecID& codecId) override; + std::optional findCodec(const AVCodecID& codecId) override; void initialize( const AVStream* avStream, diff --git a/src/torchcodec/_core/DeviceInterface.h b/src/torchcodec/_core/DeviceInterface.h index 52e97c4cd..319fe01a8 100644 --- a/src/torchcodec/_core/DeviceInterface.h +++ b/src/torchcodec/_core/DeviceInterface.h @@ -46,7 +46,7 @@ class DeviceInterface { return device_; }; - virtual std::optional findDecoder( + virtual std::optional findCodec( [[maybe_unused]] const AVCodecID& codecId) { return std::nullopt; }; diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 2f948ee42..89511436d 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -894,8 +894,7 @@ void VideoEncoder::encode() { avFrame = gpuEncoder_->convertTensorToAVFrame( currFrame, outPixelFormat_, i, avCodecContext_.get()); } else { - // Use direct CPU conversion for CPU devices - avFrame = convertCpuTensorToAVFrame(currFrame, i); + avFrame = convertTensorToAVFrame(currFrame, i); } encodeFrame(autoAVPacket, avFrame); } @@ -909,24 +908,25 @@ void VideoEncoder::encode() { getFFMPEGErrorStringFromErrorCode(status)); } -UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( - const torch::Tensor& tensor, +UniqueAVFrame VideoEncoder::convertTensorToAVFrame( + const torch::Tensor& frame, int frameIndex) { - TORCH_CHECK(tensor.is_cpu(), "CPU encoder requires CPU tensors"); + TORCH_CHECK(frame.is_cpu(), "CPU encoder requires CPU tensors"); TORCH_CHECK( - tensor.dim() == 3 && tensor.size(0) == 3, + frame.dim() == 3 && frame.size(0) == 3, "Expected 3D RGB tensor (CHW format), got shape: ", - tensor.sizes()); + frame.sizes()); - inHeight_ = static_cast(tensor.sizes()[1]); - inWidth_ = static_cast(tensor.sizes()[2]); + // These are all already set in initializeEncoder? + // inHeight_ = static_cast(tensor.sizes()[1]); + // inWidth_ = static_cast(tensor.sizes()[2]); - // For now, reuse input dimensions as output dimensions - outWidth_ = inWidth_; - outHeight_ = inHeight_; + // // For now, reuse input dimensions as output dimensions + // outWidth_ = inWidth_; + // outHeight_ = inHeight_; - // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) - inPixelFormat_ = AV_PIX_FMT_GBRP; + // // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) + // inPixelFormat_ = AV_PIX_FMT_GBRP; // Initialize and cache scaling context if it does not exist if (!swsContext_) { @@ -965,15 +965,15 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( inputFrame->width = inWidth_; inputFrame->height = inHeight_; - uint8_t* tensorData = static_cast(tensor.data_ptr()); + uint8_t* tensorData = static_cast(frame.data_ptr()); // TODO-VideoEncoder: Reorder tensor if in NHWC format int channelSize = inHeight_ * inWidth_; // Reorder RGB -> GBR for AV_PIX_FMT_GBRP format // TODO-VideoEncoder: Determine if FFmpeg supports planar RGB input format - inputFrame->data[0] = tensorData + channelSize; // G channel - inputFrame->data[1] = tensorData + (2 * channelSize); // B channel - inputFrame->data[2] = tensorData; // R channel + inputFrame->data[0] = tensorData + channelSize; + inputFrame->data[1] = tensorData + (2 * channelSize); + inputFrame->data[2] = tensorData; inputFrame->linesize[0] = inWidth_; inputFrame->linesize[1] = inWidth_; @@ -988,7 +988,6 @@ UniqueAVFrame VideoEncoder::convertCpuTensorToAVFrame( avFrame->data, avFrame->linesize); TORCH_CHECK(status == outHeight_, "sws_scale failed"); - return avFrame; } diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index cfb4dc1a6..fe3284737 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -162,14 +162,12 @@ class VideoEncoder { private: void initializeEncoder(const VideoStreamOptions& videoStreamOptions); + UniqueAVFrame convertTensorToAVFrame( + const torch::Tensor& frame, + int frameIndex); void encodeFrame(AutoAVPacket& autoAVPacket, const UniqueAVFrame& avFrame); void flushBuffers(); - // CPU tensor-to-frame conversion for CPU encoding - UniqueAVFrame convertCpuTensorToAVFrame( - const torch::Tensor& tensor, - int frameIndex); - UniqueEncodingAVFormatContext avFormatContext_; UniqueAVCodecContext avCodecContext_; AVStream* avStream_ = nullptr; @@ -187,7 +185,6 @@ class VideoEncoder { AVPixelFormat outPixelFormat_ = AV_PIX_FMT_NONE; std::unique_ptr avioContextHolder_; - std::unique_ptr deviceInterface_; std::unique_ptr gpuEncoder_; bool encodeWasCalled_ = false; diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 14ddd3e6e..e1b88b36a 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -40,7 +40,7 @@ AVPacket* ReferenceAVPacket::operator->() { AVCodecOnlyUseForCallingAVFindBestStream makeAVCodecOnlyUseForCallingAVFindBestStream(const AVCodec* codec) { -#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) // FFmpeg < 5.0.3 +#if LIBAVCODEC_VERSION_INT < AV_VERSION_INT(59, 18, 100) return const_cast(codec); #else return codec; diff --git a/src/torchcodec/_core/GpuEncoder.cpp b/src/torchcodec/_core/GpuEncoder.cpp index 5cb1b2741..54efa4b4d 100644 --- a/src/torchcodec/_core/GpuEncoder.cpp +++ b/src/torchcodec/_core/GpuEncoder.cpp @@ -69,6 +69,20 @@ UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) { return UniqueAVBufferRef(hardwareDeviceCtxRaw); } +// RGB to NV12 color conversion matrices (inverse of YUV to RGB) +// Note: NPP's ColorTwist function apparently expects "limited range" +// coefficient format even when producing full range output. All matrices below +// use the limited range coefficient format (Y with +16 offset) for NPP +// compatibility. + +// BT.601 limited range (matches FFmpeg default behavior) +const Npp32f defaultLimitedRangeRgbToNv12[3][4] = { + // Y = 16 + 0.859 * (0.299*R + 0.587*G + 0.114*B) + {0.257f, 0.504f, 0.098f, 16.0f}, + // U = -0.148*R - 0.291*G + 0.439*B + 128 (BT.601 coefficients) + {-0.148f, -0.291f, 0.439f, 128.0f}, + // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients) + {0.439f, -0.368f, -0.071f, 128.0f}}; } // anonymous namespace GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) { @@ -155,14 +169,6 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame( tensor.dim() == 3 && tensor.size(0) == 3, "Expected 3D RGB tensor (CHW format), got shape: ", tensor.sizes()); - - return convertRGBTensorToNV12Frame(tensor, frameIndex, codecContext); -} - -UniqueAVFrame GpuEncoder::convertRGBTensorToNV12Frame( - const torch::Tensor& tensor, - int frameIndex, - AVCodecContext* codecContext) { UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); @@ -178,13 +184,55 @@ UniqueAVFrame GpuEncoder::convertRGBTensorToNV12Frame( "Failed to allocate hardware frame: ", getFFMPEGErrorStringFromErrorCode(ret)); + // Validate that avFrame was properly allocated with CUDA memory + TORCH_CHECK( + avFrame != nullptr && avFrame->data[0] != nullptr, + "avFrame must be pre-allocated with CUDA memory"); + + // Convert CHW to HWC for NPP processing + int height = static_cast(tensor.size(1)); + int width = static_cast(tensor.size(2)); + torch::Tensor hwcFrame = tensor.permute({1, 2, 0}).contiguous(); + + // Get current CUDA stream for NPP operations at::cuda::CUDAStream currentStream = at::cuda::getCurrentCUDAStream(device_.index()); - facebook::torchcodec::convertRGBTensorToNV12Frame( - tensor, avFrame, device_, nppCtx_, currentStream); + // Setup NPP context with current stream + nppCtx_->hStream = currentStream.stream(); + cudaError_t cudaErr = + cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags); + TORCH_CHECK( + cudaErr == cudaSuccess, + "cudaStreamGetFlags failed: ", + cudaGetErrorString(cudaErr)); + + // Always use FFmpeg's default behavior: BT.601 limited range + NppiSize oSizeROI = {width, height}; + + NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( + static_cast(hwcFrame.data_ptr()), + hwcFrame.stride(0) * hwcFrame.element_size(), + avFrame->data, + avFrame->linesize, + oSizeROI, + defaultLimitedRangeRgbToNv12, + *nppCtx_); + + TORCH_CHECK( + status == NPP_SUCCESS, + "Failed to convert RGB to NV12: NPP error code ", + status); + + // Validate CUDA operations completed successfully + cudaError_t memCheck = cudaGetLastError(); + TORCH_CHECK( + memCheck == cudaSuccess, + "CUDA error detected: ", + cudaGetErrorString(memCheck)); - // Set color properties to FFmpeg defaults + // TODO-VideoEncoder: Enable configuration of color properties, similar to + // FFmpeg Set color properties to FFmpeg defaults avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range diff --git a/src/torchcodec/_core/GpuEncoder.h b/src/torchcodec/_core/GpuEncoder.h index 7c0940e98..f8066f4df 100644 --- a/src/torchcodec/_core/GpuEncoder.h +++ b/src/torchcodec/_core/GpuEncoder.h @@ -47,12 +47,6 @@ class GpuEncoder { UniqueNppContext nppCtx_; void initializeHardwareContext(); - void setupHardwareFrameContext(AVCodecContext* codecContext); - - UniqueAVFrame convertRGBTensorToNV12Frame( - const torch::Tensor& tensor, - int frameIndex, - AVCodecContext* codecContext); }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 5177c49f0..6968a4b3f 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -462,7 +462,7 @@ void SingleStreamDecoder::addStream( // addStream() which is supposed to be generic if (mediaType == AVMEDIA_TYPE_VIDEO) { avCodec = makeAVCodecOnlyUseForCallingAVFindBestStream( - deviceInterface_->findDecoder(streamInfo.stream->codecpar->codec_id) + deviceInterface_->findCodec(streamInfo.stream->codecpar->codec_id) .value_or(avCodec)); } diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 754b3bcdc..15791830b 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -18,7 +18,6 @@ class VideoEncoder: frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. device (str or torch.device, optional): The device to use for encoding. Default: "cpu". If you pass a CUDA device, frames will be encoded on GPU. - Note: The "beta" CUDA backend is not supported for encoding. """ def __init__( From ffbdf4e2e62c4b537a7c0c3d1e2a77e3b226c6f3 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Thu, 27 Nov 2025 03:34:34 +0000 Subject: [PATCH 12/17] remove device arg instead use frames device --- src/torchcodec/_core/CudaDeviceInterface.cpp | 2 -- src/torchcodec/_core/custom_ops.cpp | 15 ++++++--------- src/torchcodec/_core/ops.py | 8 +------- src/torchcodec/encoders/_video_encoder.py | 14 ++------------ test/test_encoders.py | 12 ++++-------- 5 files changed, 13 insertions(+), 38 deletions(-) diff --git a/src/torchcodec/_core/CudaDeviceInterface.cpp b/src/torchcodec/_core/CudaDeviceInterface.cpp index ecb360ba7..0e20c5e8d 100644 --- a/src/torchcodec/_core/CudaDeviceInterface.cpp +++ b/src/torchcodec/_core/CudaDeviceInterface.cpp @@ -1,10 +1,8 @@ #include #include -#include #include #include -#include "CUDACommon.h" #include "Cache.h" #include "CudaDeviceInterface.h" #include "FFMPEGCommon.h" diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 750b52ae8..961d78291 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, float frame_rate, str filename, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + "encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, float frame_rate, str format, str device=\"cpu\", str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str device=\"cpu\",str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -640,14 +640,13 @@ void encode_video_to_file( const at::Tensor& frames, double frame_rate, std::string_view file_name, - std::string_view device = "cpu", std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, std::optional preset = std::nullopt, std::optional> extra_options = std::nullopt) { VideoStreamOptions videoStreamOptions; - videoStreamOptions.device = torch::Device(std::string(device)); + videoStreamOptions.device = frames.device(); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; @@ -665,7 +664,6 @@ at::Tensor encode_video_to_tensor( const at::Tensor& frames, double frame_rate, std::string_view format, - std::string_view device = "cpu", std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, @@ -673,7 +671,7 @@ at::Tensor encode_video_to_tensor( std::optional> extra_options = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; - videoStreamOptions.device = torch::Device(std::string(device)); + videoStreamOptions.device = frames.device(); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; @@ -698,7 +696,6 @@ void _encode_video_to_file_like( double frame_rate, std::string_view format, int64_t file_like_context, - std::string_view device = "cpu", std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, @@ -711,7 +708,7 @@ void _encode_video_to_file_like( std::unique_ptr avioContextHolder(fileLikeContext); VideoStreamOptions videoStreamOptions; - videoStreamOptions.device = torch::Device(std::string(device)); + videoStreamOptions.device = frames.device(); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 5d85bd5d6..8dd50c99d 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -213,7 +213,6 @@ def encode_video_to_file_like( frame_rate: float, format: str, file_like: Union[io.RawIOBase, io.BufferedIOBase], - device: Optional[str] = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, @@ -223,11 +222,10 @@ def encode_video_to_file_like( """Encode video frames to a file-like object. Args: - frames: Video frames tensor + frames: Video frames tensor. The device of the frames tensor will be used for encoding. frame_rate: Frame rate in frames per second format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods - device: Device to use for encoding (default: "cpu") codec: Optional codec name (e.g., "libx264", "h264") pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p") crf: Optional constant rate factor for encoding quality @@ -241,7 +239,6 @@ def encode_video_to_file_like( frame_rate, format, _pybind_ops.create_file_like_context(file_like, True), # True means for writing - device, codec, pixel_format, crf, @@ -334,7 +331,6 @@ def encode_video_to_file_abstract( frames: torch.Tensor, frame_rate: float, filename: str, - device: str = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, preset: Optional[str] = None, @@ -349,7 +345,6 @@ def encode_video_to_tensor_abstract( frames: torch.Tensor, frame_rate: float, format: str, - device: str = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, preset: Optional[str] = None, @@ -365,7 +360,6 @@ def _encode_video_to_file_like_abstract( frame_rate: float, format: str, file_like_context: int, - device: str = "cpu", codec: Optional[str] = None, pixel_format: Optional[str] = None, preset: Optional[str] = None, diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 15791830b..cbcd56740 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional, Union import torch -from torch import device as torch_device, Tensor +from torch import Tensor from torchcodec import _core @@ -15,9 +15,8 @@ class VideoEncoder: tensor of shape ``(N, C, H, W)`` where N is the number of frames, C is 3 channels (RGB), H is height, and W is width. Values must be uint8 in the range ``[0, 255]``. + The device of the frames tensor will be used for encoding. frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. - device (str or torch.device, optional): The device to use for encoding. Default: "cpu". - If you pass a CUDA device, frames will be encoded on GPU. """ def __init__( @@ -25,7 +24,6 @@ def __init__( frames: Tensor, *, frame_rate: float, - device: Optional[Union[str, torch_device]] = "cpu", ): torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") if not isinstance(frames, Tensor): @@ -37,13 +35,8 @@ def __init__( if frame_rate <= 0: raise ValueError(f"{frame_rate = } must be > 0.") - # Validate and store device - if isinstance(device, torch_device): - device = str(device) - self._frames = frames self._frame_rate = frame_rate - self._device = device def to_file( self, @@ -86,7 +79,6 @@ def to_file( frames=self._frames, frame_rate=self._frame_rate, filename=str(dest), - device=self._device, codec=codec, pixel_format=pixel_format, crf=crf, @@ -139,7 +131,6 @@ def to_tensor( frames=self._frames, frame_rate=self._frame_rate, format=format, - device=self._device, codec=codec, pixel_format=pixel_format, crf=crf, @@ -196,7 +187,6 @@ def to_file_like( frame_rate=self._frame_rate, format=format, file_like=file_like, - device=self._device, codec=codec, pixel_format=pixel_format, crf=crf, diff --git a/test/test_encoders.py b/test/test_encoders.py index 1125fd94c..bb8c777b4 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -829,18 +829,16 @@ def encode_to_tensor(frames): common_params = dict(crf=0, pixel_format="yuv444p") if method == "to_file": dest = str(tmp_path / "output.mp4") - VideoEncoder(frames, frame_rate=30, device=device).to_file( - dest=dest, **common_params - ) + VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params) with open(dest, "rb") as f: return torch.frombuffer(f.read(), dtype=torch.uint8).clone() elif method == "to_tensor": - return VideoEncoder(frames, frame_rate=30, device=device).to_tensor( + return VideoEncoder(frames, frame_rate=30).to_tensor( format="mp4", **common_params ) elif method == "to_file_like": file_like = io.BytesIO() - VideoEncoder(frames, frame_rate=30, device=device).to_file_like( + VideoEncoder(frames, frame_rate=30).to_file_like( file_like, format="mp4", **common_params ) return torch.frombuffer(file_like.getvalue(), dtype=torch.uint8) @@ -1331,9 +1329,7 @@ def test_nvenc_against_ffmpeg_cli( else: raise - encoder = VideoEncoder( - frames=source_frames, frame_rate=frame_rate, device=device - ) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) encoder_extra_options = {"qp": qp} if codec == "av1_nvenc": From 15e23f6827b30dfa414fc545f83628ceb07bf530 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Thu, 27 Nov 2025 04:01:50 +0000 Subject: [PATCH 13/17] nits --- src/torchcodec/_core/Encoder.cpp | 11 ----------- src/torchcodec/_core/custom_ops.cpp | 1 + src/torchcodec/encoders/_video_encoder.py | 7 +------ test/test_encoders.py | 1 + 4 files changed, 3 insertions(+), 17 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 89511436d..8091f61bf 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -917,17 +917,6 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame( "Expected 3D RGB tensor (CHW format), got shape: ", frame.sizes()); - // These are all already set in initializeEncoder? - // inHeight_ = static_cast(tensor.sizes()[1]); - // inWidth_ = static_cast(tensor.sizes()[2]); - - // // For now, reuse input dimensions as output dimensions - // outWidth_ = inWidth_; - // outHeight_ = inHeight_; - - // // Input format is RGB planar (AV_PIX_FMT_GBRP after channel reordering) - // inPixelFormat_ = AV_PIX_FMT_GBRP; - // Initialize and cache scaling context if it does not exist if (!swsContext_) { swsContext_.reset(sws_getContext( diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 961d78291..d30e0734d 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -415,6 +415,7 @@ void _add_video_stream( } validateDeviceInterface(std::string(device), std::string(device_variant)); + videoStreamOptions.device = torch::Device(std::string(device)); videoStreamOptions.deviceVariant = device_variant; diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index cbcd56740..a240052a6 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -19,12 +19,7 @@ class VideoEncoder: frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. """ - def __init__( - self, - frames: Tensor, - *, - frame_rate: float, - ): + def __init__(self, frames: Tensor, *, frame_rate: float): torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") if not isinstance(frames, Tensor): raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.") diff --git a/test/test_encoders.py b/test/test_encoders.py index bb8c777b4..9b45b499f 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -10,6 +10,7 @@ import pytest import torch from torchcodec.decoders import AudioDecoder, VideoDecoder + from torchcodec.encoders import AudioEncoder, VideoEncoder from .utils import ( From 2342837e9f2e11e4263ce861d0167f1e7824cb47 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Thu, 27 Nov 2025 04:05:27 +0000 Subject: [PATCH 14/17] remove repeat torch_check --- src/torchcodec/_core/Encoder.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 8091f61bf..a86e32b63 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -912,11 +912,6 @@ UniqueAVFrame VideoEncoder::convertTensorToAVFrame( const torch::Tensor& frame, int frameIndex) { TORCH_CHECK(frame.is_cpu(), "CPU encoder requires CPU tensors"); - TORCH_CHECK( - frame.dim() == 3 && frame.size(0) == 3, - "Expected 3D RGB tensor (CHW format), got shape: ", - frame.sizes()); - // Initialize and cache scaling context if it does not exist if (!swsContext_) { swsContext_.reset(sws_getContext( From f8b64ea58297655f62da3e23d1ca8a6a4dd80ce6 Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Mon, 1 Dec 2025 22:17:15 +0000 Subject: [PATCH 15/17] feedback --- src/torchcodec/_core/Encoder.cpp | 31 +++--------- src/torchcodec/_core/GpuEncoder.cpp | 70 ++++++++++------------------ src/torchcodec/_core/GpuEncoder.h | 3 +- src/torchcodec/_core/StreamOptions.h | 2 + src/torchcodec/_core/custom_ops.cpp | 3 -- test/test_encoders.py | 21 +++++---- 6 files changed, 45 insertions(+), 85 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index a86e32b63..225f7ecd0 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -523,9 +523,7 @@ void AudioEncoder::flushBuffers() { namespace { -torch::Tensor validateFrames( - const torch::Tensor& frames, - const torch::Device& device) { +torch::Tensor validateFrames(const torch::Tensor& frames) { TORCH_CHECK( frames.dtype() == torch::kUInt8, "frames must have uint8 dtype, got ", @@ -538,15 +536,6 @@ torch::Tensor validateFrames( frames.sizes()[1] == 3, "frame must have 3 channels (R, G, B), got ", frames.sizes()[1]); - if (device.type() != torch::kCPU) { - TORCH_CHECK( - frames.is_cuda(), - "When using CUDA encoding (device=", - device.str(), - "), frames must be on a CUDA device. Got frames on ", - frames.device().str(), - ". Please move frames to a CUDA device: frames.to('cuda')"); - } return frames.contiguous(); } @@ -676,8 +665,7 @@ VideoEncoder::VideoEncoder( double frameRate, std::string_view fileName, const VideoStreamOptions& videoStreamOptions) - : frames_(validateFrames(frames, videoStreamOptions.device)), - inFrameRate_(frameRate) { + : frames_(validateFrames(frames)), inFrameRate_(frameRate) { setFFmpegLogLevel(); // Allocate output format context @@ -710,7 +698,7 @@ VideoEncoder::VideoEncoder( std::string_view formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions) - : frames_(validateFrames(frames, videoStreamOptions.device)), + : frames_(validateFrames(frames)), inFrameRate_(frameRate), avioContextHolder_(std::move(avioContextHolder)) { setFFmpegLogLevel(); @@ -736,8 +724,8 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { - if (videoStreamOptions.device.is_cuda()) { - gpuEncoder_ = std::make_unique(videoStreamOptions.device); + if (frames_.device().is_cuda()) { + gpuEncoder_ = std::make_unique(frames_.device()); } const AVCodec* avCodec = nullptr; @@ -764,12 +752,7 @@ void VideoEncoder::initializeEncoder( TORCH_CHECK( avFormatContext_->oformat != nullptr, "Output format is null, unable to find default codec."); - // Try to find a hardware-accelerated encoder if not using CPU avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); - if (gpuEncoder_) { - avCodec = gpuEncoder_->findEncoder(avFormatContext_->oformat->video_codec) - .value_or(avCodec); - } TORCH_CHECK(avCodec != nullptr, "Video codec not found"); } @@ -842,11 +825,9 @@ void VideoEncoder::initializeEncoder( 0); } - // Register the hardware device context with the codec - // context before calling avcodec_open2(). if (gpuEncoder_) { gpuEncoder_->registerHardwareDeviceWithCodec(avCodecContext_.get()); - gpuEncoder_->setupEncodingContext(avCodecContext_.get()); + gpuEncoder_->setupHardwareFrameContext(avCodecContext_.get()); } int status = avcodec_open2(avCodecContext_.get(), avCodec, &avCodecOptions); diff --git a/src/torchcodec/_core/GpuEncoder.cpp b/src/torchcodec/_core/GpuEncoder.cpp index 54efa4b4d..f36e94fae 100644 --- a/src/torchcodec/_core/GpuEncoder.cpp +++ b/src/torchcodec/_core/GpuEncoder.cpp @@ -100,26 +100,6 @@ void GpuEncoder::initializeHardwareContext() { nppCtx_ = getNppStreamContext(device_); } -std::optional GpuEncoder::findEncoder( - const AVCodecID& codecId) { - void* i = nullptr; - const AVCodec* codec = nullptr; - while ((codec = av_codec_iterate(&i)) != nullptr) { - if (codec->id != codecId || !av_codec_is_encoder(codec)) { - continue; - } - - const AVCodecHWConfig* config = nullptr; - for (int j = 0; (config = avcodec_get_hw_config(codec, j)) != nullptr; - ++j) { - if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { - return codec; - } - } - } - return std::nullopt; -} - void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) { TORCH_CHECK( hardwareDeviceCtx_, "Hardware device context has not been initialized"); @@ -127,19 +107,25 @@ void GpuEncoder::registerHardwareDeviceWithCodec(AVCodecContext* codecContext) { codecContext->hw_device_ctx = av_buffer_ref(hardwareDeviceCtx_.get()); } -void GpuEncoder::setupEncodingContext(AVCodecContext* codecContext) { +// Allocates and initializes AVHWFramesContext, and sets pixel format fields +// to enable encoding with CUDA device. The hw_frames_ctx field is needed by +// FFmpeg to allocate frames on GPU's memory. +void GpuEncoder::setupHardwareFrameContext(AVCodecContext* codecContext) { TORCH_CHECK( hardwareDeviceCtx_, "Hardware device context has not been initialized"); TORCH_CHECK(codecContext != nullptr, "codecContext is null"); - codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; - codecContext->pix_fmt = AV_PIX_FMT_CUDA; - AVBufferRef* hwFramesCtxRef = av_hwframe_ctx_alloc(hardwareDeviceCtx_.get()); TORCH_CHECK( hwFramesCtxRef != nullptr, "Failed to allocate hardware frames context for codec"); + // Always set pixel formats to options that support CUDA encoding. + // TODO-VideoEncoder: Enable user set pixel formats to be set and properly + // converted with npp functions below + codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; + codecContext->pix_fmt = AV_PIX_FMT_CUDA; + AVHWFramesContext* hwFramesCtx = reinterpret_cast(hwFramesCtxRef->data); hwFramesCtx->format = codecContext->pix_fmt; @@ -164,41 +150,44 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame( [[maybe_unused]] AVPixelFormat targetFormat, int frameIndex, AVCodecContext* codecContext) { - TORCH_CHECK(tensor.is_cuda(), "GpuEncoder requires CUDA tensors"); + TORCH_CHECK( + tensor.is_cuda(), + "Frame tensor is not stored on GPU, but the GPU method convertTensorToAVFrame was called."); TORCH_CHECK( tensor.dim() == 3 && tensor.size(0) == 3, "Expected 3D RGB tensor (CHW format), got shape: ", tensor.sizes()); + + // TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); + int height = static_cast(tensor.size(1)); + int width = static_cast(tensor.size(2)); avFrame->format = AV_PIX_FMT_CUDA; - avFrame->width = static_cast(tensor.size(2)); - avFrame->height = static_cast(tensor.size(1)); + avFrame->height = height; + avFrame->width = width; avFrame->pts = frameIndex; - int ret = av_hwframe_get_buffer( - codecContext ? codecContext->hw_frames_ctx : nullptr, avFrame.get(), 0); + // FFmpeg's av_hwframe_get_buffer is used to allocate memory on CUDA device. + // TODO-VideoEncoder: Consider using pytorch to allocate CUDA memory for + // efficiency + int ret = + av_hwframe_get_buffer(codecContext->hw_frames_ctx, avFrame.get(), 0); TORCH_CHECK( ret >= 0, "Failed to allocate hardware frame: ", getFFMPEGErrorStringFromErrorCode(ret)); - // Validate that avFrame was properly allocated with CUDA memory TORCH_CHECK( avFrame != nullptr && avFrame->data[0] != nullptr, "avFrame must be pre-allocated with CUDA memory"); - // Convert CHW to HWC for NPP processing - int height = static_cast(tensor.size(1)); - int width = static_cast(tensor.size(2)); torch::Tensor hwcFrame = tensor.permute({1, 2, 0}).contiguous(); - // Get current CUDA stream for NPP operations at::cuda::CUDAStream currentStream = at::cuda::getCurrentCUDAStream(device_.index()); - // Setup NPP context with current stream nppCtx_->hStream = currentStream.stream(); cudaError_t cudaErr = cudaStreamGetFlags(nppCtx_->hStream, &nppCtx_->nStreamFlags); @@ -207,9 +196,7 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame( "cudaStreamGetFlags failed: ", cudaGetErrorString(cudaErr)); - // Always use FFmpeg's default behavior: BT.601 limited range NppiSize oSizeROI = {width, height}; - NppStatus status = nppiRGBToNV12_8u_ColorTwist32f_C3P2R_Ctx( static_cast(hwcFrame.data_ptr()), hwcFrame.stride(0) * hwcFrame.element_size(), @@ -224,15 +211,8 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame( "Failed to convert RGB to NV12: NPP error code ", status); - // Validate CUDA operations completed successfully - cudaError_t memCheck = cudaGetLastError(); - TORCH_CHECK( - memCheck == cudaSuccess, - "CUDA error detected: ", - cudaGetErrorString(memCheck)); - // TODO-VideoEncoder: Enable configuration of color properties, similar to - // FFmpeg Set color properties to FFmpeg defaults + // FFmpeg. Below are the default color properties used by FFmpeg. avFrame->colorspace = AVCOL_SPC_SMPTE170M; // BT.601 avFrame->color_range = AVCOL_RANGE_MPEG; // Limited range diff --git a/src/torchcodec/_core/GpuEncoder.h b/src/torchcodec/_core/GpuEncoder.h index f8066f4df..a5a6ad68c 100644 --- a/src/torchcodec/_core/GpuEncoder.h +++ b/src/torchcodec/_core/GpuEncoder.h @@ -27,9 +27,8 @@ class GpuEncoder { explicit GpuEncoder(const torch::Device& device); ~GpuEncoder(); - std::optional findEncoder(const AVCodecID& codecId); void registerHardwareDeviceWithCodec(AVCodecContext* codecContext); - void setupEncodingContext(AVCodecContext* codecContext); + void setupHardwareFrameContext(AVCodecContext* codecContext); UniqueAVFrame convertTensorToAVFrame( const torch::Tensor& tensor, diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index ce0f27d3b..9faafb502 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -41,6 +41,8 @@ struct VideoStreamOptions { ColorConversionLibrary::FILTERGRAPH; // By default we use CPU for decoding for both C++ and python users. + // Note: For video encoding, device is determined by the location of the input + // frame tensor. torch::Device device = torch::kCPU; // Device variant (e.g., "ffmpeg", "beta", etc.) std::string_view deviceVariant = "ffmpeg"; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index d30e0734d..7030928a5 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -647,7 +647,6 @@ void encode_video_to_file( std::optional preset = std::nullopt, std::optional> extra_options = std::nullopt) { VideoStreamOptions videoStreamOptions; - videoStreamOptions.device = frames.device(); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; @@ -672,7 +671,6 @@ at::Tensor encode_video_to_tensor( std::optional> extra_options = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; - videoStreamOptions.device = frames.device(); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; @@ -709,7 +707,6 @@ void _encode_video_to_file_like( std::unique_ptr avioContextHolder(fileLikeContext); VideoStreamOptions videoStreamOptions; - videoStreamOptions.device = frames.device(); videoStreamOptions.codec = std::move(codec); videoStreamOptions.pixelFormat = std::move(pixel_format); videoStreamOptions.crf = crf; diff --git a/test/test_encoders.py b/test/test_encoders.py index 9b45b499f..6d36eb3de 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -827,12 +827,16 @@ def test_contiguity(self, method, tmp_path, device): ) def encode_to_tensor(frames): - common_params = dict(crf=0, pixel_format="yuv444p") + common_params = dict( + crf=0, + pixel_format="yuv444p", + codec="h264_nvenc" if device != "cpu" else None, + ) if method == "to_file": dest = str(tmp_path / "output.mp4") VideoEncoder(frames, frame_rate=30).to_file(dest=dest, **common_params) with open(dest, "rb") as f: - return torch.frombuffer(f.read(), dtype=torch.uint8).clone() + return torch.frombuffer(f.read(), dtype=torch.uint8) elif method == "to_tensor": return VideoEncoder(frames, frame_rate=30).to_tensor( format="mp4", **common_params @@ -1269,7 +1273,6 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range @pytest.mark.needs_cuda @pytest.mark.skipif(in_fbcode(), reason="ffmpeg CLI not available") - @pytest.mark.parametrize("pixel_format", ("nv12", "yuv420p")) @pytest.mark.parametrize( "format_codec", [ @@ -1280,12 +1283,12 @@ def test_extra_options_utilized(self, tmp_path, profile, colorspace, color_range ], ) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) - def test_nvenc_against_ffmpeg_cli( - self, tmp_path, pixel_format, format_codec, method - ): + # TODO-VideoEncoder: Enable additional pixel formats ("yuv420p", "yuv444p") + def test_nvenc_against_ffmpeg_cli(self, tmp_path, format_codec, method): # Encode with FFmpeg CLI using nvenc codecs format, codec = format_codec device = "cuda" + pixel_format = "nv12" qp = 1 # Lossless (qp=0) is not supported on av1_nvenc, so we use 1 source_frames = self.decode(TEST_SRC_2_720P.path).data.to(device) @@ -1315,13 +1318,11 @@ def test_nvenc_against_ffmpeg_cli( ffmpeg_cmd.extend(["-pix_fmt", pixel_format]) # Output format if codec == "av1_nvenc": - ffmpeg_cmd.extend( - ["-rc", "constqp"] - ) # Set rate control mode for AV1 else: + ffmpeg_cmd.extend(["-rc", "constqp"]) # Set rate control mode for AV1 ffmpeg_cmd.extend(["-qp", str(qp)]) # Use lossless qp for other codecs ffmpeg_cmd.extend([ffmpeg_encoded_path]) - # Will this prevent CI from treating test as failed if NVENC is not available? + # TODO-VideoEncoder: Ensure CI does not skip this test, as we know NVENC is available. try: subprocess.run(ffmpeg_cmd, check=True, capture_output=True) except subprocess.CalledProcessError as e: From 6d88905a6e521a4e6dfea328e4fb45c44a5f13be Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Mon, 1 Dec 2025 22:33:39 +0000 Subject: [PATCH 16/17] additional small fixes --- src/torchcodec/_core/Encoder.cpp | 2 -- src/torchcodec/_core/GpuEncoder.cpp | 13 +++++-------- test/test_encoders.py | 1 - 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 225f7ecd0..2569a72ab 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -727,7 +727,6 @@ void VideoEncoder::initializeEncoder( if (frames_.device().is_cuda()) { gpuEncoder_ = std::make_unique(frames_.device()); } - const AVCodec* avCodec = nullptr; // If codec arg is provided, find codec using logic similar to FFmpeg: // https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835 @@ -892,7 +891,6 @@ void VideoEncoder::encode() { UniqueAVFrame VideoEncoder::convertTensorToAVFrame( const torch::Tensor& frame, int frameIndex) { - TORCH_CHECK(frame.is_cpu(), "CPU encoder requires CPU tensors"); // Initialize and cache scaling context if it does not exist if (!swsContext_) { swsContext_.reset(sws_getContext( diff --git a/src/torchcodec/_core/GpuEncoder.cpp b/src/torchcodec/_core/GpuEncoder.cpp index f36e94fae..5e80a1e06 100644 --- a/src/torchcodec/_core/GpuEncoder.cpp +++ b/src/torchcodec/_core/GpuEncoder.cpp @@ -71,8 +71,8 @@ UniqueAVBufferRef createHardwareDeviceContext(const torch::Device& device) { // RGB to NV12 color conversion matrices (inverse of YUV to RGB) // Note: NPP's ColorTwist function apparently expects "limited range" -// coefficient format even when producing full range output. All matrices below -// use the limited range coefficient format (Y with +16 offset) for NPP +// coefficient format even when producing full range output. The matrix below +// uses the limited range coefficient format (Y with +16 offset) for NPP // compatibility. // BT.601 limited range (matches FFmpeg default behavior) @@ -83,7 +83,7 @@ const Npp32f defaultLimitedRangeRgbToNv12[3][4] = { {-0.148f, -0.291f, 0.439f, 128.0f}, // V = 0.439*R - 0.368*G - 0.071*B + 128 (BT.601 coefficients) {0.439f, -0.368f, -0.071f, 128.0f}}; -} // anonymous namespace +} // namespace GpuEncoder::GpuEncoder(const torch::Device& device) : device_(device) { TORCH_CHECK( @@ -122,7 +122,7 @@ void GpuEncoder::setupHardwareFrameContext(AVCodecContext* codecContext) { // Always set pixel formats to options that support CUDA encoding. // TODO-VideoEncoder: Enable user set pixel formats to be set and properly - // converted with npp functions below + // handled with NPP functions below codecContext->sw_pix_fmt = AV_PIX_FMT_NV12; codecContext->pix_fmt = AV_PIX_FMT_CUDA; @@ -150,20 +150,17 @@ UniqueAVFrame GpuEncoder::convertTensorToAVFrame( [[maybe_unused]] AVPixelFormat targetFormat, int frameIndex, AVCodecContext* codecContext) { - TORCH_CHECK( - tensor.is_cuda(), - "Frame tensor is not stored on GPU, but the GPU method convertTensorToAVFrame was called."); TORCH_CHECK( tensor.dim() == 3 && tensor.size(0) == 3, "Expected 3D RGB tensor (CHW format), got shape: ", tensor.sizes()); - // TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method UniqueAVFrame avFrame(av_frame_alloc()); TORCH_CHECK(avFrame != nullptr, "Failed to allocate AVFrame"); int height = static_cast(tensor.size(1)); int width = static_cast(tensor.size(2)); + // TODO-VideoEncoder: Unify AVFrame creation with CPU version of this method avFrame->format = AV_PIX_FMT_CUDA; avFrame->height = height; avFrame->width = width; diff --git a/test/test_encoders.py b/test/test_encoders.py index 6d36eb3de..7761daf5a 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1371,5 +1371,4 @@ def test_nvenc_against_ffmpeg_cli(self, tmp_path, format_codec, method): assert ffmpeg_frames.shape[0] == encoder_frames.shape[0] for ff_frame, enc_frame in zip(ffmpeg_frames, encoder_frames): assert psnr(ff_frame, enc_frame) > 25 - assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=99, atol=10) assert_tensor_close_on_at_least(ff_frame, enc_frame, percentage=95, atol=2) From 6b8c1fe44e8970fac92b042e5b4dbe7f8048210b Mon Sep 17 00:00:00 2001 From: Dan-Flores Date: Tue, 2 Dec 2025 14:55:39 +0000 Subject: [PATCH 17/17] add in_fbcode --- test/test_encoders.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/test_encoders.py b/test/test_encoders.py index f10c9d695..157616a38 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -17,6 +17,7 @@ assert_tensor_close_on_at_least, get_ffmpeg_major_version, get_ffmpeg_minor_version, + in_fbcode, IS_WINDOWS, NASA_AUDIO_MP3, needs_ffmpeg_cli,