diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index aad3c23c1..197221086 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -116,16 +116,17 @@ void setChannelLayout( #endif } -SwrContext* allocateSwrContext( +SwrContext* createSwrContext( UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, int desiredSampleRate) { SwrContext* swrContext = nullptr; + int status = AVSUCCESS; #if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 AVChannelLayout layout = avCodecContext->ch_layout; - auto status = swr_alloc_set_opts2( + status = swr_alloc_set_opts2( &swrContext, &layout, desiredSampleFormat, @@ -155,9 +156,77 @@ SwrContext* allocateSwrContext( #endif TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext"); + status = swr_init(swrContext); + TORCH_CHECK( + status == AVSUCCESS, + "Couldn't initialize SwrContext: ", + getFFMPEGErrorStringFromErrorCode(status), + ". If the error says 'Invalid argument', it's likely that you are using " + "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " + "valid scenarios. Try to upgrade FFmpeg?"); return swrContext; } +UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( + const UniqueSwrContext& swrContext, + const UniqueAVFrame& srcAVFrame, + AVSampleFormat desiredSampleFormat, + int sourceSampleRate, + int desiredSampleRate) { + UniqueAVFrame convertedAVFrame(av_frame_alloc()); + TORCH_CHECK( + convertedAVFrame, + "Could not allocate frame for sample format conversion."); + + setChannelLayout(convertedAVFrame, srcAVFrame); + convertedAVFrame->format = static_cast(desiredSampleFormat); + convertedAVFrame->sample_rate = desiredSampleRate; + if (sourceSampleRate != desiredSampleRate) { + // Note that this is an upper bound on the number of output samples. + // `swr_convert()` will likely not fill convertedAVFrame with that many + // samples if sample rate conversion is needed. It will buffer the last few + // ones because those require future samples. That's also why we reset + // nb_samples after the call to `swr_convert()`. + // We could also use `swr_get_out_samples()` to determine the number of + // output samples, but empirically `av_rescale_rnd()` seems to provide a + // tighter bound. + convertedAVFrame->nb_samples = av_rescale_rnd( + swr_get_delay(swrContext.get(), sourceSampleRate) + + srcAVFrame->nb_samples, + desiredSampleRate, + sourceSampleRate, + AV_ROUND_UP); + } else { + convertedAVFrame->nb_samples = srcAVFrame->nb_samples; + } + + auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); + TORCH_CHECK( + status == AVSUCCESS, + "Could not allocate frame buffers for sample format conversion: ", + getFFMPEGErrorStringFromErrorCode(status)); + + auto numConvertedSamples = swr_convert( + swrContext.get(), + convertedAVFrame->data, + convertedAVFrame->nb_samples, + static_cast( + const_cast(srcAVFrame->data)), + srcAVFrame->nb_samples); + // numConvertedSamples can be 0 if we're downsampling by a great factor and + // the first frame doesn't contain a lot of samples. It should be handled + // properly by the caller. + TORCH_CHECK( + numConvertedSamples >= 0, + "Error in swr_convert: ", + getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); + + // See comment above about nb_samples + convertedAVFrame->nb_samples = numConvertedSamples; + + return convertedAVFrame; +} + void setFFmpegLogLevel() { auto logLevel = AV_LOG_QUIET; const char* logLevelEnvPtr = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL"); diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 81a9fb8f2..8c4abd13e 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -158,13 +158,20 @@ void setChannelLayout( void setChannelLayout( UniqueAVFrame& dstAVFrame, const UniqueAVFrame& srcAVFrame); -SwrContext* allocateSwrContext( +SwrContext* createSwrContext( UniqueAVCodecContext& avCodecContext, AVSampleFormat sourceSampleFormat, AVSampleFormat desiredSampleFormat, int sourceSampleRate, int desiredSampleRate); +UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( + const UniqueSwrContext& swrContext, + const UniqueAVFrame& srcAVFrame, + AVSampleFormat desiredSampleFormat, + int sourceSampleRate, + int desiredSampleRate); + // Returns true if sws_scale can handle unaligned data. bool canSwsScaleHandleUnalignedData(); diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 02ea5ee1e..17e1301d8 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1345,10 +1345,10 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( static_cast(srcAVFrame->format); AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP; + StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; int sourceSampleRate = srcAVFrame->sample_rate; int desiredSampleRate = - streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or( - sourceSampleRate); + streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate); bool mustConvert = (sourceSampleFormat != desiredSampleFormat || @@ -1356,9 +1356,18 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( UniqueAVFrame convertedAVFrame; if (mustConvert) { + if (!streamInfo.swrContext) { + streamInfo.swrContext.reset(createSwrContext( + streamInfo.codecContext, + sourceSampleFormat, + desiredSampleFormat, + sourceSampleRate, + desiredSampleRate)); + } + convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate( + streamInfo.swrContext, srcAVFrame, - sourceSampleFormat, desiredSampleFormat, sourceSampleRate, desiredSampleRate); @@ -1393,77 +1402,6 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU( } } -UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate( - const UniqueAVFrame& srcAVFrame, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - auto& streamInfo = streamInfos_[activeStreamIndex_]; - - if (!streamInfo.swrContext) { - createSwrContext( - streamInfo, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); - } - - UniqueAVFrame convertedAVFrame(av_frame_alloc()); - TORCH_CHECK( - convertedAVFrame, - "Could not allocate frame for sample format conversion."); - - setChannelLayout(convertedAVFrame, srcAVFrame); - convertedAVFrame->format = static_cast(desiredSampleFormat); - convertedAVFrame->sample_rate = desiredSampleRate; - if (sourceSampleRate != desiredSampleRate) { - // Note that this is an upper bound on the number of output samples. - // `swr_convert()` will likely not fill convertedAVFrame with that many - // samples if sample rate conversion is needed. It will buffer the last few - // ones because those require future samples. That's also why we reset - // nb_samples after the call to `swr_convert()`. - // We could also use `swr_get_out_samples()` to determine the number of - // output samples, but empirically `av_rescale_rnd()` seems to provide a - // tighter bound. - convertedAVFrame->nb_samples = av_rescale_rnd( - swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) + - srcAVFrame->nb_samples, - desiredSampleRate, - sourceSampleRate, - AV_ROUND_UP); - } else { - convertedAVFrame->nb_samples = srcAVFrame->nb_samples; - } - - auto status = av_frame_get_buffer(convertedAVFrame.get(), 0); - TORCH_CHECK( - status == AVSUCCESS, - "Could not allocate frame buffers for sample format conversion: ", - getFFMPEGErrorStringFromErrorCode(status)); - - auto numConvertedSamples = swr_convert( - streamInfo.swrContext.get(), - convertedAVFrame->data, - convertedAVFrame->nb_samples, - static_cast( - const_cast(srcAVFrame->data)), - srcAVFrame->nb_samples); - // numConvertedSamples can be 0 if we're downsampling by a great factor and - // the first frame doesn't contain a lot of samples. It should be handled - // properly by the caller. - TORCH_CHECK( - numConvertedSamples >= 0, - "Error in swr_convert: ", - getFFMPEGErrorStringFromErrorCode(numConvertedSamples)); - - // See comment above about nb_samples - convertedAVFrame->nb_samples = numConvertedSamples; - - return convertedAVFrame; -} - std::optional SingleStreamDecoder::maybeFlushSwrBuffers() { // When sample rate conversion is involved, swresample buffers some of the // samples in-between calls to swr_convert (see the libswresample docs). @@ -1735,30 +1673,6 @@ void SingleStreamDecoder::createSwsContext( streamInfo.swsContext.reset(swsContext); } -void SingleStreamDecoder::createSwrContext( - StreamInfo& streamInfo, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate) { - auto swrContext = allocateSwrContext( - streamInfo.codecContext, - sourceSampleFormat, - desiredSampleFormat, - sourceSampleRate, - desiredSampleRate); - - auto status = swr_init(swrContext); - TORCH_CHECK( - status == AVSUCCESS, - "Couldn't initialize SwrContext: ", - getFFMPEGErrorStringFromErrorCode(status), - ". If the error says 'Invalid argument', it's likely that you are using " - "a buggy FFmpeg version. FFmpeg4 is known to fail here in some " - "valid scenarios. Try to upgrade FFmpeg?"); - streamInfo.swrContext.reset(swrContext); -} - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index d532675e8..d85151119 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -287,13 +287,6 @@ class SingleStreamDecoder { const UniqueAVFrame& avFrame, torch::Tensor& outputTensor); - UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate( - const UniqueAVFrame& srcAVFrame, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); - std::optional maybeFlushSwrBuffers(); // -------------------------------------------------------------------------- @@ -310,13 +303,6 @@ class SingleStreamDecoder { const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); - void createSwrContext( - StreamInfo& streamInfo, - AVSampleFormat sourceSampleFormat, - AVSampleFormat desiredSampleFormat, - int sourceSampleRate, - int desiredSampleRate); - // -------------------------------------------------------------------------- // PTS <-> INDEX CONVERSIONS // --------------------------------------------------------------------------