diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 90e326087..89ad380d8 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -687,9 +687,33 @@ VideoEncoder::VideoEncoder( void VideoEncoder::initializeEncoder( const VideoStreamOptions& videoStreamOptions) { - const AVCodec* avCodec = - avcodec_find_encoder(avFormatContext_->oformat->video_codec); - TORCH_CHECK(avCodec != nullptr, "Video codec not found"); + const AVCodec* avCodec = nullptr; + // If codec arg is provided, find codec using logic similar to FFmpeg: + // https://github.com/FFmpeg/FFmpeg/blob/master/fftools/ffmpeg_opt.c#L804-L835 + if (videoStreamOptions.codec.has_value()) { + const std::string& codec = videoStreamOptions.codec.value(); + // Try to find codec by name ("libx264", "libsvtav1") + avCodec = avcodec_find_encoder_by_name(codec.c_str()); + // Try to find by codec descriptor ("h264", "av1") + if (!avCodec) { + const AVCodecDescriptor* desc = + avcodec_descriptor_get_by_name(codec.c_str()); + if (desc) { + avCodec = avcodec_find_encoder(desc->id); + } + } + TORCH_CHECK( + avCodec != nullptr, + "Video codec ", + codec, + " not found. To see available codecs, run: ffmpeg -encoders"); + } else { + TORCH_CHECK( + avFormatContext_->oformat != nullptr, + "Output format is null, unable to find default codec."); + avCodec = avcodec_find_encoder(avFormatContext_->oformat->video_codec); + TORCH_CHECK(avCodec != nullptr, "Video codec not found"); + } AVCodecContext* avCodecContext = avcodec_alloc_context3(avCodec); TORCH_CHECK(avCodecContext != nullptr, "Couldn't allocate codec context."); diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 01af6846c..fca33855c 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -45,6 +45,7 @@ struct VideoStreamOptions { std::string_view deviceVariant = "ffmpeg"; // Encoding options + std::optional codec; // Optional pixel format for video encoding (e.g., "yuv420p", "yuv444p") // If not specified, uses codec's default format. std::optional pixelFormat; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 2ef5d0f49..c2ec3f2af 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? pixel_format=None, float? crf=None, str? preset=None) -> ()"); + "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? pixel_format=None, float? crf=None, str? preset=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? pixel_format=None, float? crf=None, str? preset=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -603,10 +603,12 @@ void encode_video_to_file( const at::Tensor& frames, int64_t frame_rate, std::string_view file_name, + std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, std::optional preset = std::nullopt) { VideoStreamOptions videoStreamOptions; + videoStreamOptions.codec = codec; videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; videoStreamOptions.preset = preset; @@ -622,11 +624,13 @@ at::Tensor encode_video_to_tensor( const at::Tensor& frames, int64_t frame_rate, std::string_view format, + std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, std::optional preset = std::nullopt) { auto avioContextHolder = std::make_unique(); VideoStreamOptions videoStreamOptions; + videoStreamOptions.codec = codec; videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; videoStreamOptions.preset = preset; @@ -644,6 +648,7 @@ void _encode_video_to_file_like( int64_t frame_rate, std::string_view format, int64_t file_like_context, + std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, std::optional crf = std::nullopt, std::optional preset = std::nullopt) { @@ -654,6 +659,7 @@ void _encode_video_to_file_like( std::unique_ptr avioContextHolder(fileLikeContext); VideoStreamOptions videoStreamOptions; + videoStreamOptions.codec = codec; videoStreamOptions.pixelFormat = pixel_format; videoStreamOptions.crf = crf; videoStreamOptions.preset = preset; diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index a2f1fa0a3..fda84c7e6 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -213,8 +213,9 @@ def encode_video_to_file_like( frame_rate: int, format: str, file_like: Union[io.RawIOBase, io.BufferedIOBase], - crf: Optional[Union[int, float]] = None, + codec: Optional[str] = None, pixel_format: Optional[str] = None, + crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, ) -> None: """Encode video frames to a file-like object. @@ -224,8 +225,9 @@ def encode_video_to_file_like( frame_rate: Frame rate in frames per second format: Video format (e.g., "mp4", "mov", "mkv") file_like: File-like object that supports write() and seek() methods - crf: Optional constant rate factor for encoding quality + codec: Optional codec name (e.g., "libx264", "h264") pixel_format: Optional pixel format (e.g., "yuv420p", "yuv444p") + crf: Optional constant rate factor for encoding quality preset: Optional encoder preset as string (e.g., "ultrafast", "medium") """ assert _pybind_ops is not None @@ -235,6 +237,7 @@ def encode_video_to_file_like( frame_rate, format, _pybind_ops.create_file_like_context(file_like, True), # True means for writing + codec, pixel_format, crf, preset, @@ -325,6 +328,7 @@ def encode_video_to_file_abstract( frames: torch.Tensor, frame_rate: int, filename: str, + codec: Optional[str], pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, @@ -337,6 +341,7 @@ def encode_video_to_tensor_abstract( frames: torch.Tensor, frame_rate: int, format: str, + codec: Optional[str], pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, @@ -350,6 +355,7 @@ def _encode_video_to_file_like_abstract( frame_rate: int, format: str, file_like_context: int, + codec: Optional[str], pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[str] = None, diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index d812d4a11..4788801c1 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -36,6 +36,7 @@ def to_file( self, dest: Union[str, Path], *, + codec: Optional[str] = None, pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[Union[str, int]] = None, @@ -46,6 +47,9 @@ def to_file( dest (str or ``pathlib.Path``): The path to the output file, e.g. ``video.mp4``. The extension of the file determines the video container format. + codec (str, optional): The codec to use for encoding (e.g., "libx264", + "h264"). If not specified, the default codec + for the container format will be used. pixel_format (str, optional): The pixel format for encoding (e.g., "yuv420p", "yuv444p"). If not specified, uses codec's default format. crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values @@ -61,6 +65,7 @@ def to_file( frames=self._frames, frame_rate=self._frame_rate, filename=str(dest), + codec=codec, pixel_format=pixel_format, crf=crf, preset=preset, @@ -70,6 +75,7 @@ def to_tensor( self, format: str, *, + codec: Optional[str] = None, pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[Union[str, int]] = None, @@ -78,7 +84,10 @@ def to_tensor( Args: format (str): The container format of the encoded frames, e.g. "mp4", "mov", - "mkv", "avi", "webm", "flv", etc. + "mkv", "avi", "webm", "flv", etc. + codec (str, optional): The codec to use for encoding (e.g., "libx264", + "h264"). If not specified, the default codec + for the container format will be used. pixel_format (str, optional): The pixel format to encode frames into (e.g., "yuv420p", "yuv444p"). If not specified, uses codec's default format. crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values @@ -90,13 +99,14 @@ def to_tensor( (which will use encoder's default). Returns: - Tensor: The raw encoded bytes as 4D uint8 Tensor. + Tensor: The raw encoded bytes as 1D uint8 Tensor. """ preset_value = str(preset) if isinstance(preset, int) else preset return _core.encode_video_to_tensor( frames=self._frames, frame_rate=self._frame_rate, format=format, + codec=codec, pixel_format=pixel_format, crf=crf, preset=preset_value, @@ -107,6 +117,7 @@ def to_file_like( file_like, format: str, *, + codec: Optional[str] = None, pixel_format: Optional[str] = None, crf: Optional[Union[int, float]] = None, preset: Optional[Union[str, int]] = None, @@ -121,6 +132,9 @@ def to_file_like( int = 0) -> int``. format (str): The container format of the encoded frames, e.g. "mp4", "mov", "mkv", "avi", "webm", "flv", etc. + codec (str, optional): The codec to use for encoding (e.g., "libx264", + "h264"). If not specified, the default codec + for the container format will be used. pixel_format (str, optional): The pixel format for encoding (e.g., "yuv420p", "yuv444p"). If not specified, uses codec's default format. crf (int or float, optional): Constant Rate Factor for encoding quality. Lower values @@ -137,6 +151,7 @@ def to_file_like( frame_rate=self._frame_rate, format=format, file_like=file_like, + codec=codec, pixel_format=pixel_format, crf=crf, preset=preset, diff --git a/test/test_encoders.py b/test/test_encoders.py index 303f3307c..9fb02f1ed 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -572,6 +572,27 @@ class TestVideoEncoder: def decode(self, source=None) -> torch.Tensor: return VideoDecoder(source).get_frames_in_range(start=0, stop=60) + def _get_codec_spec(self, file_path): + """Helper function to get codec name from a video file using ffprobe.""" + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + "stream=codec_name", + "-of", + "default=noprint_wrappers=1:nokey=1", + str(file_path), + ], + capture_output=True, + check=True, + text=True, + ) + return result.stdout.strip() + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parameterized(self, tmp_path, method): if method == "to_file": @@ -610,6 +631,16 @@ def test_bad_input_parameterized(self, tmp_path, method): ) getattr(encoder, method)(**valid_params) + with pytest.raises( + RuntimeError, + match=r"Video codec invalid_codec_name not found.", + ): + encoder = VideoEncoder( + frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), + frame_rate=30, + ) + encoder.to_file(str(tmp_path / "output.mp4"), codec="invalid_codec_name") + with pytest.raises(RuntimeError, match=r"crf=-10 is out of valid range"): encoder = VideoEncoder( frames=torch.zeros((5, 3, 64, 64), dtype=torch.uint8), @@ -990,3 +1021,72 @@ def write(self, data): RuntimeError, match="File like object must implement a seek method" ): encoder.to_file_like(NoSeekMethod(), format="mp4") + + @pytest.mark.skipif( + in_fbcode(), + reason="ffprobe not available internally", + ) + @pytest.mark.parametrize( + "format,codec_spec", + [ + ("mp4", "h264"), + ("mp4", "hevc"), + ("mkv", "av1"), + ("avi", "mpeg4"), + pytest.param( + "webm", + "vp9", + marks=pytest.mark.skipif( + IS_WINDOWS, reason="vp9 codec not available on Windows" + ), + ), + ], + ) + def test_codec_parameter_utilized(self, tmp_path, format, codec_spec): + # Test the codec parameter is utilized by using ffprobe to check the encoded file's codec spec + frames = torch.zeros((10, 3, 64, 64), dtype=torch.uint8) + dest = str(tmp_path / f"output.{format}") + + VideoEncoder(frames=frames, frame_rate=30).to_file(dest=dest, codec=codec_spec) + actual_codec_spec = self._get_codec_spec(dest) + assert actual_codec_spec == codec_spec + + @pytest.mark.skipif( + in_fbcode(), + reason="ffprobe not available internally", + ) + @pytest.mark.parametrize( + "codec_spec,codec_impl", + [ + ("h264", "libx264"), + ("av1", "libaom-av1"), + pytest.param( + "vp9", + "libvpx-vp9", + marks=pytest.mark.skipif( + IS_WINDOWS, reason="vp9 codec not available on Windows" + ), + ), + ], + ) + def test_codec_spec_vs_impl_equivalence(self, tmp_path, codec_spec, codec_impl): + # Test that using codec spec gives the same result as using default codec implementation + # We cannot directly check codec impl used, so we assert frame equality + frames = torch.randint(0, 256, (10, 3, 64, 64), dtype=torch.uint8) + + spec_output = str(tmp_path / "spec_output.mp4") + VideoEncoder(frames=frames, frame_rate=30).to_file( + dest=spec_output, codec=codec_spec + ) + + impl_output = str(tmp_path / "impl_output.mp4") + VideoEncoder(frames=frames, frame_rate=30).to_file( + dest=impl_output, codec=codec_impl + ) + + assert self._get_codec_spec(spec_output) == codec_spec + assert self._get_codec_spec(impl_output) == codec_spec + + frames_spec = self.decode(spec_output).data + frames_impl = self.decode(impl_output).data + torch.testing.assert_close(frames_spec, frames_impl, rtol=0, atol=0)