diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 362a02a95..19ac9220d 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -662,7 +662,7 @@ VideoEncoder::~VideoEncoder() { VideoEncoder::VideoEncoder( const torch::Tensor& frames, - int frameRate, + double frameRate, std::string_view fileName, const VideoStreamOptions& videoStreamOptions) : frames_(validateFrames(frames)), inFrameRate_(frameRate) { @@ -694,7 +694,7 @@ VideoEncoder::VideoEncoder( VideoEncoder::VideoEncoder( const torch::Tensor& frames, - int frameRate, + double frameRate, std::string_view formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions) @@ -787,9 +787,9 @@ void VideoEncoder::initializeEncoder( avCodecContext_->width = outWidth_; avCodecContext_->height = outHeight_; avCodecContext_->pix_fmt = outPixelFormat_; - // TODO-VideoEncoder: Verify that frame_rate and time_base are correct - avCodecContext_->time_base = {1, inFrameRate_}; - avCodecContext_->framerate = {inFrameRate_, 1}; + // TODO-VideoEncoder: Add and utilize output frame_rate option + avCodecContext_->framerate = av_d2q(inFrameRate_, INT_MAX); + avCodecContext_->time_base = av_inv_q(avCodecContext_->framerate); // Set flag for containers that require extradata to be in the codec context if (avFormatContext_->oformat->flags & AVFMT_GLOBALHEADER) { @@ -833,6 +833,10 @@ void VideoEncoder::initializeEncoder( // Set the stream time base to encode correct frame timestamps avStream_->time_base = avCodecContext_->time_base; + // Set the stream frame rate to store correct frame durations for some + // containers (webm, mkv) + avStream_->r_frame_rate = avCodecContext_->framerate; + status = avcodec_parameters_from_context( avStream_->codecpar, avCodecContext_.get()); TORCH_CHECK( diff --git a/src/torchcodec/_core/Encoder.h b/src/torchcodec/_core/Encoder.h index 3d59eb6f6..1bdc1e443 100644 --- a/src/torchcodec/_core/Encoder.h +++ b/src/torchcodec/_core/Encoder.h @@ -143,13 +143,13 @@ class VideoEncoder { VideoEncoder( const torch::Tensor& frames, - int frameRate, + double frameRate, std::string_view fileName, const VideoStreamOptions& videoStreamOptions); VideoEncoder( const torch::Tensor& frames, - int frameRate, + double frameRate, std::string_view formatName, std::unique_ptr avioContextHolder, const VideoStreamOptions& videoStreamOptions); @@ -172,7 +172,7 @@ class VideoEncoder { UniqueSwsContext swsContext_; const torch::Tensor frames_; - int inFrameRate_; + double inFrameRate_; int inWidth_ = -1; int inHeight_ = -1; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 14ca48a7b..4a75ae82e 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -37,11 +37,11 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "_encode_audio_to_file_like(Tensor samples, int sample_rate, str format, int file_like_context, int? bit_rate=None, int? num_channels=None, int? desired_sample_rate=None) -> ()"); m.def( - "encode_video_to_file(Tensor frames, int frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + "encode_video_to_file(Tensor frames, float frame_rate, str filename, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( - "encode_video_to_tensor(Tensor frames, int frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); + "encode_video_to_tensor(Tensor frames, float frame_rate, str format, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> Tensor"); m.def( - "_encode_video_to_file_like(Tensor frames, int frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); + "_encode_video_to_file_like(Tensor frames, float frame_rate, str format, int file_like_context, str? codec=None, str? pixel_format=None, float? crf=None, str? preset=None, str[]? extra_options=None) -> ()"); m.def( "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( @@ -611,7 +611,7 @@ void _encode_audio_to_file_like( void encode_video_to_file( const at::Tensor& frames, - int64_t frame_rate, + double frame_rate, std::string_view file_name, std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, @@ -629,17 +629,12 @@ void encode_video_to_file( unflattenExtraOptions(extra_options.value()); } - VideoEncoder( - frames, - validateInt64ToInt(frame_rate, "frame_rate"), - file_name, - videoStreamOptions) - .encode(); + VideoEncoder(frames, frame_rate, file_name, videoStreamOptions).encode(); } at::Tensor encode_video_to_tensor( const at::Tensor& frames, - int64_t frame_rate, + double frame_rate, std::string_view format, std::optional codec = std::nullopt, std::optional pixel_format = std::nullopt, @@ -660,7 +655,7 @@ at::Tensor encode_video_to_tensor( return VideoEncoder( frames, - validateInt64ToInt(frame_rate, "frame_rate"), + frame_rate, format, std::move(avioContextHolder), videoStreamOptions) @@ -669,7 +664,7 @@ at::Tensor encode_video_to_tensor( void _encode_video_to_file_like( const at::Tensor& frames, - int64_t frame_rate, + double frame_rate, std::string_view format, int64_t file_like_context, std::optional codec = std::nullopt, @@ -696,7 +691,7 @@ void _encode_video_to_file_like( VideoEncoder encoder( frames, - validateInt64ToInt(frame_rate, "frame_rate"), + frame_rate, format, std::move(avioContextHolder), videoStreamOptions); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index c3562f679..160e273bb 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -210,7 +210,7 @@ def encode_audio_to_file_like( def encode_video_to_file_like( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, format: str, file_like: Union[io.RawIOBase, io.BufferedIOBase], codec: Optional[str] = None, @@ -329,7 +329,7 @@ def _encode_audio_to_file_like_abstract( @register_fake("torchcodec_ns::encode_video_to_file") def encode_video_to_file_abstract( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, filename: str, codec: Optional[str] = None, pixel_format: Optional[str] = None, @@ -343,7 +343,7 @@ def encode_video_to_file_abstract( @register_fake("torchcodec_ns::encode_video_to_tensor") def encode_video_to_tensor_abstract( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, format: str, codec: Optional[str] = None, pixel_format: Optional[str] = None, @@ -357,7 +357,7 @@ def encode_video_to_tensor_abstract( @register_fake("torchcodec_ns::_encode_video_to_file_like") def _encode_video_to_file_like_abstract( frames: torch.Tensor, - frame_rate: int, + frame_rate: float, format: str, file_like_context: int, codec: Optional[str] = None, diff --git a/src/torchcodec/encoders/_video_encoder.py b/src/torchcodec/encoders/_video_encoder.py index 0bb754025..49ece70b6 100644 --- a/src/torchcodec/encoders/_video_encoder.py +++ b/src/torchcodec/encoders/_video_encoder.py @@ -15,10 +15,10 @@ class VideoEncoder: tensor of shape ``(N, C, H, W)`` where N is the number of frames, C is 3 channels (RGB), H is height, and W is width. Values must be uint8 in the range ``[0, 255]``. - frame_rate (int): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. + frame_rate (float): The frame rate of the **input** ``frames``. Also defines the encoded **output** frame rate. """ - def __init__(self, frames: Tensor, *, frame_rate: int): + def __init__(self, frames: Tensor, *, frame_rate: float): torch._C._log_api_usage_once("torchcodec.encoders.VideoEncoder") if not isinstance(frames, Tensor): raise ValueError(f"Expected frames to be a Tensor, got {type(frames) = }.") diff --git a/test/test_encoders.py b/test/test_encoders.py index ad2f0cefe..543025599 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -570,7 +570,14 @@ def write(self, data): class TestVideoEncoder: def decode(self, source=None) -> torch.Tensor: - return VideoDecoder(source).get_frames_in_range(start=0, stop=60) + return VideoDecoder(source).get_frames_in_range(start=0, stop=30).data + + # TODO: add average_fps field to TestVideo asset + def decode_and_get_frame_rate(self, source=None): + decoder = VideoDecoder(source) + frames = decoder.get_frames_in_range(start=0, stop=30).data + frame_rate = decoder.metadata.average_fps + return frames, frame_rate def _get_video_metadata(self, file_path, fields): """Helper function to get video metadata from a file using ffprobe.""" @@ -596,8 +603,32 @@ def _get_video_metadata(self, file_path, fields): if "=" in line: key, value = line.split("=", 1) metadata[key] = value + assert all(field in metadata for field in fields) return metadata + def _get_frames_info(self, file_path, fields): + """Helper function to get frame info (pts, dts, etc.) using ffprobe.""" + result = subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-select_streams", + "v:0", + "-show_entries", + f"frame={','.join(fields)}", + "-of", + "json", + str(file_path), + ], + capture_output=True, + check=True, + text=True, + ) + frames = json.loads(result.stdout)["frames"] + assert all(field in frame for field in fields for frame in frames) + return frames + @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) def test_bad_input_parameterized(self, tmp_path, method): if method == "to_file": @@ -826,26 +857,25 @@ def test_round_trip(self, tmp_path, format, method): ffmpeg_version == 4 or (IS_WINDOWS and ffmpeg_version in (6, 7)) ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") - source_frames = self.decode(TEST_SRC_2_720P.path).data + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) - # Frame rate is fixed with num frames decoded - encoder = VideoEncoder(frames=source_frames, frame_rate=30) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) if method == "to_file": encoded_path = str(tmp_path / f"encoder_output.{format}") encoder.to_file(dest=encoded_path, pixel_format="yuv444p", crf=0) - round_trip_frames = self.decode(encoded_path).data + round_trip_frames = self.decode(encoded_path) elif method == "to_tensor": encoded_tensor = encoder.to_tensor( format=format, pixel_format="yuv444p", crf=0 ) - round_trip_frames = self.decode(encoded_tensor).data + round_trip_frames = self.decode(encoded_tensor) elif method == "to_file_like": file_like = io.BytesIO() encoder.to_file_like( file_like=file_like, format=format, pixel_format="yuv444p", crf=0 ) - round_trip_frames = self.decode(file_like.getvalue()).data + round_trip_frames = self.decode(file_like.getvalue()) else: raise ValueError(f"Unknown method: {method}") @@ -878,8 +908,8 @@ def test_against_to_file(self, tmp_path, format, method): ): pytest.skip("Codec for webm is not available in this FFmpeg installation.") - source_frames = self.decode(TEST_SRC_2_720P.path).data - encoder = VideoEncoder(frames=source_frames, frame_rate=30) + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) encoded_file = tmp_path / f"output.{format}" encoder.to_file(dest=encoded_file, crf=0) @@ -892,8 +922,8 @@ def test_against_to_file(self, tmp_path, format, method): encoded_output = file_like.getvalue() torch.testing.assert_close( - self.decode(encoded_file).data, - self.decode(encoded_output).data, + self.decode(encoded_file), + self.decode(encoded_output), atol=0, rtol=0, ) @@ -920,8 +950,9 @@ def test_against_to_file(self, tmp_path, format, method): ], ) @pytest.mark.parametrize("method", ("to_file", "to_tensor", "to_file_like")) + @pytest.mark.parametrize("frame_rate", [30, 29.97]) def test_video_encoder_against_ffmpeg_cli( - self, tmp_path, format, encode_params, method + self, tmp_path, format, encode_params, method, frame_rate ): ffmpeg_version = get_ffmpeg_major_version() if format == "webm" and ( @@ -936,7 +967,7 @@ def test_video_encoder_against_ffmpeg_cli( if format in ("avi", "flv") and pixel_format == "yuv444p": pytest.skip(f"Default codec for {format} does not support {pixel_format}") - source_frames = self.decode(TEST_SRC_2_720P.path).data + source_frames = self.decode(TEST_SRC_2_720P.path) # Encode with FFmpeg CLI temp_raw_path = str(tmp_path / "temp_input.raw") @@ -944,7 +975,6 @@ def test_video_encoder_against_ffmpeg_cli( f.write(source_frames.permute(0, 2, 3, 1).cpu().numpy().tobytes()) ffmpeg_encoded_path = str(tmp_path / f"ffmpeg_output.{format}") - frame_rate = 30 # Some codecs (ex. MPEG4) do not support CRF or preset. # Flags not supported by the selected codec will be ignored. ffmpeg_cmd = [ @@ -983,7 +1013,7 @@ def test_video_encoder_against_ffmpeg_cli( crf=crf, preset=preset, ) - encoder_frames = self.decode(encoder_output_path).data + encoder_frames = self.decode(encoder_output_path) elif method == "to_tensor": encoded_output = encoder.to_tensor( format=format, @@ -991,7 +1021,7 @@ def test_video_encoder_against_ffmpeg_cli( crf=crf, preset=preset, ) - encoder_frames = self.decode(encoded_output).data + encoder_frames = self.decode(encoded_output) elif method == "to_file_like": file_like = io.BytesIO() encoder.to_file_like( @@ -1001,7 +1031,7 @@ def test_video_encoder_against_ffmpeg_cli( crf=crf, preset=preset, ) - encoder_frames = self.decode(file_like.getvalue()).data + encoder_frames = self.decode(file_like.getvalue()) else: raise ValueError(f"Unknown method: {method}") @@ -1018,9 +1048,16 @@ def test_video_encoder_against_ffmpeg_cli( ff_frame, enc_frame, percentage=percentage, atol=2 ) - # Check that video metadata is the same - if method == "to_file": - fields = ["duration", "duration_ts", "r_frame_rate", "nb_frames"] + # Only compare video metadata on ffmpeg versions >= 6, as older versions + # are often missing metadata + if ffmpeg_version >= 6 and method == "to_file": + fields = [ + "duration", + "duration_ts", + "r_frame_rate", + "time_base", + "nb_frames", + ] ffmpeg_metadata = self._get_video_metadata( ffmpeg_encoded_path, fields=fields, @@ -1031,6 +1068,18 @@ def test_video_encoder_against_ffmpeg_cli( ) assert ffmpeg_metadata == encoder_metadata + # Check that frame timestamps and duration are the same + fields = ("pts", "pts_time") + if format != "flv": + fields += ("duration", "duration_time") + ffmpeg_frames_info = self._get_frames_info( + ffmpeg_encoded_path, fields=fields + ) + encoder_frames_info = self._get_frames_info( + encoder_output_path, fields=fields + ) + assert ffmpeg_frames_info == encoder_frames_info + def test_to_file_like_custom_file_object(self): """Test to_file_like with a custom file-like object that implements write and seek.""" @@ -1047,15 +1096,15 @@ def seek(self, offset, whence=0): def get_encoded_data(self): return self._file.getvalue() - source_frames = self.decode(TEST_SRC_2_720P.path).data - encoder = VideoEncoder(frames=source_frames, frame_rate=30) + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) file_like = CustomFileObject() encoder.to_file_like(file_like, format="mp4", pixel_format="yuv444p", crf=0) decoded_frames = self.decode(file_like.get_encoded_data()) torch.testing.assert_close( - decoded_frames.data, + decoded_frames, source_frames, atol=2, rtol=0, @@ -1063,8 +1112,8 @@ def get_encoded_data(self): def test_to_file_like_real_file(self, tmp_path): """Test to_file_like with a real file opened in binary write mode.""" - source_frames = self.decode(TEST_SRC_2_720P.path).data - encoder = VideoEncoder(frames=source_frames, frame_rate=30) + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) file_path = tmp_path / "test_file_like.mp4" @@ -1073,15 +1122,15 @@ def test_to_file_like_real_file(self, tmp_path): decoded_frames = self.decode(str(file_path)) torch.testing.assert_close( - decoded_frames.data, + decoded_frames, source_frames, atol=2, rtol=0, ) def test_to_file_like_bad_methods(self): - source_frames = self.decode(TEST_SRC_2_720P.path).data - encoder = VideoEncoder(frames=source_frames, frame_rate=30) + source_frames, frame_rate = self.decode_and_get_frame_rate(TEST_SRC_2_720P.path) + encoder = VideoEncoder(frames=source_frames, frame_rate=frame_rate) class NoWriteMethod: def seek(self, offset, whence=0): @@ -1174,8 +1223,8 @@ def test_codec_spec_vs_impl_equivalence(self, tmp_path, codec_spec, codec_impl): == codec_spec ) - frames_spec = self.decode(spec_output).data - frames_impl = self.decode(impl_output).data + frames_spec = self.decode(spec_output) + frames_impl = self.decode(impl_output) torch.testing.assert_close(frames_spec, frames_impl, rtol=0, atol=0) @pytest.mark.skipif(in_fbcode(), reason="ffprobe not available")