Skip to content
73 changes: 71 additions & 2 deletions src/torchcodec/_core/FFMPEGCommon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,16 +116,17 @@ void setChannelLayout(
#endif
}

SwrContext* allocateSwrContext(
SwrContext* createSwrContext(
UniqueAVCodecContext& avCodecContext,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
SwrContext* swrContext = nullptr;
int status = AVSUCCESS;
#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4
AVChannelLayout layout = avCodecContext->ch_layout;
auto status = swr_alloc_set_opts2(
status = swr_alloc_set_opts2(
&swrContext,
&layout,
desiredSampleFormat,
Expand Down Expand Up @@ -155,9 +156,77 @@ SwrContext* allocateSwrContext(
#endif

TORCH_CHECK(swrContext != nullptr, "Couldn't create swrContext");
status = swr_init(swrContext);
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't initialize SwrContext: ",
getFFMPEGErrorStringFromErrorCode(status),
". If the error says 'Invalid argument', it's likely that you are using "
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
"valid scenarios. Try to upgrade FFmpeg?");
return swrContext;
}

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueSwrContext& swrContext,
const UniqueAVFrame& srcAVFrame,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: we should be consistent about src and source. I have a preference for src, as it's a universal abbreviation, particularly when paired with dst. But if we say source in a lot of other places, we should stick with that.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double nit, and I recognize this name already existed: convertAudioAVFrameSampleFormatAndSampleRate() is very long, and I feel like we're encoding parameter names that modify the operation into the name. I feel like it's clearer as just convertAudioAVFrame().

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good, I'll merge as-is and follow-up with a PR to address these

AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
UniqueAVFrame convertedAVFrame(av_frame_alloc());
TORCH_CHECK(
convertedAVFrame,
"Could not allocate frame for sample format conversion.");

setChannelLayout(convertedAVFrame, srcAVFrame);
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
convertedAVFrame->sample_rate = desiredSampleRate;
if (sourceSampleRate != desiredSampleRate) {
// Note that this is an upper bound on the number of output samples.
// `swr_convert()` will likely not fill convertedAVFrame with that many
// samples if sample rate conversion is needed. It will buffer the last few
// ones because those require future samples. That's also why we reset
// nb_samples after the call to `swr_convert()`.
// We could also use `swr_get_out_samples()` to determine the number of
// output samples, but empirically `av_rescale_rnd()` seems to provide a
// tighter bound.
convertedAVFrame->nb_samples = av_rescale_rnd(
swr_get_delay(swrContext.get(), sourceSampleRate) +
srcAVFrame->nb_samples,
desiredSampleRate,
sourceSampleRate,
AV_ROUND_UP);
} else {
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
}

auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
TORCH_CHECK(
status == AVSUCCESS,
"Could not allocate frame buffers for sample format conversion: ",
getFFMPEGErrorStringFromErrorCode(status));

auto numConvertedSamples = swr_convert(
swrContext.get(),
convertedAVFrame->data,
convertedAVFrame->nb_samples,
static_cast<const uint8_t**>(
const_cast<const uint8_t**>(srcAVFrame->data)),
srcAVFrame->nb_samples);
// numConvertedSamples can be 0 if we're downsampling by a great factor and
// the first frame doesn't contain a lot of samples. It should be handled
// properly by the caller.
TORCH_CHECK(
numConvertedSamples >= 0,
"Error in swr_convert: ",
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));

// See comment above about nb_samples
convertedAVFrame->nb_samples = numConvertedSamples;

return convertedAVFrame;
}

void setFFmpegLogLevel() {
auto logLevel = AV_LOG_QUIET;
const char* logLevelEnvPtr = std::getenv("TORCHCODEC_FFMPEG_LOG_LEVEL");
Expand Down
9 changes: 8 additions & 1 deletion src/torchcodec/_core/FFMPEGCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,13 +158,20 @@ void setChannelLayout(
void setChannelLayout(
UniqueAVFrame& dstAVFrame,
const UniqueAVFrame& srcAVFrame);
SwrContext* allocateSwrContext(
SwrContext* createSwrContext(
UniqueAVCodecContext& avCodecContext,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueSwrContext& swrContext,
const UniqueAVFrame& srcAVFrame,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

// Returns true if sws_scale can handle unaligned data.
bool canSwsScaleHandleUnalignedData();

Expand Down
110 changes: 12 additions & 98 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1345,20 +1345,29 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
static_cast<AVSampleFormat>(srcAVFrame->format);
AVSampleFormat desiredSampleFormat = AV_SAMPLE_FMT_FLTP;

StreamInfo& streamInfo = streamInfos_[activeStreamIndex_];
int sourceSampleRate = srcAVFrame->sample_rate;
int desiredSampleRate =
streamInfos_[activeStreamIndex_].audioStreamOptions.sampleRate.value_or(
sourceSampleRate);
streamInfo.audioStreamOptions.sampleRate.value_or(sourceSampleRate);

bool mustConvert =
(sourceSampleFormat != desiredSampleFormat ||
sourceSampleRate != desiredSampleRate);

UniqueAVFrame convertedAVFrame;
if (mustConvert) {
if (!streamInfo.swrContext) {
streamInfo.swrContext.reset(createSwrContext(
streamInfo.codecContext,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate));
}

convertedAVFrame = convertAudioAVFrameSampleFormatAndSampleRate(
streamInfo.swrContext,
srcAVFrame,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);
Expand Down Expand Up @@ -1393,77 +1402,6 @@ void SingleStreamDecoder::convertAudioAVFrameToFrameOutputOnCPU(
}
}

UniqueAVFrame SingleStreamDecoder::convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueAVFrame& srcAVFrame,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
auto& streamInfo = streamInfos_[activeStreamIndex_];

if (!streamInfo.swrContext) {
createSwrContext(
streamInfo,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);
}

UniqueAVFrame convertedAVFrame(av_frame_alloc());
TORCH_CHECK(
convertedAVFrame,
"Could not allocate frame for sample format conversion.");

setChannelLayout(convertedAVFrame, srcAVFrame);
convertedAVFrame->format = static_cast<int>(desiredSampleFormat);
convertedAVFrame->sample_rate = desiredSampleRate;
if (sourceSampleRate != desiredSampleRate) {
// Note that this is an upper bound on the number of output samples.
// `swr_convert()` will likely not fill convertedAVFrame with that many
// samples if sample rate conversion is needed. It will buffer the last few
// ones because those require future samples. That's also why we reset
// nb_samples after the call to `swr_convert()`.
// We could also use `swr_get_out_samples()` to determine the number of
// output samples, but empirically `av_rescale_rnd()` seems to provide a
// tighter bound.
convertedAVFrame->nb_samples = av_rescale_rnd(
swr_get_delay(streamInfo.swrContext.get(), sourceSampleRate) +
srcAVFrame->nb_samples,
desiredSampleRate,
sourceSampleRate,
AV_ROUND_UP);
} else {
convertedAVFrame->nb_samples = srcAVFrame->nb_samples;
}

auto status = av_frame_get_buffer(convertedAVFrame.get(), 0);
TORCH_CHECK(
status == AVSUCCESS,
"Could not allocate frame buffers for sample format conversion: ",
getFFMPEGErrorStringFromErrorCode(status));

auto numConvertedSamples = swr_convert(
streamInfo.swrContext.get(),
convertedAVFrame->data,
convertedAVFrame->nb_samples,
static_cast<const uint8_t**>(
const_cast<const uint8_t**>(srcAVFrame->data)),
srcAVFrame->nb_samples);
// numConvertedSamples can be 0 if we're downsampling by a great factor and
// the first frame doesn't contain a lot of samples. It should be handled
// properly by the caller.
TORCH_CHECK(
numConvertedSamples >= 0,
"Error in swr_convert: ",
getFFMPEGErrorStringFromErrorCode(numConvertedSamples));

// See comment above about nb_samples
convertedAVFrame->nb_samples = numConvertedSamples;

return convertedAVFrame;
}

std::optional<torch::Tensor> SingleStreamDecoder::maybeFlushSwrBuffers() {
// When sample rate conversion is involved, swresample buffers some of the
// samples in-between calls to swr_convert (see the libswresample docs).
Expand Down Expand Up @@ -1735,30 +1673,6 @@ void SingleStreamDecoder::createSwsContext(
streamInfo.swsContext.reset(swsContext);
}

void SingleStreamDecoder::createSwrContext(
StreamInfo& streamInfo,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate) {
auto swrContext = allocateSwrContext(
streamInfo.codecContext,
sourceSampleFormat,
desiredSampleFormat,
sourceSampleRate,
desiredSampleRate);

auto status = swr_init(swrContext);
TORCH_CHECK(
status == AVSUCCESS,
"Couldn't initialize SwrContext: ",
getFFMPEGErrorStringFromErrorCode(status),
". If the error says 'Invalid argument', it's likely that you are using "
"a buggy FFmpeg version. FFmpeg4 is known to fail here in some "
"valid scenarios. Try to upgrade FFmpeg?");
streamInfo.swrContext.reset(swrContext);
}

// --------------------------------------------------------------------------
// PTS <-> INDEX CONVERSIONS
// --------------------------------------------------------------------------
Expand Down
14 changes: 0 additions & 14 deletions src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,6 @@ class SingleStreamDecoder {
const UniqueAVFrame& avFrame,
torch::Tensor& outputTensor);

UniqueAVFrame convertAudioAVFrameSampleFormatAndSampleRate(
const UniqueAVFrame& srcAVFrame,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

std::optional<torch::Tensor> maybeFlushSwrBuffers();

// --------------------------------------------------------------------------
Expand All @@ -310,13 +303,6 @@ class SingleStreamDecoder {
const DecodedFrameContext& frameContext,
const enum AVColorSpace colorspace);

void createSwrContext(
StreamInfo& streamInfo,
AVSampleFormat sourceSampleFormat,
AVSampleFormat desiredSampleFormat,
int sourceSampleRate,
int desiredSampleRate);

// --------------------------------------------------------------------------
// PTS <-> INDEX CONVERSIONS
// --------------------------------------------------------------------------
Expand Down
Loading