diff --git a/src/torchcodec/_core/Encoder.cpp b/src/torchcodec/_core/Encoder.cpp index b4d2c5609..8c2ce85e1 100644 --- a/src/torchcodec/_core/Encoder.cpp +++ b/src/torchcodec/_core/Encoder.cpp @@ -33,21 +33,22 @@ torch::Tensor validateSamples(const torch::Tensor& samples) { } void validateSampleRate(const AVCodec& avCodec, int sampleRate) { - if (avCodec.supported_samplerates == nullptr) { + const int* supportedSampleRates = getSupportedSampleRates(avCodec); + if (supportedSampleRates == nullptr) { return; } - for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) { - if (sampleRate == avCodec.supported_samplerates[i]) { + for (auto i = 0; supportedSampleRates[i] != 0; ++i) { + if (sampleRate == supportedSampleRates[i]) { return; } } std::stringstream supportedRates; - for (auto i = 0; avCodec.supported_samplerates[i] != 0; ++i) { + for (auto i = 0; supportedSampleRates[i] != 0; ++i) { if (i > 0) { supportedRates << ", "; } - supportedRates << avCodec.supported_samplerates[i]; + supportedRates << supportedSampleRates[i]; } TORCH_CHECK( @@ -73,19 +74,22 @@ static const std::vector preferredFormatsOrder = { AV_SAMPLE_FMT_U8}; AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { + const AVSampleFormat* supportedSampleFormats = + getSupportedOutputSampleFormats(avCodec); + // Find a sample format that the encoder supports. We prefer using FLT[P], // since this is the format of the input samples. If FLTP isn't supported // then we'll need to convert the AVFrame's format. Our heuristic is to encode // into the format with the highest resolution. - if (avCodec.sample_fmts == nullptr) { + if (supportedSampleFormats == nullptr) { // Can't really validate anything in this case, best we can do is hope that // FLTP is supported by the encoder. If not, FFmpeg will raise. return AV_SAMPLE_FMT_FLTP; } for (AVSampleFormat preferredFormat : preferredFormatsOrder) { - for (int i = 0; avCodec.sample_fmts[i] != -1; ++i) { - if (avCodec.sample_fmts[i] == preferredFormat) { + for (int i = 0; supportedSampleFormats[i] != -1; ++i) { + if (supportedSampleFormats[i] == preferredFormat) { return preferredFormat; } } @@ -93,7 +97,7 @@ AVSampleFormat findBestOutputSampleFormat(const AVCodec& avCodec) { // We should always find a match in preferredFormatsOrder, so we should always // return earlier. But in the event that a future FFmpeg version defines an // additional sample format that isn't in preferredFormatsOrder, we fallback: - return avCodec.sample_fmts[0]; + return supportedSampleFormats[0]; } } // namespace diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 094a1be0a..f0143ddd3 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -56,6 +56,46 @@ int64_t getDuration(const UniqueAVFrame& avFrame) { #endif } +const int* getSupportedSampleRates(const AVCodec& avCodec) { + const int* supportedSampleRates = nullptr; +#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) + int numSampleRates = 0; + int ret = avcodec_get_supported_config( + nullptr, + &avCodec, + AV_CODEC_CONFIG_SAMPLE_RATE, + 0, + reinterpret_cast & supportedSampleRates, + &numSampleRates); + if (ret < 0 || supportedSampleRates == nullptr) { + TORCH_CHECK(false, "Couldn't get supported sample rates from encoder."); + } +#else + supportedSampleRates = avCodec.supported_samplerates; +#endif + return supportedSampleRates; +} + +const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec) { + const AVSampleFormat* supportedSampleFormats = nullptr; +#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7 + int numSampleFormats = 0; + int ret = avcodec_get_supported_config( + nullptr, + &avCodec, + AV_CODEC_CONFIG_SAMPLE_FORMAT, + 0, + reinterpret_cast & supportedSampleFormats, + &numSampleFormats); + if (ret < 0 || supportedSampleFormats == nullptr) { + TORCH_CHECK(false, "Couldn't get supported sample formats from encoder."); + } +#else + supportedSampleFormats = avCodec.sample_fmts; +#endif + return supportedSampleFormats; +} + int getNumChannels(const UniqueAVFrame& avFrame) { #if LIBAVFILTER_VERSION_MAJOR > 8 || \ (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) @@ -109,7 +149,31 @@ void setDefaultChannelLayout(UniqueAVFrame& avFrame, int numChannels) { } void validateNumChannels(const AVCodec& avCodec, int numChannels) { -#if LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 +#if LIBAVCODEC_VERSION_INT >= AV_VERSION_INT(61, 13, 100) // FFmpeg >= 7 + std::stringstream supportedNumChannels; + const AVChannelLayout* supported_layouts = nullptr; + int num_layouts = 0; + int ret = avcodec_get_supported_config( + nullptr, + &avCodec, + AV_CODEC_CONFIG_CHANNEL_LAYOUT, + 0, + reinterpret_cast & supported_layouts, + &num_layouts); + if (ret < 0 || supported_layouts == nullptr) { + TORCH_CHECK(false, "Couldn't get supported channel layouts from encoder."); + return; + } + for (int i = 0; supported_layouts[i].nb_channels != 0; ++i) { + if (i > 0) { + supportedNumChannels << ", "; + } + supportedNumChannels << supported_layouts[i].nb_channels; + if (numChannels == supported_layouts[i].nb_channels) { + return; + } + } +#elif LIBAVFILTER_VERSION_MAJOR > 7 // FFmpeg > 4 if (avCodec.ch_layouts == nullptr) { // If we can't validate, we must assume it'll be fine. If not, FFmpeg will // eventually raise. @@ -131,7 +195,7 @@ void validateNumChannels(const AVCodec& avCodec, int numChannels) { } supportedNumChannels << avCodec.ch_layouts[i].nb_channels; } -#else +#else // FFmpeg <= 4 if (avCodec.channel_layouts == nullptr) { // can't validate, same as above. return; diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index b8c9e621c..179c7464b 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -162,6 +162,9 @@ std::string getFFMPEGErrorStringFromErrorCode(int errorCode); // support. int64_t getDuration(const UniqueAVFrame& frame); +const int* getSupportedSampleRates(const AVCodec& avCodec); +const AVSampleFormat* getSupportedOutputSampleFormats(const AVCodec& avCodec); + int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const UniqueAVCodecContext& avCodecContext);