diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index 1c876f4ef..8a29d0654 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -293,18 +293,20 @@ void AudioEncoder::encodeInnerLoop( if (mustConvert) { if (!swrContext_) { swrContext_.reset(createSwrContext( - avCodecContext_, AV_SAMPLE_FMT_FLTP, avCodecContext_->sample_fmt, srcAVFrame->sample_rate, // No sample rate conversion - srcAVFrame->sample_rate)); + srcAVFrame->sample_rate, + srcAVFrame, + getNumChannels(srcAVFrame) // No num_channel conversion + )); } - convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( + convertedAVFrame = convertAudioAVFrameSamples( swrContext_, srcAVFrame, avCodecContext_->sample_fmt, srcAVFrame->sample_rate, // No sample rate conversion - srcAVFrame->sample_rate); + getNumChannels(srcAVFrame)); // No num_channel conversion TORCH_CHECK( convertedAVFrame->nb_samples == srcAVFrame->nb_samples, "convertedAVFrame->nb_samples=", diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index a8da49e85..262b67dfd 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -81,7 +81,6 @@ void setDefaultChannelLayout( AVChannelLayout channel_layout; av_channel_layout_default(&channel_layout, numChannels); avCodecContext->ch_layout = channel_layout; - #else uint64_t channel_layout = av_get_default_channel_layout(numChannels); avCodecContext->channel_layout = channel_layout; @@ -106,32 +105,79 @@ void setChannelLayout( #endif } +namespace { +#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 + +// Returns: +// - the srcAVFrame's channel layout if srcAVFrame has desiredNumChannels +// - the default channel layout with desiredNumChannels otherwise. +AVChannelLayout getDesiredChannelLayout( + int desiredNumChannels, + const UniqueAVFrame& srcAVFrame) { + AVChannelLayout desiredLayout; + if (desiredNumChannels == getNumChannels(srcAVFrame)) { + desiredLayout = srcAVFrame->ch_layout; + } else { + av_channel_layout_default(&desiredLayout, desiredNumChannels); + } + return desiredLayout; +} + +#else + +// Same as above +int64_t getDesiredChannelLayout( + int desiredNumChannels, + const UniqueAVFrame& srcAVFrame) { + int64_t desiredLayout; + if (desiredNumChannels == getNumChannels(srcAVFrame)) { + desiredLayout = srcAVFrame->channel_layout; + } else { + desiredLayout = av_get_default_channel_layout(desiredNumChannels); + } + return desiredLayout; +} +#endif +} // namespace + +// Sets dstAVFrame' channel layout to getDesiredChannelLayout(): see doc above void setChannelLayout( UniqueAVFrame& dstAVFrame, - const UniqueAVFrame& srcAVFrame) { + const UniqueAVFrame& srcAVFrame, + int desiredNumChannels) { #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 - dstAVFrame->ch_layout = srcAVFrame->ch_layout; + AVChannelLayout desiredLayout = + getDesiredChannelLayout(desiredNumChannels, srcAVFrame); + auto status = av_channel_layout_copy(&dstAVFrame->ch_layout, &desiredLayout); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't copy channel layout to avFrame: ", + getFFMPEGErrorStringFromErrorCode(status)); #else - dstAVFrame->channel_layout = srcAVFrame->channel_layout; + dstAVFrame->channel_layout = + getDesiredChannelLayout(desiredNumChannels, srcAVFrame); + dstAVFrame->channels = desiredNumChannels; #endif } SwrContext* createSwrContext( - UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, - int desiredSampleRate) { + int desiredSampleRate, + const UniqueAVFrame& srcAVFrame, + int desiredNumChannels) { SwrContext* swrContext = nullptr; int status = AVSUCCESS; #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 - AVChannelLayout layout = avCodecContext->ch_layout; + AVChannelLayout desiredLayout = + getDesiredChannelLayout(desiredNumChannels, srcAVFrame); status = swr_alloc_set_opts2( &swrContext, - &layout, + &desiredLayout, desiredSampleFormat, desiredSampleRate, - &layout, + &srcAVFrame->ch_layout, sourceSampleFormat, sourceSampleRate, 0, @@ -142,13 +188,14 @@ SwrContext* createSwrContext( "Couldn't create SwrContext: ", getFFMPEGErrorStringFromErrorCode(status)); #else - int64_t layout = static_cast(avCodecContext->channel_layout); + int64_t desiredLayout = + getDesiredChannelLayout(desiredNumChannels, srcAVFrame); swrContext = swr_alloc_set_opts( nullptr, - layout, + desiredLayout, desiredSampleFormat, desiredSampleRate, - layout, + srcAVFrame->channel_layout, sourceSampleFormat, sourceSampleRate, 0, @@ -167,20 +214,21 @@ SwrContext* createSwrContext( return swrContext; } -UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( +UniqueAVFrame convertAudioAVFrameSamples( const UniqueSwrContext& swrContext, const UniqueAVFrame& srcAVFrame, AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { + int desiredSampleRate, + int desiredNumChannels) { UniqueAVFrame convertedAVFrame(av_frame_alloc()); TORCH_CHECK( convertedAVFrame, "Could not allocate frame for sample format conversion."); - setChannelLayout(convertedAVFrame, srcAVFrame); convertedAVFrame->format = static_cast(desiredSampleFormat); + convertedAVFrame->sample_rate = desiredSampleRate; + int sourceSampleRate = srcAVFrame->sample_rate; if (sourceSampleRate != desiredSampleRate) { // Note that this is an upper bound on the number of output samples. // `swr_convert()` will likely not fill convertedAVFrame with that many @@ -200,6 +248,8 @@ UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( convertedAVFrame->nb_samples = srcAVFrame->nb_samples; } + setChannelLayout(convertedAVFrame, srcAVFrame, desiredNumChannels); + auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); TORCH_CHECK( status == AVSUCCESS, diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 308dec484..4281689e2 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -157,20 +157,28 @@ void setChannelLayout( void setChannelLayout( UniqueAVFrame& dstAVFrame, - const UniqueAVFrame& srcAVFrame); + const UniqueAVFrame& srcAVFrame, + int desiredNumChannels); + SwrContext* createSwrContext( - UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, - int desiredSampleRate); - -UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( + int desiredSampleRate, + const UniqueAVFrame& srcAVFrame, + int desiredNumChannels); + +// Converts, if needed: +// - sample format +// - sample rate +// - number of channels. +// createSwrContext must have been previously called with matching parameters. +UniqueAVFrame convertAudioAVFrameSamples( const UniqueSwrContext& swrContext, const UniqueAVFrame& srcAVFrame, AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); + int desiredSampleRate, + int desiredNumChannels); // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index cafbc70ea..e2c55ef29 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -453,6 +453,13 @@ void SingleStreamDecoder::addAudioStream( TORCH_CHECK( seekMode_ == SeekMode::approximate, "seek_mode must be 'approximate' for audio streams."); + if (audioStreamOptions.numChannels.has_value()) { + TORCH_CHECK( + *audioStreamOptions.numChannels > 0 && + *audioStreamOptions.numChannels <= AV_NUM_DATA_POINTERS, + "num_channels must be > 0 and <= AV_NUM_DATA_POINTERS (usually 8). Got: ", + *audioStreamOptions.numChannels); + } addStream(streamIndex, AVMEDIA_TYPE_AUDIO); @@ -1171,27 +1178,42 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( int desiredSampleRate = streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate); + int sourceNumChannels = getNumChannels(streamInfo.codecContext); + TORCH_CHECK( + sourceNumChannels == getNumChannels(srcAVFrame), + "The frame has ", + getNumChannels(srcAVFrame), + " channels, expected ", + sourceNumChannels, + ". If you are hitting this, it may be because you are using " + "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " + "valid scenarios. Try to upgrade FFmpeg?"); + int desiredNumChannels = + streamInfo.audioStreamOptions.numChannels.value_or(sourceNumChannels); + bool mustConvert = (sourceSampleFormat != desiredSampleFormat || - sourceSampleRate != desiredSampleRate); + sourceSampleRate != desiredSampleRate || + sourceNumChannels != desiredNumChannels); UniqueAVFrame convertedAVFrame; if (mustConvert) { if (!streamInfo.swrContext) { streamInfo.swrContext.reset(createSwrContext( - streamInfo.codecContext, sourceSampleFormat, desiredSampleFormat, sourceSampleRate, - desiredSampleRate)); + desiredSampleRate, + srcAVFrame, + desiredNumChannels)); } - convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( + convertedAVFrame = convertAudioAVFrameSamples( streamInfo.swrContext, srcAVFrame, desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); + desiredSampleRate, + desiredNumChannels); } const UniqueAVFrame& avFrame = mustConvert ? convertedAVFrame : srcAVFrame; @@ -1204,8 +1226,17 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( "source format = ", av_get_sample_fmt_name(format)); + int numChannels = getNumChannels(avFrame); + TORCH_CHECK( + numChannels == desiredNumChannels, + "Something went wrong, the frame didn't get converted to the desired ", + "number of channels = ", + desiredNumChannels, + ". Got ", + numChannels, + " instead."); + auto numSamples = avFrame->nb_samples; // per channel - auto numChannels = getNumChannels(avFrame); frameOutput.data = torch::empty({numChannels, numSamples}, torch::kFloat32); @@ -1240,7 +1271,8 @@ std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { return std::nullopt; } - auto numChannels = getNumChannels(streamInfo.codecContext); + int numChannels = streamInfo.audioStreamOptions.numChannels.value_or( + getNumChannels(streamInfo.codecContext)); torch::Tensor lastSamples = torch::empty({numChannels, numRemainingSamples}, torch::kFloat32); diff --git a/src/torchcodec/_core/StreamOptions.h b/src/torchcodec/_core/StreamOptions.h index 38e51209c..ef250da09 100644 --- a/src/torchcodec/_core/StreamOptions.h +++ b/src/torchcodec/_core/StreamOptions.h @@ -44,6 +44,7 @@ struct AudioStreamOptions { AudioStreamOptions() {} std::optional sampleRate; + std::optional numChannels; }; } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 813c53a7f..1355045a5 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -40,7 +40,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None) -> ()"); m.def( - "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None) -> ()"); + "add_audio_stream(Tensor(a!) decoder, *, int? stream_index=None, int? sample_rate=None, int? num_channels=None) -> ()"); m.def("seek_to_pts(Tensor(a!) decoder, float seconds) -> ()"); m.def("get_next_frame(Tensor(a!) decoder) -> (Tensor, Tensor, Tensor)"); m.def( @@ -280,9 +280,11 @@ void add_video_stream( void add_audio_stream( at::Tensor& decoder, std::optional stream_index = std::nullopt, - std::optional sample_rate = std::nullopt) { + std::optional sample_rate = std::nullopt, + std::optional num_channels = std::nullopt) { AudioStreamOptions audioStreamOptions; audioStreamOptions.sampleRate = sample_rate; + audioStreamOptions.numChannels = num_channels; auto videoDecoder = unwrapTensorToGetDecoder(decoder); videoDecoder->addAudioStream(stream_index.value_or(-1), audioStreamOptions); diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index e9b4faecf..1240d2d63 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -221,6 +221,8 @@ def add_audio_stream_abstract( decoder: torch.Tensor, *, stream_index: Optional[int] = None, + sample_rate: Optional[int] = None, + num_channels: Optional[int] = None, ) -> None: return diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index 0fcab7008..54d7e4583 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -39,7 +39,9 @@ class AudioDecoder: Note that this index is absolute across all media types. If left unspecified, then the :term:`best stream` is used. sample_rate (int, optional): The desired output sample rate of the decoded samples. - By default, the samples are returned in their original sample rate. + By default, the sample rate of the source is used. + num_channels (int, optional): The desired number of channels of the decoded samples. + By default, the number of channels of the source is used. Attributes: metadata (AudioStreamMetadata): Metadata of the audio stream. @@ -54,11 +56,15 @@ def __init__( *, stream_index: Optional[int] = None, sample_rate: Optional[int] = None, + num_channels: Optional[int] = None, ): self._decoder = create_decoder(source=source, seek_mode="approximate") core.add_audio_stream( - self._decoder, stream_index=stream_index, sample_rate=sample_rate + self._decoder, + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, ) container_metadata = core.get_container_metadata(self._decoder) diff --git a/test/test_decoders.py b/test/test_decoders.py index a0269c3f1..ddd35ff3f 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1292,7 +1292,7 @@ def test_s16_ffmpeg4_bug(self): assert decoder.metadata.sample_format == asset.sample_format cm = ( - pytest.raises(RuntimeError, match="Invalid argument") + pytest.raises(RuntimeError, match="The frame has 0 channels, expected 1.") if get_ffmpeg_major_version() == 4 else contextlib.nullcontext() ) @@ -1305,3 +1305,33 @@ def test_samples_duration(self, asset, sample_rate): decoder = AudioDecoder(asset.path, sample_rate=sample_rate) samples = decoder.get_samples_played_in_range(start_seconds=1, stop_seconds=2) assert samples.duration_seconds == 1 + + @pytest.mark.parametrize("asset", (SINE_MONO_S32, NASA_AUDIO_MP3)) + # Note that we parametrize over sample_rate as well, so that we can ensure + # that the extra tensor allocation that happens within + # maybeFlushSwrBuffers() is correct. + @pytest.mark.parametrize("sample_rate", (None, 16_000)) + # FFmpeg can handle up to AV_NUM_DATA_POINTERS=8 channels + @pytest.mark.parametrize("num_channels", (1, 2, 8, None)) + def test_num_channels(self, asset, sample_rate, num_channels): + decoder = AudioDecoder( + asset.path, sample_rate=sample_rate, num_channels=num_channels + ) + samples = decoder.get_all_samples() + + if num_channels is None: + num_channels = asset.num_channels + + assert samples.data.shape[0] == num_channels + + @pytest.mark.parametrize("asset", (SINE_MONO_S32, NASA_AUDIO_MP3)) + def test_num_channels_errors(self, asset): + with pytest.raises( + RuntimeError, match="num_channels must be > 0 and <= AV_NUM_DATA_POINTERS" + ): + AudioDecoder(asset.path, num_channels=0) + with pytest.raises( + RuntimeError, match="num_channels must be > 0 and <= AV_NUM_DATA_POINTERS" + ): + # FFmpeg can handle up to AV_NUM_DATA_POINTERS=8 channels + AudioDecoder(asset.path, num_channels=9)