diff --git a/src/torchcodec/_core/__init__.py b/src/torchcodec/_core/__init__.py index eb8dd9697..e04fcb8bd 100644 --- a/src/torchcodec/_core/__init__.py +++ b/src/torchcodec/_core/__init__.py @@ -26,6 +26,7 @@ encode_audio_to_file_like, encode_audio_to_tensor, encode_video_to_file, + encode_video_to_file_like, encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 94a3fba1b..32c140d9e 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -40,6 +40,8 @@ TORCH_LIBRARY(torchcodec_ns, m) { "encode_video_to_file(Tensor frames, int frame_rate, str filename, int? crf=None) -> ()"); m.def( "encode_video_to_tensor(Tensor frames, int frame_rate, str format, int? crf=None) -> Tensor"); + m.def( + "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, int? crf=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -606,6 +608,30 @@ at::Tensor encode_video_to_tensor( .encodeToTensor(); } +void _encode_video_to_file_like( + const at::Tensor& frames, + int64_t frame_rate, + std::string_view format, + int64_t file_like_context, + std::optional crf = std::nullopt) { + auto fileLikeContext = + reinterpret_cast(file_like_context); + TORCH_CHECK( + fileLikeContext != nullptr, "file_like_context must be a valid pointer"); + std::unique_ptr avioContextHolder(fileLikeContext); + + VideoStreamOptions videoStreamOptions; + videoStreamOptions.crf = crf; + + VideoEncoder encoder( + frames, + validateInt64ToInt(frame_rate, "frame_rate"), + format, + std::move(avioContextHolder), + videoStreamOptions); + encoder.encode(); +} + // For testing only. We need to implement this operation as a core library // function because what we're testing is round-tripping pts values as // double-precision floating point numbers from C++ to Python and back to C++. @@ -870,6 +896,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) { m.impl("_encode_audio_to_file_like", &_encode_audio_to_file_like); m.impl("encode_video_to_file", &encode_video_to_file); m.impl("encode_video_to_tensor", &encode_video_to_tensor); + m.impl("_encode_video_to_file_like", &_encode_video_to_file_like); m.impl("seek_to_pts", &seek_to_pts); m.impl("add_video_stream", &add_video_stream); m.impl("_add_video_stream", &_add_video_stream); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index 03cf8cf6d..7123c83da 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -104,6 +104,9 @@ def load_torchcodec_shared_libraries(): encode_video_to_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.encode_video_to_tensor.default ) +_encode_video_to_file_like = torch._dynamo.disallow_in_graph( + torch.ops.torchcodec_ns._encode_video_to_file_like.default +) create_from_tensor = torch._dynamo.disallow_in_graph( torch.ops.torchcodec_ns.create_from_tensor.default ) @@ -203,6 +206,33 @@ def encode_audio_to_file_like( ) +def encode_video_to_file_like( + frames: torch.Tensor, + frame_rate: int, + format: str, + file_like: Union[io.RawIOBase, io.BufferedIOBase], + crf: Optional[int] = None, +) -> None: + """Encode video frames to a file-like object. + + Args: + frames: Video frames tensor + frame_rate: Frame rate in frames per second + format: Video format (e.g., "mp4", "mov", "mkv") + file_like: File-like object that supports write() and seek() methods + crf: Optional constant rate factor for encoding quality + """ + assert _pybind_ops is not None + + _encode_video_to_file_like( + frames, + frame_rate, + format, + _pybind_ops.create_file_like_context(file_like, True), # True means for writing + crf, + ) + + def get_frames_at_indices( decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -302,6 +332,17 @@ def encode_video_to_tensor_abstract( return torch.empty([], dtype=torch.long) +@register_fake("torchcodec_ns::_encode_video_to_file_like") +def _encode_video_to_file_like_abstract( + frames: torch.Tensor, + frame_rate: int, + format: str, + file_like_context: int, + crf: Optional[int] = None, +) -> None: + return + + @register_fake("torchcodec_ns::create_from_tensor") def create_from_tensor_abstract( video_tensor: torch.Tensor, seek_mode: Optional[str] diff --git a/test/test_ops.py b/test/test_ops.py index 31afbdd14..b2fe45b50 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -29,6 +29,7 @@ create_from_tensor, encode_audio_to_file, encode_video_to_file, + encode_video_to_file_like, encode_video_to_tensor, get_ffmpeg_library_versions, get_frame_at_index, @@ -1329,7 +1330,7 @@ def test_bad_input(self, tmp_path): class TestVideoEncoderOps: - + # TODO-VideoEncoder: Test encoding against different memory layouts (ex. test_contiguity) # TODO-VideoEncoder: Parametrize test after moving to test_encoders def test_bad_input(self, tmp_path): output_file = str(tmp_path / ".mp4") @@ -1397,7 +1398,7 @@ def decode(self, source=None) -> torch.Tensor: @pytest.mark.parametrize( "format", ("mov", "mp4", "mkv", pytest.param("webm", marks=pytest.mark.slow)) ) - @pytest.mark.parametrize("method", ("to_file", "to_tensor")) + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_video_encoder_round_trip(self, tmp_path, format, method): # Test that decode(encode(decode(frames))) == decode(frames) ffmpeg_version = get_ffmpeg_major_version() @@ -1424,11 +1425,22 @@ def test_video_encoder_round_trip(self, tmp_path, format, method): **params, ) round_trip_frames = self.decode(encoded_path).data - else: # to_tensor + elif method == "to_tensor": encoded_tensor = encode_video_to_tensor( source_frames, format=format, **params ) round_trip_frames = self.decode(encoded_tensor).data + elif method == "to_file_like": + file_like = io.BytesIO() + encode_video_to_file_like( + frames=source_frames, + format=format, + file_like=file_like, + **params, + ) + round_trip_frames = self.decode(file_like.getvalue()).data + else: + raise ValueError(f"Unknown method: {method}") assert source_frames.shape == round_trip_frames.shape assert source_frames.dtype == round_trip_frames.dtype @@ -1457,8 +1469,9 @@ def test_video_encoder_round_trip(self, tmp_path, format, method): pytest.param("webm", marks=pytest.mark.slow), ), ) - def test_against_to_file(self, tmp_path, format): - # Test that to_file and to_tensor produce the same results + @pytest.mark.parametrize("method", ("to_tensor", "to_file_like")) + def test_against_to_file(self, tmp_path, format, method): + # Test that to_file, to_tensor, and to_file_like produce the same results ffmpeg_version = get_ffmpeg_major_version() if format == "webm" and ( ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) @@ -1470,11 +1483,24 @@ def test_against_to_file(self, tmp_path, format): encoded_file = tmp_path / f"output.{format}" encode_video_to_file(frames=source_frames, filename=str(encoded_file), **params) - encoded_tensor = encode_video_to_tensor(source_frames, format=format, **params) + + if method == "to_tensor": + encoded_output = encode_video_to_tensor( + source_frames, format=format, **params + ) + else: # to_file_like + file_like = io.BytesIO() + encode_video_to_file_like( + frames=source_frames, + file_like=file_like, + format=format, + **params, + ) + encoded_output = file_like.getvalue() torch.testing.assert_close( self.decode(encoded_file).data, - self.decode(encoded_tensor).data, + self.decode(encoded_output).data, atol=0, rtol=0, ) @@ -1557,6 +1583,82 @@ def test_video_encoder_against_ffmpeg_cli(self, tmp_path, format): ff_frame, enc_frame, percentage=percentage, atol=2 ) + def test_to_file_like_custom_file_object(self): + """Test with a custom file-like object that implements write and seek.""" + + class CustomFileObject: + def __init__(self): + self._file = io.BytesIO() + + def write(self, data): + return self._file.write(data) + + def seek(self, offset, whence=0): + return self._file.seek(offset, whence) + + def get_encoded_data(self): + return self._file.getvalue() + + source_frames = self.decode(TEST_SRC_2_720P.path).data + file_like = CustomFileObject() + encode_video_to_file_like( + source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + ) + decoded_samples = self.decode(file_like.get_encoded_data()) + + torch.testing.assert_close( + decoded_samples.data, + source_frames, + atol=2, + rtol=0, + ) + + def test_to_file_like_real_file(self, tmp_path): + """Test to_file_like with a real file opened in binary write mode.""" + source_frames = self.decode(TEST_SRC_2_720P.path).data + file_path = tmp_path / "test_file_like.mp4" + + with open(file_path, "wb") as file_like: + encode_video_to_file_like( + source_frames, frame_rate=30, crf=0, format="mp4", file_like=file_like + ) + decoded_samples = self.decode(str(file_path)) + + torch.testing.assert_close( + decoded_samples.data, + source_frames, + atol=2, + rtol=0, + ) + + def test_to_file_like_bad_methods(self): + source_frames = self.decode(TEST_SRC_2_720P.path).data + + class NoWriteMethod: + def seek(self, offset, whence=0): + return 0 + + with pytest.raises( + RuntimeError, match="File like object must implement a write method" + ): + encode_video_to_file_like( + source_frames, + frame_rate=30, + format="mp4", + file_like=NoWriteMethod(), + ) + + class NoSeekMethod: + def write(self, data): + return len(data) + + with pytest.raises( + RuntimeError, match="File like object must implement a seek method" + ): + encode_video_to_file_like( + source_frames, frame_rate=30, format="mp4", file_like=NoSeekMethod() + ) + if __name__ == "__main__": pytest.main()