diff --git a/benchmarks/decoders/benchmark_audio_decoders.py b/benchmarks/decoders/benchmark_audio_decoders.py index f10da1204..bfffbf12f 100644 --- a/benchmarks/decoders/benchmark_audio_decoders.py +++ b/benchmarks/decoders/benchmark_audio_decoders.py @@ -71,7 +71,7 @@ def get_duration(path: Path) -> str: def decode_with_torchcodec(path: Path) -> None: - AudioDecoder(path).get_samples_played_in_range(start_seconds=0, stop_seconds=None) + AudioDecoder(path).get_all_samples() def decode_with_torchaudio_StreamReader(path: Path) -> None: diff --git a/examples/audio_decoding.py b/examples/audio_decoding.py index 01509f291..c2762a104 100644 --- a/examples/audio_decoding.py +++ b/examples/audio_decoding.py @@ -59,10 +59,10 @@ def play_audio(samples): # ---------------- # # To get decoded samples, we just need to call the -# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method, +# :meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` method, # which returns an :class:`~torchcodec.AudioSamples` object: -samples = decoder.get_samples_played_in_range() +samples = decoder.get_all_samples() print(samples) play_audio(samples) @@ -80,9 +80,9 @@ def play_audio(samples): # Specifying a range # ------------------ # -# By default, -# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` decodes -# the entire audio stream, but we can specify a custom range: +# If we don't need all the samples, we can use +# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` to +# decode the samples within a custom range: samples = decoder.get_samples_played_in_range(start_seconds=10, stop_seconds=70) @@ -99,7 +99,7 @@ def play_audio(samples): # increased: decoder = AudioDecoder(raw_audio_bytes, sample_rate=16_000) -samples = decoder.get_samples_played_in_range(start_seconds=0) +samples = decoder.get_all_samples() print(samples) play_audio(samples) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index d0e7ede00..80fd87e54 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -75,6 +75,17 @@ def __init__( sample_rate if sample_rate is not None else self.metadata.sample_rate ) + def get_all_samples(self) -> AudioSamples: + """Returns all the audio samples from the source. + + To decode samples in a specific range, use + :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range`. + + Returns: + AudioSamples: The samples within the file. + """ + return self.get_samples_played_in_range() + def get_samples_played_in_range( self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None ) -> AudioSamples: @@ -82,11 +93,18 @@ def get_samples_played_in_range( Samples are in the half open range [start_seconds, stop_seconds). + To decode all the samples from beginning to end, you can call this + method while leaving ``start_seconds`` and ``stop_seconds`` to their + default values, or use + :meth:`~torchcodec.decoders.AudioDecoder.get_all_samples` as a more + convenient alias. + Args: start_seconds (float): Time, in seconds, of the start of the range. Default: 0. - stop_seconds (float): Time, in seconds, of the end of the - range. As a half open range, the end is excluded. + stop_seconds (float or None): Time, in seconds, of the end of the + range. As a half open range, the end is excluded. Default: None, + which decodes samples until the end. Returns: AudioSamples: The samples within the specified range. diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c480ed3ea..4d3e2f2ce 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -239,7 +239,6 @@ class VideoDecoder { double startSeconds, double stopSeconds); - // TODO-AUDIO: Should accept sampleRate AudioFramesOutput getFramesPlayedInRangeAudio( double startSeconds, std::optional stopSecondsOptional = std::nullopt); diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index e8dfc675c..cc47e116d 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -982,7 +982,7 @@ def test_negative_start(self, asset): @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) @pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999)) - def test_get_all_samples(self, asset, stop_seconds): + def test_get_all_samples_with_range(self, asset, stop_seconds): decoder = AudioDecoder(asset.path) if stop_seconds == "duration": @@ -998,6 +998,14 @@ def test_get_all_samples(self, asset, stop_seconds): assert samples.sample_rate == asset.sample_rate assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_get_all_samples(self, asset): + decoder = AudioDecoder(asset.path) + torch.testing.assert_close( + decoder.get_all_samples().data, + decoder.get_samples_played_in_range().data, + ) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_at_frame_boundaries(self, asset): decoder = AudioDecoder(asset.path)