diff --git a/examples/audio_decoding.py b/examples/audio_decoding.py index 62b26c554..01509f291 100644 --- a/examples/audio_decoding.py +++ b/examples/audio_decoding.py @@ -62,7 +62,7 @@ def play_audio(samples): # :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method, # which returns an :class:`~torchcodec.AudioSamples` object: -samples = decoder.get_samples_played_in_range(start_seconds=0) +samples = decoder.get_samples_played_in_range() print(samples) play_audio(samples) diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index bd91c4e22..cac968e5d 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -70,9 +70,8 @@ def __init__( sample_rate if sample_rate is not None else self.metadata.sample_rate ) - # TODO-AUDIO: start_seconds should be 0 by default def get_samples_played_in_range( - self, start_seconds: float, stop_seconds: Optional[float] = None + self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None ) -> AudioSamples: """Returns audio samples in the given range. @@ -80,7 +79,7 @@ def get_samples_played_in_range( Args: start_seconds (float): Time, in seconds, of the start of the - range. + range. Default: 0. stop_seconds (float): Time, in seconds, of the end of the range. As a half open range, the end is excluded. diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 3e82106c0..0442e1df7 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -983,9 +983,7 @@ def test_get_all_samples(self, asset, stop_seconds): if stop_seconds == "duration": stop_seconds = asset.duration_seconds - samples = decoder.get_samples_played_in_range( - start_seconds=0, stop_seconds=stop_seconds - ) + samples = decoder.get_samples_played_in_range(stop_seconds=stop_seconds) reference_frames = asset.get_frame_data_by_range( start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1 @@ -1078,7 +1076,7 @@ def test_single_channel(self): asset = SINE_MONO_S32 decoder = AudioDecoder(asset.path) - samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2) + samples = decoder.get_samples_played_in_range(stop_seconds=2) assert samples.data.shape[0] == asset.num_channels == 1 def test_format_conversion(self): @@ -1086,7 +1084,7 @@ def test_format_conversion(self): decoder = AudioDecoder(asset.path) assert decoder.metadata.sample_format == asset.sample_format == "s32" - all_samples = decoder.get_samples_played_in_range(start_seconds=0) + all_samples = decoder.get_samples_played_in_range() assert all_samples.data.dtype == torch.float32 reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames) @@ -1163,7 +1161,7 @@ def test_sample_rate_conversion_stereo(self): assert asset.sample_rate == 8000 assert asset.num_channels == 2 decoder = AudioDecoder(asset.path, sample_rate=44_100) - decoder.get_samples_played_in_range(start_seconds=0) + decoder.get_samples_played_in_range() def test_downsample_empty_frame(self): # Non-regression test for @@ -1183,13 +1181,13 @@ def test_downsample_empty_frame(self): asset = NASA_AUDIO_MP3_44100 assert asset.sample_rate == 44_100 decoder = AudioDecoder(asset.path, sample_rate=8_000) - frames_44100_to_8000 = decoder.get_samples_played_in_range(start_seconds=0) + frames_44100_to_8000 = decoder.get_samples_played_in_range() # Just checking correctness now asset = NASA_AUDIO_MP3 assert asset.sample_rate == 8_000 decoder = AudioDecoder(asset.path) - frames_8000 = decoder.get_samples_played_in_range(start_seconds=0) + frames_8000 = decoder.get_samples_played_in_range() torch.testing.assert_close( frames_44100_to_8000.data, frames_8000.data, atol=0.03, rtol=0 ) @@ -1213,7 +1211,7 @@ def test_s16_ffmpeg4_bug(self): else contextlib.nullcontext() ) with cm: - decoder.get_samples_played_in_range(start_seconds=0) + decoder.get_samples_played_in_range() @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) @pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000))