Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions examples/audio_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ def play_audio(samples):
# all streams start exactly at 0! This is not a bug in TorchCodec, this is a
# property of the file that was defined when it was encoded.
#
# We only output the *start* of the samples, not the end or the duration. Those can
# be easily derived from the number of samples and the sample rate:

duration_seconds = samples.data.shape[1] / samples.sample_rate
print(f"Duration = {int(duration_seconds // 60)}m{int(duration_seconds % 60)}s.")

# %%
# Specifying a range
# ------------------
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ class AudioSamples(Iterable):
"""The sample data (``torch.Tensor`` of float in [-1, 1], shape is ``(num_channels, num_samples)``)."""
pts_seconds: float
"""The :term:`pts` of the first sample, in seconds."""
duration_seconds: float
"""The duration of the sampleas, in seconds."""
sample_rate: int
"""The sample rate of the samples, in Hz."""

Expand Down
4 changes: 3 additions & 1 deletion src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ def get_samples_played_in_range(
else:
offset_end = num_samples

data = frames[:, offset_beginning:offset_end]
return AudioSamples(
data=frames[:, offset_beginning:offset_end],
data=data,
pts_seconds=output_pts_seconds,
duration_seconds=data.shape[1] / sample_rate,
sample_rate=sample_rate,
)
8 changes: 7 additions & 1 deletion test/decoders/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,6 @@ def test_get_all_samples(self, asset, stop_seconds):

torch.testing.assert_close(samples.data, reference_frames)
assert samples.sample_rate == asset.sample_rate

assert samples.pts_seconds == asset.get_frame_info(idx=0).pts_seconds

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
Expand Down Expand Up @@ -1215,3 +1214,10 @@ def test_s16_ffmpeg4_bug(self):
)
with cm:
decoder.get_samples_played_in_range(start_seconds=0)

@pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))
@pytest.mark.parametrize("sample_rate", (None, 8000, 16_000, 44_1000))
def test_samples_duration(self, asset, sample_rate):
decoder = AudioDecoder(asset.path, sample_rate=sample_rate)
samples = decoder.get_samples_played_in_range(start_seconds=1, stop_seconds=2)
assert samples.duration_seconds == 1
6 changes: 5 additions & 1 deletion test/test_frame_dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@

def test_unpacking():
data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa
data, pts_seconds, sample_rate = AudioSamples(torch.rand(2, 4), 2, 16_000)
data, pts_seconds, duration_seconds, sample_rate = AudioSamples(
torch.rand(2, 4), 2, 3, 16_000
)


def test_frame_error():
Expand Down Expand Up @@ -147,11 +149,13 @@ def test_audio_samples_error():
AudioSamples(
data=torch.rand(1),
pts_seconds=1,
duration_seconds=1,
sample_rate=16_000,
)
with pytest.raises(ValueError, match="data must be 2-dimensional"):
AudioSamples(
data=torch.rand(1, 2, 3),
pts_seconds=1,
duration_seconds=1,
sample_rate=16_000,
)
Loading