From 0e54f3e4ba76749fb262532b6130f2cf10f09f60 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 29 Apr 2025 10:01:04 +0100 Subject: [PATCH 1/2] Audio: Fix output shape when start==stop --- src/torchcodec/_core/SingleStreamDecoder.cpp | 7 ++++--- test/test_decoders.py | 2 +- test/test_ops.py | 2 +- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 17e1301d8..8fc5f1d6e 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -848,13 +848,14 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( std::to_string(*stopSecondsOptional) + ")."); } + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) { // For consistency with video - return AudioFramesOutput{torch::empty({0, 0}), 0.0}; + auto numChannels = getNumChannels(streamInfo.codecContext); + return AudioFramesOutput{torch::empty({numChannels, 0}), 0.0}; } - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); if (startPts < streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration) { diff --git a/test/test_decoders.py b/test/test_decoders.py index c68e1ace6..a5bc3b8d8 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1108,7 +1108,7 @@ def test_not_at_frame_boundaries(self, asset): def test_start_equals_stop(self, asset): decoder = AudioDecoder(asset.path) samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3) - assert samples.data.shape == (0, 0) + assert samples.data.shape == (asset.num_channels, 0) def test_frame_start_is_not_zero(self): # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125. diff --git a/test/test_ops.py b/test/test_ops.py index 158e3d082..4512301b3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -782,7 +782,7 @@ def test_decode_start_equal_stop(self, asset): frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=1, stop_seconds=1 ) - assert frames.shape == (0, 0) + assert frames.shape == (asset.num_channels, 0) assert pts_seconds == 0 @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) From 1765f098edf9724aede63b78bae08d3926f11196 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 29 Apr 2025 14:18:50 +0100 Subject: [PATCH 2/2] Use int --- src/torchcodec/_core/SingleStreamDecoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 8fc5f1d6e..95bdfd093 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -852,7 +852,7 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) { // For consistency with video - auto numChannels = getNumChannels(streamInfo.codecContext); + int numChannels = getNumChannels(streamInfo.codecContext); return AudioFramesOutput{torch::empty({numChannels, 0}), 0.0}; }