diff --git a/examples/audio_decoding.py b/examples/audio_decoding.py index 89c5f34ed..62b26c554 100644 --- a/examples/audio_decoding.py +++ b/examples/audio_decoding.py @@ -76,12 +76,6 @@ def play_audio(samples): # all streams start exactly at 0! This is not a bug in TorchCodec, this is a # property of the file that was defined when it was encoded. # -# We only output the *start* of the samples, not the end or the duration. Those can -# be easily derived from the number of samples and the sample rate: - -duration_seconds = samples.data.shape[1] / samples.sample_rate -print(f"Duration = {int(duration_seconds // 60)}m{int(duration_seconds % 60)}s.") - # %% # Specifying a range # ------------------ diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index 958db82fa..525c7ac8b 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -124,6 +124,8 @@ class AudioSamples(Iterable): """The sample data (``torch.Tensor`` of float in [-1, 1], shape is ``(num_channels, num_samples)``).""" pts_seconds: float """The :term:`pts` of the first sample, in seconds.""" + duration_seconds: float + """The duration of the sampleas, in seconds.""" sample_rate: int """The sample rate of the samples, in Hz.""" diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index cdf16dc82..bd91c4e22 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -139,8 +139,10 @@ def get_samples_played_in_range( else: offset_end = num_samples + data = frames[:, offset_beginning:offset_end] return AudioSamples( - data=frames[:, offset_beginning:offset_end], + data=data, pts_seconds=output_pts_seconds, + duration_seconds=data.shape[1] / sample_rate, sample_rate=sample_rate, ) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 0f19d680f..3e82106c0 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -993,7 +993,6 @@ def test_get_all_samples(self, asset, stop_seconds): torch.testing.assert_close(samples.data, reference_frames) assert samples.sample_rate == asset.sample_rate - assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) @@ -1215,3 +1214,10 @@ def test_s16_ffmpeg4_bug(self): ) with cm: decoder.get_samples_played_in_range(start_seconds=0) + + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + @pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000)) + def test_samples_duration(self, asset, sample_rate): + decoder = AudioDecoder(asset.path, sample_rate=sample_rate) + samples = decoder.get_samples_played_in_range(start_seconds=1, stop_seconds=2) + assert samples.duration_seconds == 1 diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py index e014cd4d3..73e06d208 100644 --- a/test/test_frame_dataclasses.py +++ b/test/test_frame_dataclasses.py @@ -5,7 +5,9 @@ def test_unpacking(): data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa - data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000) + data, pts_seconds, duration_seconds, sample_rate = AudioSamples( + torch.rand(2, 4), 2, 3, 16_000 + ) def test_frame_error(): @@ -147,11 +149,13 @@ def test_audio_samples_error(): AudioSamples( data=torch.rand(1), pts_seconds=1, + duration_seconds=1, sample_rate=16_000, ) with pytest.raises(ValueError, match="data must be 2-dimensional"): AudioSamples( data=torch.rand(1, 2, 3), pts_seconds=1, + duration_seconds=1, sample_rate=16_000, )