diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 17e1301d8..95bdfd093 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -848,13 +848,14 @@ AudioFramesOutput SingleStreamDecoder::getFramesPlayedInRangeAudio( std::to_string(*stopSecondsOptional) + ")."); } + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; + if (stopSecondsOptional.has_value() && startSeconds == *stopSecondsOptional) { // For consistency with video - return AudioFramesOutput{torch::empty({0, 0}), 0.0}; + int numChannels = getNumChannels(streamInfo.codecContext); + return AudioFramesOutput{torch::empty({numChannels, 0}), 0.0}; } - StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); if (startPts < streamInfo.lastDecodedAvFramePts + streamInfo.lastDecodedAvFrameDuration) { diff --git a/test/test_decoders.py b/test/test_decoders.py index c68e1ace6..a5bc3b8d8 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1108,7 +1108,7 @@ def test_not_at_frame_boundaries(self, asset): def test_start_equals_stop(self, asset): decoder = AudioDecoder(asset.path) samples = decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=3) - assert samples.data.shape == (0, 0) + assert samples.data.shape == (asset.num_channels, 0) def test_frame_start_is_not_zero(self): # For NASA_AUDIO_MP3, the first frame is not at 0, it's at 0.138125. diff --git a/test/test_ops.py b/test/test_ops.py index 158e3d082..4512301b3 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -782,7 +782,7 @@ def test_decode_start_equal_stop(self, asset): frames, pts_seconds = get_frames_by_pts_in_range_audio( decoder, start_seconds=1, stop_seconds=1 ) - assert frames.shape == (0, 0) + assert frames.shape == (asset.num_channels, 0) assert pts_seconds == 0 @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3))