Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/audio_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def play_audio(samples):
# :meth:`~torchcodec.decoders.AudioDecoder.get_samples_played_in_range` method,
# which returns an :class:`~torchcodec.AudioSamples` object:

samples = decoder.get_samples_played_in_range(start_seconds=0)
samples = decoder.get_samples_played_in_range()

print(samples)
play_audio(samples)
Expand Down
5 changes: 2 additions & 3 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,17 +70,16 @@ def __init__(
sample_rate if sample_rate is not None else self.metadata.sample_rate
)

# TODO-AUDIO: start_seconds should be 0 by default
def get_samples_played_in_range(
self, start_seconds: float, stop_seconds: Optional[float] = None
self, start_seconds: float = 0.0, stop_seconds: Optional[float] = None
) -> AudioSamples:
"""Returns audio samples in the given range.

Samples are in the half open range [start_seconds, stop_seconds).

Args:
start_seconds (float): Time, in seconds, of the start of the
range.
range. Default: 0.
stop_seconds (float): Time, in seconds, of the end of the
range. As a half open range, the end is excluded.

Expand Down
16 changes: 7 additions & 9 deletions test/decoders/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -983,9 +983,7 @@ def test_get_all_samples(self, asset, stop_seconds):
if stop_seconds == "duration":
stop_seconds = asset.duration_seconds

samples = decoder.get_samples_played_in_range(
start_seconds=0, stop_seconds=stop_seconds
)
samples = decoder.get_samples_played_in_range(stop_seconds=stop_seconds)

reference_frames = asset.get_frame_data_by_range(
start=0, stop=asset.get_frame_index(pts_seconds=asset.duration_seconds) + 1
Expand Down Expand Up @@ -1078,15 +1076,15 @@ def test_single_channel(self):
asset = SINE_MONO_S32
decoder = AudioDecoder(asset.path)

samples = decoder.get_samples_played_in_range(start_seconds=0, stop_seconds=2)
samples = decoder.get_samples_played_in_range(stop_seconds=2)
assert samples.data.shape[0] == asset.num_channels == 1

def test_format_conversion(self):
asset = SINE_MONO_S32
decoder = AudioDecoder(asset.path)
assert decoder.metadata.sample_format == asset.sample_format == "s32"

all_samples = decoder.get_samples_played_in_range(start_seconds=0)
all_samples = decoder.get_samples_played_in_range()
assert all_samples.data.dtype == torch.float32

reference_frames = asset.get_frame_data_by_range(start=0, stop=asset.num_frames)
Expand Down Expand Up @@ -1163,7 +1161,7 @@ def test_sample_rate_conversion_stereo(self):
assert asset.sample_rate == 8000
assert asset.num_channels == 2
decoder = AudioDecoder(asset.path, sample_rate=44_100)
decoder.get_samples_played_in_range(start_seconds=0)
decoder.get_samples_played_in_range()

def test_downsample_empty_frame(self):
# Non-regression test for
Expand All @@ -1183,13 +1181,13 @@ def test_downsample_empty_frame(self):
asset = NASA_AUDIO_MP3_44100
assert asset.sample_rate == 44_100
decoder = AudioDecoder(asset.path, sample_rate=8_000)
frames_44100_to_8000 = decoder.get_samples_played_in_range(start_seconds=0)
frames_44100_to_8000 = decoder.get_samples_played_in_range()

# Just checking correctness now
asset = NASA_AUDIO_MP3
assert asset.sample_rate == 8_000
decoder = AudioDecoder(asset.path)
frames_8000 = decoder.get_samples_played_in_range(start_seconds=0)
frames_8000 = decoder.get_samples_played_in_range()
torch.testing.assert_close(
frames_44100_to_8000.data, frames_8000.data, atol=0.03, rtol=0
)
Expand All @@ -1213,7 +1211,7 @@ def test_s16_ffmpeg4_bug(self):
else contextlib.nullcontext()
)
with cm:
decoder.get_samples_played_in_range(start_seconds=0)
decoder.get_samples_played_in_range()

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
@pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000))
Expand Down
Loading