From 84fd7a5081d8b26925f536b2fcd0b9c99c078a39 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Nov 2025 18:25:03 -0800 Subject: [PATCH 1/3] Refactor order of getting metadata and adding a stream --- src/torchcodec/_core/FFMPEGCommon.cpp | 10 +++++++ src/torchcodec/_core/FFMPEGCommon.h | 1 + src/torchcodec/_core/SingleStreamDecoder.cpp | 19 ++++++++++++- src/torchcodec/decoders/_audio_decoder.py | 24 ++++++++++------ src/torchcodec/decoders/_video_decoder.py | 29 +++++++++++--------- test/test_decoders.py | 4 +-- test/test_metadata.py | 5 +--- 7 files changed, 64 insertions(+), 28 deletions(-) diff --git a/src/torchcodec/_core/FFMPEGCommon.cpp b/src/torchcodec/_core/FFMPEGCommon.cpp index 2256c9098..e1b88b36a 100644 --- a/src/torchcodec/_core/FFMPEGCommon.cpp +++ b/src/torchcodec/_core/FFMPEGCommon.cpp @@ -158,6 +158,16 @@ int getNumChannels(const SharedAVCodecContext& avCodecContext) { #endif } +int getNumChannels(const AVCodecParameters* codecpar) { + TORCH_CHECK(codecpar != nullptr, "codecpar is null") +#if LIBAVFILTER_VERSION_MAJOR > 8 || \ + (LIBAVFILTER_VERSION_MAJOR == 8 && LIBAVFILTER_VERSION_MINOR >= 44) + return codecpar->ch_layout.nb_channels; +#else + return codecpar->channels; +#endif +} + void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, int numChannels) { diff --git a/src/torchcodec/_core/FFMPEGCommon.h b/src/torchcodec/_core/FFMPEGCommon.h index 39994967d..5afa97db5 100644 --- a/src/torchcodec/_core/FFMPEGCommon.h +++ b/src/torchcodec/_core/FFMPEGCommon.h @@ -180,6 +180,7 @@ const AVPixelFormat* getSupportedPixelFormats(const AVCodec& avCodec); int getNumChannels(const UniqueAVFrame& avFrame); int getNumChannels(const SharedAVCodecContext& avCodecContext); +int getNumChannels(const AVCodecParameters* codecpar); void setDefaultChannelLayout( UniqueAVCodecContext& avCodecContext, diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 32d6f9d99..6296fc156 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -110,8 +110,8 @@ void SingleStreamDecoder::initializeDecoder() { ", does not match AVStream's index, " + std::to_string(avStream->index) + "."); streamMetadata.streamIndex = i; - streamMetadata.mediaType = avStream->codecpar->codec_type; streamMetadata.codecName = avcodec_get_name(avStream->codecpar->codec_id); + streamMetadata.mediaType = avStream->codecpar->codec_type; streamMetadata.bitRate = avStream->codecpar->bit_rate; int64_t frameCount = avStream->nb_frames; @@ -133,10 +133,18 @@ void SingleStreamDecoder::initializeDecoder() { if (fps > 0) { streamMetadata.averageFpsFromHeader = fps; } + streamMetadata.width = avStream->codecpar->width; + streamMetadata.height = avStream->codecpar->height; + streamMetadata.sampleAspectRatio = + avStream->codecpar->sample_aspect_ratio; containerMetadata_.numVideoStreams++; } else if (avStream->codecpar->codec_type == AVMEDIA_TYPE_AUDIO) { AVSampleFormat format = static_cast(avStream->codecpar->format); + streamMetadata.sampleRate = + static_cast(avStream->codecpar->sample_rate); + streamMetadata.numChannels = + static_cast(getNumChannels(avStream->codecpar)); // If the AVSampleFormat is not recognized, we get back nullptr. We have // to make sure we don't initialize a std::string with nullptr. There's @@ -516,6 +524,10 @@ void SingleStreamDecoder::addVideoStream( auto& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.videoStreamOptions = videoStreamOptions; + // This metadata was already set in initializeDecoder() from the + // AVCodecParameters that are part of the AVStream. But we consider the + // AVCodecContext to be more authoritative, so we use that for our decoding + // stream. streamMetadata.width = streamInfo.codecContext->width; streamMetadata.height = streamInfo.codecContext->height; streamMetadata.sampleAspectRatio = @@ -568,6 +580,11 @@ void SingleStreamDecoder::addAudioStream( auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; + + // This metadata was already set in initializeDecoder() from the + // AVCodecParameters that are part of the AVStream. But we consider the + // AVCodecContext to be more authoritative, so we use that for our decoding + // stream. streamMetadata.sampleRate = static_cast(streamInfo.codecContext->sample_rate); streamMetadata.numChannels = diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index d1e42c196..9d9f13717 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -63,13 +63,6 @@ def __init__( torch._C._log_api_usage_once("torchcodec.decoders.AudioDecoder") self._decoder = create_decoder(source=source, seek_mode="approximate") - core.add_audio_stream( - self._decoder, - stream_index=stream_index, - sample_rate=sample_rate, - num_channels=num_channels, - ) - container_metadata = core.get_container_metadata(self._decoder) self.stream_index = ( container_metadata.best_audio_stream_index @@ -81,13 +74,28 @@ def __init__( "The best audio stream is unknown and there is no specified stream. " + ERROR_REPORTING_INSTRUCTIONS ) + if self.stream_index >= len(container_metadata.streams): + raise ValueError( + f"The stream at index {stream_index} is not a valid stream." + ) + self.metadata = container_metadata.streams[self.stream_index] - assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy + if not isinstance(self.metadata, core._metadata.AudioStreamMetadata): + raise ValueError( + f"The stream at index {stream_index} is not an audio stream. " + ) self._desired_sample_rate = ( sample_rate if sample_rate is not None else self.metadata.sample_rate ) + core.add_audio_stream( + self._decoder, + stream_index=stream_index, + sample_rate=sample_rate, + num_channels=num_channels, + ) + def get_all_samples(self) -> AudioSamples: """Returns all the audio samples from the source. diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 8659ab05b..1b4d4706d 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -141,6 +141,16 @@ def __init__( self._decoder = create_decoder(source=source, seek_mode=seek_mode) + ( + self.metadata, + self.stream_index, + self._begin_stream_seconds, + self._end_stream_seconds, + self._num_frames, + ) = _get_and_validate_stream_metadata( + decoder=self._decoder, stream_index=stream_index + ) + allowed_dimension_orders = ("NCHW", "NHWC") if dimension_order not in allowed_dimension_orders: raise ValueError( @@ -157,12 +167,11 @@ def __init__( device = str(device) device_variant = _get_cuda_backend() - transform_specs = _make_transform_specs(transforms) core.add_video_stream( self._decoder, - stream_index=stream_index, + stream_index=self.stream_index, dimension_order=dimension_order, num_threads=num_ffmpeg_threads, device=device, @@ -171,16 +180,6 @@ def __init__( custom_frame_mappings=custom_frame_mappings_data, ) - ( - self.metadata, - self.stream_index, - self._begin_stream_seconds, - self._end_stream_seconds, - self._num_frames, - ) = _get_and_validate_stream_metadata( - decoder=self._decoder, stream_index=stream_index - ) - def __len__(self) -> int: return self._num_frames @@ -413,8 +412,12 @@ def _get_and_validate_stream_metadata( + ERROR_REPORTING_INSTRUCTIONS ) + if stream_index >= len(container_metadata.streams): + raise ValueError(f"The stream index {stream_index} is not a valid stream.") + metadata = container_metadata.streams[stream_index] - assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy + if not isinstance(metadata, core._metadata.VideoStreamMetadata): + raise ValueError(f"The stream at index {stream_index} is not a video stream. ") if metadata.begin_stream_seconds is None: raise ValueError( diff --git a/test/test_decoders.py b/test/test_decoders.py index 39e2d7b93..3cbdf70e3 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -116,11 +116,11 @@ def test_create_fails(self, Decoder): Decoder(123) # stream index that does not exist - with pytest.raises(ValueError, match="No valid stream found"): + with pytest.raises(ValueError, match="40 is not a valid stream"): Decoder(NASA_VIDEO.path, stream_index=40) # stream index that does exist, but it's not audio or video - with pytest.raises(ValueError, match="No valid stream found"): + with pytest.raises(ValueError, match=r"not (a|an) (video|audio) stream"): Decoder(NASA_VIDEO.path, stream_index=2) # user mistakenly forgets to specify binary reading when creating a file diff --git a/test/test_metadata.py b/test/test_metadata.py index a4f6da341..f0fba5d2a 100644 --- a/test/test_metadata.py +++ b/test/test_metadata.py @@ -59,7 +59,6 @@ def test_get_metadata(metadata_getter): ) if (seek_mode == "custom_frame_mappings") and get_ffmpeg_major_version() in (4, 5): pytest.skip(reason="ffprobe isn't accurate on ffmpeg 4 and 5") - with_added_video_stream = seek_mode == "custom_frame_mappings" metadata = metadata_getter(NASA_VIDEO.path) with_scan = ( @@ -99,9 +98,7 @@ def test_get_metadata(metadata_getter): assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 assert best_video_stream_metadata.bit_rate == 128783 assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) - assert best_video_stream_metadata.pixel_aspect_ratio == ( - Fraction(1, 1) if with_added_video_stream else None - ) + assert best_video_stream_metadata.pixel_aspect_ratio == Fraction(1, 1) assert best_video_stream_metadata.codec == "h264" assert best_video_stream_metadata.num_frames_from_content == ( 390 if with_scan else None From 8e0b75611f7d1235bddb11dd219524be0814810e Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Nov 2025 06:46:30 -0800 Subject: [PATCH 2/3] Remove re-setting of metadata --- src/torchcodec/_core/SingleStreamDecoder.cpp | 21 -------------------- 1 file changed, 21 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 6296fc156..55232bda7 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -524,15 +524,6 @@ void SingleStreamDecoder::addVideoStream( auto& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.videoStreamOptions = videoStreamOptions; - // This metadata was already set in initializeDecoder() from the - // AVCodecParameters that are part of the AVStream. But we consider the - // AVCodecContext to be more authoritative, so we use that for our decoding - // stream. - streamMetadata.width = streamInfo.codecContext->width; - streamMetadata.height = streamInfo.codecContext->height; - streamMetadata.sampleAspectRatio = - streamInfo.codecContext->sample_aspect_ratio; - if (seekMode_ == SeekMode::custom_frame_mappings) { TORCH_CHECK( customFrameMappings.has_value(), @@ -578,18 +569,6 @@ void SingleStreamDecoder::addAudioStream( auto& streamInfo = streamInfos_[activeStreamIndex_]; streamInfo.audioStreamOptions = audioStreamOptions; - auto& streamMetadata = - containerMetadata_.allStreamMetadata[activeStreamIndex_]; - - // This metadata was already set in initializeDecoder() from the - // AVCodecParameters that are part of the AVStream. But we consider the - // AVCodecContext to be more authoritative, so we use that for our decoding - // stream. - streamMetadata.sampleRate = - static_cast(streamInfo.codecContext->sample_rate); - streamMetadata.numChannels = - static_cast(getNumChannels(streamInfo.codecContext)); - // FFmpeg docs say that the decoder will try to decode natively in this // format, if it can. Docs don't say what the decoder does when it doesn't // support that format, but it looks like it does nothing, so this probably From 9b14d17c22ecd5a2eab60be1b9ff91bb60845b6e Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 18 Nov 2025 20:38:57 -0800 Subject: [PATCH 3/3] Deal with custom frame mapping --- src/torchcodec/_core/Metadata.h | 10 ++++- src/torchcodec/_core/custom_ops.cpp | 62 ++++++++++++++++++++--------- 2 files changed, 53 insertions(+), 19 deletions(-) diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index e138d5dc0..22cc0cb48 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -23,9 +23,11 @@ enum class SeekMode { exact, approximate, custom_frame_mappings }; struct StreamMetadata { // Common (video and audio) fields derived from the AVStream. int streamIndex; + // See this link for what various values are available: // https://ffmpeg.org/doxygen/trunk/group__lavu__misc.html#ga9a84bba4713dfced21a1a56163be1f48 AVMediaType mediaType; + std::optional codecId; std::optional codecName; std::optional durationSecondsFromHeader; @@ -39,13 +41,15 @@ struct StreamMetadata { // These presentation timestamps are in time base. std::optional beginStreamPtsFromContent; std::optional endStreamPtsFromContent; + // These presentation timestamps are in seconds. std::optional beginStreamPtsSecondsFromContent; std::optional endStreamPtsSecondsFromContent; + // This can be useful for index-based seeking. std::optional numFramesFromContent; - // Video-only fields derived from the AVCodecContext. + // Video-only fields std::optional width; std::optional height; std::optional sampleAspectRatio; @@ -67,13 +71,17 @@ struct ContainerMetadata { std::vector allStreamMetadata; int numAudioStreams = 0; int numVideoStreams = 0; + // Note that this is the container-level duration, which is usually the max // of all stream durations available in the container. std::optional durationSecondsFromHeader; + // Total BitRate level information at the container level in bit/s std::optional bitRate; + // If set, this is the index to the default audio stream. std::optional bestAudioStreamIndex; + // If set, this is the index to the default video stream. std::optional bestVideoStreamIndex; }; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 3836e52da..d514b2777 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -198,6 +198,34 @@ SeekMode seekModeFromString(std::string_view seekMode) { } } +void writeFallbackBasedMetadata( + std::map& map, + const StreamMetadata& streamMetadata, + SeekMode seekMode) { + auto durationSeconds = streamMetadata.getDurationSeconds(seekMode); + if (durationSeconds.has_value()) { + map["durationSeconds"] = std::to_string(durationSeconds.value()); + } + + auto numFrames = streamMetadata.getNumFrames(seekMode); + if (numFrames.has_value()) { + map["numFrames"] = std::to_string(numFrames.value()); + } + + double beginStreamSeconds = streamMetadata.getBeginStreamSeconds(seekMode); + map["beginStreamSeconds"] = std::to_string(beginStreamSeconds); + + auto endStreamSeconds = streamMetadata.getEndStreamSeconds(seekMode); + if (endStreamSeconds.has_value()) { + map["endStreamSeconds"] = std::to_string(endStreamSeconds.value()); + } + + auto averageFps = streamMetadata.getAverageFps(seekMode); + if (averageFps.has_value()) { + map["averageFps"] = std::to_string(averageFps.value()); + } +} + int checkedToPositiveInt(const std::string& str) { int ret = 0; try { @@ -917,30 +945,28 @@ std::string get_stream_json_metadata( // In approximate mode: content-based metadata does not exist for any stream. // In custom_frame_mappings: content-based metadata exists only for the active // stream. + // // Our fallback logic assumes content-based metadata is available. // It is available for decoding on the active stream, but would break // when getting metadata from non-active streams. if ((seekMode != SeekMode::custom_frame_mappings) || (seekMode == SeekMode::custom_frame_mappings && stream_index == activeStreamIndex)) { - if (streamMetadata.getDurationSeconds(seekMode).has_value()) { - map["durationSeconds"] = - std::to_string(streamMetadata.getDurationSeconds(seekMode).value()); - } - if (streamMetadata.getNumFrames(seekMode).has_value()) { - map["numFrames"] = - std::to_string(streamMetadata.getNumFrames(seekMode).value()); - } - map["beginStreamSeconds"] = - std::to_string(streamMetadata.getBeginStreamSeconds(seekMode)); - if (streamMetadata.getEndStreamSeconds(seekMode).has_value()) { - map["endStreamSeconds"] = - std::to_string(streamMetadata.getEndStreamSeconds(seekMode).value()); - } - if (streamMetadata.getAverageFps(seekMode).has_value()) { - map["averageFps"] = - std::to_string(streamMetadata.getAverageFps(seekMode).value()); - } + writeFallbackBasedMetadata(map, streamMetadata, seekMode); + } else if (seekMode == SeekMode::custom_frame_mappings) { + // If this is not the active stream, then we don't have content-based + // metadata for custom frame mappings. In that case, we want the same + // behavior as we would get with approximate mode. Encoding this behavior in + // the fallback logic itself is tricky and not worth it for this corner + // case. So we hardcode in approximate mode. + // + // TODO: This hacky behavior is only necessary because the custom frame + // mapping is supplied in SingleStreamDecoder::addVideoStream() rather + // than in the constructor. And it's supplied to addVideoStream() and + // not the constructor because we need to know the stream index. If we + // can encode the relevant stream indices into custom frame mappings + // itself, then we can put it in the constructor. + writeFallbackBasedMetadata(map, streamMetadata, SeekMode::approximate); } return mapToJson(map);