From 9f3a8043d201e547b37bc155ae5872d029637fdb Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 16 Dec 2024 08:07:36 -0800 Subject: [PATCH] Refactor VideoDecoder C++ initialization --- .../decoders/_core/VideoDecoder.cpp | 41 ++++++++++++------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 + 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 4c98980c5..ab3d3f9d1 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -234,36 +234,41 @@ VideoDecoder::VideoDecoder(const void* buffer, size_t length) { } void VideoDecoder::initializeDecoder() { - // Some formats don't store enough info in the header so we read/decode a few - // frames to grab that. This is needed for the filter graph. Note: If this - // takes a long time, consider initializing the filter graph after the first - // frame decode. + TORCH_CHECK(!initialized_, "Attempted double initialization."); + int ffmpegStatus = avformat_find_stream_info(formatContext_.get(), nullptr); if (ffmpegStatus < 0) { throw std::runtime_error( "Failed to find stream info: " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - containerMetadata_.streams.resize(0); + for (int i = 0; i < formatContext_->nb_streams; i++) { AVStream* stream = formatContext_->streams[i]; - containerMetadata_.streams.resize(containerMetadata_.streams.size() + 1); - auto& curr = containerMetadata_.streams.back(); - curr.streamIndex = i; - curr.mediaType = stream->codecpar->codec_type; - curr.codecName = avcodec_get_name(stream->codecpar->codec_id); - curr.bitRate = stream->codecpar->bit_rate; + StreamMetadata meta; + + TORCH_CHECK( + i == stream->index, + "Our stream index, " + std::to_string(i) + + ", does not match AVStream's index, " + + std::to_string(stream->index) + "."); + meta.streamIndex = i; + meta.mediaType = stream->codecpar->codec_type; + meta.codecName = avcodec_get_name(stream->codecpar->codec_id); + meta.bitRate = stream->codecpar->bit_rate; int64_t frameCount = stream->nb_frames; if (frameCount > 0) { - curr.numFrames = frameCount; + meta.numFrames = frameCount; } + if (stream->duration > 0 && stream->time_base.den > 0) { - curr.durationSeconds = av_q2d(stream->time_base) * stream->duration; + meta.durationSeconds = av_q2d(stream->time_base) * stream->duration; } + double fps = av_q2d(stream->r_frame_rate); if (fps > 0) { - curr.averageFps = fps; + meta.averageFps = fps; } if (stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { @@ -271,22 +276,30 @@ void VideoDecoder::initializeDecoder() { } else if (stream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { containerMetadata_.numAudioStreams++; } + + containerMetadata_.streams.push_back(meta); } + if (formatContext_->duration > 0) { containerMetadata_.durationSeconds = ptsToSeconds(formatContext_->duration, AV_TIME_BASE); } + if (formatContext_->bit_rate > 0) { containerMetadata_.bitRate = formatContext_->bit_rate; } + int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO); if (bestVideoStream >= 0) { containerMetadata_.bestVideoStreamIndex = bestVideoStream; } + int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO); if (bestAudioStream >= 0) { containerMetadata_.bestAudioStreamIndex = bestAudioStream; } + + initialized_ = true; } std::unique_ptr VideoDecoder::createFromFilePath( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index fcd1b17ca..6ad2ab5e0 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -420,6 +420,8 @@ class VideoDecoder { std::unique_ptr ioBytesContext_; // Whether or not we have already scanned all streams to update the metadata. bool scanned_all_streams_ = false; + // Tracks that we've already been initialized. + bool initialized_ = false; }; // --------------------------------------------------------------------------