diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index f451ee580..3b0e5b019 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1522,17 +1522,22 @@ std::optional VideoDecoder::maybeFlushSwrBuffers() { return std::nullopt; } - torch::Tensor lastSamples = torch::empty( - {getNumChannels(streamInfo.codecContext), numRemainingSamples}, - torch::kFloat32); - uint8_t* lastSamplesData = static_cast(lastSamples.data_ptr()); + auto numChannels = getNumChannels(streamInfo.codecContext); + torch::Tensor lastSamples = + torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); + + std::vector outputBuffers(numChannels); + for (auto i = 0; i < numChannels; i++) { + outputBuffers[i] = static_cast(lastSamples[i].data_ptr()); + } auto actualNumRemainingSamples = swr_convert( streamInfo.swrContext.get(), - &lastSamplesData, + outputBuffers.data(), numRemainingSamples, nullptr, 0); + return lastSamples.narrow( /*dim=*/1, /*start=*/0, /*length=*/actualNumRemainingSamples); } diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 6885b51a7..e63e4cd1d 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -1157,6 +1157,14 @@ def test_sample_rate_conversion(self, start_seconds, stop_seconds): rtol=rtol, ) + def test_sample_rate_conversion_stereo(self): + # Non-regression test for https://github.com/pytorch/torchcodec/pull/584 + asset = NASA_AUDIO_MP3 + assert asset.sample_rate == 8000 + assert asset.num_channels == 2 + decoder = AudioDecoder(asset.path, sample_rate=44_100) + decoder.get_samples_played_in_range(start_seconds=0) + def test_s16_ffmpeg4_bug(self): # s16 fails on FFmpeg4 but can be decoded on other versions. # Debugging logs show that we're hitting: