From 30a4754e799ca515dbfa5c70d606f1ac3071f1e5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 13 Jun 2025 15:14:43 +0100 Subject: [PATCH 01/10] Encoder: validate pts of encoded frames --- src/torchcodec/_frame.py | 2 +- test/test_encoders.py | 53 +++++++++++++++++++++++++++++++++++++++- 2 files changed, 53 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 525c7ac8b..50eb12923 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -125,7 +125,7 @@ class AudioSamples(Iterable): pts_seconds: float """The :term:`pts` of the first sample, in seconds.""" duration_seconds: float - """The duration of the sampleas, in seconds.""" + """The duration of the samples, in seconds.""" sample_rate: int """The sample rate of the samples, in Hz.""" diff --git a/test/test_encoders.py b/test/test_encoders.py index bf5f9cc6b..927f321e2 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,5 +1,7 @@ +import json import re import subprocess +from pathlib import Path import pytest import torch @@ -16,6 +18,49 @@ ) +def validate_frames_properties(*, actual: Path, expected: Path): + + frames_actual, frames_expected = ( + json.loads( + subprocess.run( + [ + "ffprobe", + "-v", + "error", + "-hide_banner", + "-select_streams", + "a:0", + "-show_frames", + "-of", + "json", + f"{f}", + ], + check=True, + capture_output=True, + text=True, + ).stdout + )["frames"] + for f in (actual, expected) + ) + + # frames_actual and frames_expected are both a list of dicts, each dict + # corresponds to a frame and each key-value pair corresponds to a frame + # property like pts, nb_samples, etc., similar to the AVFrame fields. + assert isinstance(frames_actual, list) + assert all(isinstance(d, dict) for d in frames_actual) + + assert len(frames_actual) == len(frames_expected) + for frame_index, (d_actual, d_expected) in enumerate( + zip(frames_actual, frames_expected) + ): + for prop in d_actual: + if prop == "pkt_pos": + continue # TODO this probably matters + assert ( + d_actual[prop] == d_expected[prop] + ), f"{prop} value is different for frame {frame_index}:" + + class TestAudioEncoder: def decode(self, source) -> torch.Tensor: @@ -162,13 +207,19 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa rtol, atol = 0, 1e-3 else: rtol, atol = None, None + # TODO should validate `.pts_seconds` and `duration_seconds` as well torch.testing.assert_close( - self.decode(encoded_by_ffmpeg), self.decode(encoded_by_us), + self.decode(encoded_by_ffmpeg), rtol=rtol, atol=atol, ) + if method == "to_file": + validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg) + else: + assert method == "to_tensor", "wrong test parametrization!" + @pytest.mark.parametrize("asset", (NASA_AUDIO_MP3, SINE_MONO_S32)) @pytest.mark.parametrize("bit_rate", (None, 0, 44_100, 999_999_999)) @pytest.mark.parametrize("num_channels", (None, 1, 2)) From 9d48ff6ad2e0963af802ba4e6c186ec568590b5f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 13 Jun 2025 15:22:36 +0100 Subject: [PATCH 02/10] Also check fields of AudioSamples --- test/test_encoders.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index 927f321e2..da2a185dc 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -53,7 +53,7 @@ def validate_frames_properties(*, actual: Path, expected: Path): for frame_index, (d_actual, d_expected) in enumerate( zip(frames_actual, frames_expected) ): - for prop in d_actual: + for prop in d_expected: if prop == "pkt_pos": continue # TODO this probably matters assert ( @@ -66,7 +66,7 @@ class TestAudioEncoder: def decode(self, source) -> torch.Tensor: if isinstance(source, TestContainerFile): source = str(source.path) - return AudioDecoder(source).get_all_samples().data + return AudioDecoder(source).get_all_samples() def test_bad_input(self): with pytest.raises(ValueError, match="Expected samples to be a Tensor"): @@ -108,12 +108,12 @@ def test_bad_input_parametrized(self, method, tmp_path): else dict(format="mp3") ) - decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3), sample_rate=10) + decoder = AudioEncoder(self.decode(NASA_AUDIO_MP3).data, sample_rate=10) with pytest.raises(RuntimeError, match="invalid sample rate=10"): getattr(decoder, method)(**valid_params) decoder = AudioEncoder( - self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate + self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate ) with pytest.raises(RuntimeError, match="bit_rate=-1 must be >= 0"): getattr(decoder, method)(**valid_params, bit_rate=-1) @@ -126,7 +126,7 @@ def test_bad_input_parametrized(self, method, tmp_path): getattr(decoder, method)(**valid_params) decoder = AudioEncoder( - self.decode(NASA_AUDIO_MP3), sample_rate=NASA_AUDIO_MP3.sample_rate + self.decode(NASA_AUDIO_MP3).data, sample_rate=NASA_AUDIO_MP3.sample_rate ) for num_channels in (0, 3): with pytest.raises( @@ -146,7 +146,7 @@ def test_round_trip(self, method, format, tmp_path): pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") asset = NASA_AUDIO_MP3 - source_samples = self.decode(asset) + source_samples = self.decode(asset).data encoder = AudioEncoder(source_samples, sample_rate=asset.sample_rate) @@ -161,7 +161,7 @@ def test_round_trip(self, method, format, tmp_path): rtol, atol = (0, 1e-4) if format == "wav" else (None, None) torch.testing.assert_close( - self.decode(encoded_source), source_samples, rtol=rtol, atol=atol + self.decode(encoded_source).data, source_samples, rtol=rtol, atol=atol ) @pytest.mark.skipif(in_fbcode(), reason="TODO: enable ffmpeg CLI") @@ -189,7 +189,7 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa check=True, ) - encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate) + encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) params = dict(bit_rate=bit_rate, num_channels=num_channels) if method == "to_file": encoded_by_us = tmp_path / f"output.{format}" @@ -207,13 +207,17 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa rtol, atol = 0, 1e-3 else: rtol, atol = None, None - # TODO should validate `.pts_seconds` and `duration_seconds` as well + samples_by_us = self.decode(encoded_by_us) + samples_by_ffmpeg = self.decode(encoded_by_ffmpeg) torch.testing.assert_close( - self.decode(encoded_by_us), - self.decode(encoded_by_ffmpeg), + samples_by_us.data, + samples_by_ffmpeg.data, rtol=rtol, atol=atol, ) + assert samples_by_us.pts_seconds == samples_by_ffmpeg.pts_seconds + assert samples_by_us.duration_seconds == samples_by_ffmpeg.duration_seconds + assert samples_by_us.sample_rate == samples_by_ffmpeg.sample_rate if method == "to_file": validate_frames_properties(actual=encoded_by_us, expected=encoded_by_ffmpeg) @@ -230,7 +234,7 @@ def test_to_tensor_against_to_file( if get_ffmpeg_major_version() == 4 and format == "wav": pytest.skip("Swresample with FFmpeg 4 doesn't work on wav files") - encoder = AudioEncoder(self.decode(asset), sample_rate=asset.sample_rate) + encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) params = dict(bit_rate=bit_rate, num_channels=num_channels) encoded_file = tmp_path / f"output.{format}" @@ -240,7 +244,7 @@ def test_to_tensor_against_to_file( ) torch.testing.assert_close( - self.decode(encoded_file), self.decode(encoded_tensor) + self.decode(encoded_file).data, self.decode(encoded_tensor).data ) def test_encode_to_tensor_long_output(self): @@ -256,7 +260,7 @@ def test_encode_to_tensor_long_output(self): INITIAL_TENSOR_SIZE = 10_000_000 assert encoded_tensor.numel() > INITIAL_TENSOR_SIZE - torch.testing.assert_close(self.decode(encoded_tensor), samples) + torch.testing.assert_close(self.decode(encoded_tensor).data, samples) def test_contiguity(self): # Ensure that 2 waveforms with the same values are encoded in the same @@ -313,4 +317,4 @@ def test_num_channels( if num_channels_output is None: num_channels_output = num_channels_input - assert self.decode(encoded_source).shape[0] == num_channels_output + assert self.decode(encoded_source).data.shape[0] == num_channels_output From 020556c67f6db5e1b5d8be8d6dbf926d6b378f51 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 13 Jun 2025 15:29:35 +0100 Subject: [PATCH 03/10] edit message on failure --- test/test_encoders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index da2a185dc..fc57b62ba 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -58,7 +58,7 @@ def validate_frames_properties(*, actual: Path, expected: Path): continue # TODO this probably matters assert ( d_actual[prop] == d_expected[prop] - ), f"{prop} value is different for frame {frame_index}:" + ), f"\nComparing: {actual}\nagainst reference: {expected},\nthe {prop} property is different at frame {frame_index}:" class TestAudioEncoder: From 16b902daf0d1216827059c25b05a5f0e32ef2eb5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 11:08:00 +0100 Subject: [PATCH 04/10] Add tests --- test/test_encoders.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/test/test_encoders.py b/test/test_encoders.py index fc57b62ba..579e7f336 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -49,10 +49,24 @@ def validate_frames_properties(*, actual: Path, expected: Path): assert isinstance(frames_actual, list) assert all(isinstance(d, dict) for d in frames_actual) + assert len(frames_actual) > 3 # arbitrary sanity check assert len(frames_actual) == len(frames_expected) + + # non-exhaustive list of the props we want to test for: + required_props = ( + "pts", + "pts_time", + "sample_fmt", + "nb_samples", + "channels", + "duration", + "duration_time", + ) + for frame_index, (d_actual, d_expected) in enumerate( zip(frames_actual, frames_expected) ): + assert all(required_prop in d_actual for required_prop in required_props) for prop in d_expected: if prop == "pkt_pos": continue # TODO this probably matters From 2c9eb1f2b2166f2e8abee3831fa0b5261c4e3872 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 11:33:27 +0100 Subject: [PATCH 05/10] Explicitly set frames pts values --- src/torchcodec/_core/Encoder.cpp | 3 +-- src/torchcodec/_core/FFMPEGCommon.cpp | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index d23ecd5f3..905794328 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -293,8 +293,7 @@ void AudioEncoder::encode() { encodeInnerLoop(autoAVPacket, convertedAVFrame); numEncodedSamples += numSamplesToEncode; - // TODO-ENCODING set frame pts correctly, and test against it. - // avFrame->pts += static_cast(numSamplesToEncode); + avFrame->pts += static_cast(numSamplesToEncode); } TORCH_CHECK(numEncodedSamples == numSamples, "Hmmmmmm something went wrong."); diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 2609caf3e..f47b42754 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -274,6 +274,7 @@ UniqueAVFrame convertAudioAVFrameSamples( convertedAVFrame, "Could not allocate frame for sample format conversion."); + convertedAVFrame->pts = srcAVFrame->pts; convertedAVFrame->format = static_cast(outSampleFormat); convertedAVFrame->sample_rate = outSampleRate; From 079de415f99829cd4be026b051d0d65035017d9f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 12:05:13 +0100 Subject: [PATCH 06/10] Add comment --- test/test_encoders.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index 579e7f336..cfc936cb3 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -69,7 +69,13 @@ def validate_frames_properties(*, actual: Path, expected: Path): assert all(required_prop in d_actual for required_prop in required_props) for prop in d_expected: if prop == "pkt_pos": - continue # TODO this probably matters + # pkt_pos is the position of the packet *in bytes* in its + # stream. We don't always match FFmpeg exactly on this, + # typically on compressed formats like mp3. It's probably + # because we are not writing the exact same headers, or + # something like this. In any case, this doesn't seem to be + # critical. + continue assert ( d_actual[prop] == d_expected[prop] ), f"\nComparing: {actual}\nagainst reference: {expected},\nthe {prop} property is different at frame {frame_index}:" From b90927b76195a3420a3aa273f4a766dd0179ecad Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 12:33:33 +0100 Subject: [PATCH 07/10] Add test for FFmpeg logs --- test/test_encoders.py | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index cfc936cb3..fb87b6b1b 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -1,4 +1,5 @@ import json +import os import re import subprocess from pathlib import Path @@ -18,6 +19,15 @@ ) +@pytest.fixture +def with_ffmpeg_debug_logs(): + # Fixture that sets the ffmpeg logs to DEBUG mode + previous_log_level = os.environ.get("TORCHCODEC_FFMPEG_LOG_LEVEL", "QUIET") + os.environ["TORCHCODEC_FFMPEG_LOG_LEVEL"] = "DEBUG" + yield + os.environ["TORCHCODEC_FFMPEG_LOG_LEVEL"] = previous_log_level + + def validate_frames_properties(*, actual: Path, expected: Path): frames_actual, frames_expected = ( @@ -190,7 +200,17 @@ def test_round_trip(self, method, format, tmp_path): @pytest.mark.parametrize("num_channels", (None, 1, 2)) @pytest.mark.parametrize("format", ("mp3", "wav", "flac")) @pytest.mark.parametrize("method", ("to_file", "to_tensor")) - def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_path): + def test_against_cli( + self, + asset, + bit_rate, + num_channels, + format, + method, + tmp_path, + capfd, + with_ffmpeg_debug_logs, + ): # Encodes samples with our encoder and with the FFmpeg CLI, and checks # that both decoded outputs are equal @@ -210,6 +230,7 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa ) encoder = AudioEncoder(self.decode(asset).data, sample_rate=asset.sample_rate) + params = dict(bit_rate=bit_rate, num_channels=num_channels) if method == "to_file": encoded_by_us = tmp_path / f"output.{format}" @@ -217,6 +238,16 @@ def test_against_cli(self, asset, bit_rate, num_channels, format, method, tmp_pa else: encoded_by_us = encoder.to_tensor(format=format, **params) + captured = capfd.readouterr() + if format == "wav": + assert "Timestamps are unset in a packet" not in captured.err + if format == "mp3": + assert "Queue input is backward in time" not in captured.err + if format in ("flac", "wav"): + assert "Encoder did not produce proper pts" not in captured.err + if format in ("flac", "mp3"): + assert "Application provided invalid" not in captured.err + if format == "wav": rtol, atol = 0, 1e-4 elif format == "mp3" and asset is SINE_MONO_S32 and num_channels == 2: From 225ac51ab0a1f58166a2f691ae5ed0c88ffab5d5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 12:38:28 +0100 Subject: [PATCH 08/10] comments --- test/test_encoders.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index fb87b6b1b..81b657cbb 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -29,6 +29,9 @@ def with_ffmpeg_debug_logs(): def validate_frames_properties(*, actual: Path, expected: Path): + # actual and expected are files containing encoded audio data. We call + # `ffprobe` on both, and assert that the frame properties match (pts, + # duration, etc.) frames_actual, frames_expected = ( json.loads( @@ -76,7 +79,7 @@ def validate_frames_properties(*, actual: Path, expected: Path): for frame_index, (d_actual, d_expected) in enumerate( zip(frames_actual, frames_expected) ): - assert all(required_prop in d_actual for required_prop in required_props) + assert all(required_prop in d_expected for required_prop in required_props) for prop in d_expected: if prop == "pkt_pos": # pkt_pos is the position of the packet *in bytes* in its From 9500ebbb80ba5f524b1a188c954c5eb9a2415535 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 12:50:16 +0100 Subject: [PATCH 09/10] Debug --- test/test_encoders.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index 81b657cbb..1eddb7c58 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -79,7 +79,10 @@ def validate_frames_properties(*, actual: Path, expected: Path): for frame_index, (d_actual, d_expected) in enumerate( zip(frames_actual, frames_expected) ): - assert all(required_prop in d_expected for required_prop in required_props) + # assert all(required_prop in d_expected for required_prop in required_props) + for prop in required_props: + assert prop in d_expected, f"{prop} not in {d_expected.keys()}" + for prop in d_expected: if prop == "pkt_pos": # pkt_pos is the position of the packet *in bytes* in its From 7e6b9b48bc41d63678294929e3dc5e1d5a470b2c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 3 Jul 2025 12:59:28 +0100 Subject: [PATCH 10/10] Fix ffmpeg version stuff --- test/test_encoders.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/test/test_encoders.py b/test/test_encoders.py index 1eddb7c58..bac5d91b9 100644 --- a/test/test_encoders.py +++ b/test/test_encoders.py @@ -79,9 +79,8 @@ def validate_frames_properties(*, actual: Path, expected: Path): for frame_index, (d_actual, d_expected) in enumerate( zip(frames_actual, frames_expected) ): - # assert all(required_prop in d_expected for required_prop in required_props) - for prop in required_props: - assert prop in d_expected, f"{prop} not in {d_expected.keys()}" + if get_ffmpeg_major_version() >= 6: + assert all(required_prop in d_expected for required_prop in required_props) for prop in d_expected: if prop == "pkt_pos":