diff --git a/src/torchcodec/_core/Metadata.cpp b/src/torchcodec/_core/Metadata.cpp index 58a115dcf..c717815df 100644 --- a/src/torchcodec/_core/Metadata.cpp +++ b/src/torchcodec/_core/Metadata.cpp @@ -29,6 +29,9 @@ std::optional StreamMetadata::getDurationSeconds( return static_cast(numFramesFromHeader.value()) / averageFpsFromHeader.value(); } + if (durationSecondsFromContainer.has_value()) { + return durationSecondsFromContainer.value(); + } return std::nullopt; default: TORCH_CHECK(false, "Unknown SeekMode"); @@ -80,13 +83,13 @@ std::optional StreamMetadata::getNumFrames(SeekMode seekMode) const { numFramesFromContent.has_value(), "Missing numFramesFromContent"); return numFramesFromContent.value(); case SeekMode::approximate: { + auto durationSeconds = getDurationSeconds(seekMode); if (numFramesFromHeader.has_value()) { return numFramesFromHeader.value(); } - if (averageFpsFromHeader.has_value() && - durationSecondsFromHeader.has_value()) { + if (averageFpsFromHeader.has_value() && durationSeconds.has_value()) { return static_cast( - averageFpsFromHeader.value() * durationSecondsFromHeader.value()); + averageFpsFromHeader.value() * durationSeconds.value()); } return std::nullopt; } diff --git a/src/torchcodec/_core/Metadata.h b/src/torchcodec/_core/Metadata.h index e138d5dc0..c3289868d 100644 --- a/src/torchcodec/_core/Metadata.h +++ b/src/torchcodec/_core/Metadata.h @@ -35,6 +35,9 @@ struct StreamMetadata { std::optional averageFpsFromHeader; std::optional bitRate; + // Used as fallback in approximate mode when stream duration is unavailable. + std::optional durationSecondsFromContainer; + // More accurate duration, obtained by scanning the file. // These presentation timestamps are in time base. std::optional beginStreamPtsFromContent; diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 0b30f1fff..22aca7bcd 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -100,6 +100,26 @@ void SingleStreamDecoder::initializeDecoder() { "Failed to find stream info: ", getFFMPEGErrorStringFromErrorCode(status)); + if (formatContext_->duration > 0) { + AVRational defaultTimeBase{1, AV_TIME_BASE}; + containerMetadata_.durationSecondsFromHeader = + ptsToSeconds(formatContext_->duration, defaultTimeBase); + } + + if (formatContext_->bit_rate > 0) { + containerMetadata_.bitRate = formatContext_->bit_rate; + } + + int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO); + if (bestVideoStream >= 0) { + containerMetadata_.bestVideoStreamIndex = bestVideoStream; + } + + int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO); + if (bestAudioStream >= 0) { + containerMetadata_.bestAudioStreamIndex = bestAudioStream; + } + for (unsigned int i = 0; i < formatContext_->nb_streams; i++) { AVStream* avStream = formatContext_->streams[i]; StreamMetadata streamMetadata; @@ -149,27 +169,10 @@ void SingleStreamDecoder::initializeDecoder() { containerMetadata_.numAudioStreams++; } - containerMetadata_.allStreamMetadata.push_back(streamMetadata); - } - - if (formatContext_->duration > 0) { - AVRational defaultTimeBase{1, AV_TIME_BASE}; - containerMetadata_.durationSecondsFromHeader = - ptsToSeconds(formatContext_->duration, defaultTimeBase); - } - - if (formatContext_->bit_rate > 0) { - containerMetadata_.bitRate = formatContext_->bit_rate; - } - - int bestVideoStream = getBestStreamIndex(AVMEDIA_TYPE_VIDEO); - if (bestVideoStream >= 0) { - containerMetadata_.bestVideoStreamIndex = bestVideoStream; - } + streamMetadata.durationSecondsFromContainer = + containerMetadata_.durationSecondsFromHeader; - int bestAudioStream = getBestStreamIndex(AVMEDIA_TYPE_AUDIO); - if (bestAudioStream >= 0) { - containerMetadata_.bestAudioStreamIndex = bestAudioStream; + containerMetadata_.allStreamMetadata.push_back(streamMetadata); } if (seekMode_ == SeekMode::exact) { diff --git a/src/torchcodec/_core/_metadata.py b/src/torchcodec/_core/_metadata.py index 08bcf2b55..1d5a7d103 100644 --- a/src/torchcodec/_core/_metadata.py +++ b/src/torchcodec/_core/_metadata.py @@ -44,7 +44,8 @@ class StreamMetadata: from the actual frames if a :term:`scan` was performed. Otherwise we fall back to ``duration_seconds_from_header``. If that value is also None, we instead calculate the duration from ``num_frames_from_header`` and - ``average_fps_from_header``. + ``average_fps_from_header``. If all of those are unavailable, we fall back + to the container-level ``duration_seconds_from_header``. """ begin_stream_seconds: Optional[float] """Beginning of the stream, in seconds (float). Conceptually, this diff --git a/test/test_decoders.py b/test/test_decoders.py index 39e2d7b93..9dba22d63 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -44,7 +44,6 @@ SINE_MONO_S32, SINE_MONO_S32_44100, SINE_MONO_S32_8000, - supports_approximate_mode, TEST_SRC_2_720P, TEST_SRC_2_720P_H265, TEST_SRC_2_720P_MPEG4, @@ -1465,8 +1464,6 @@ def test_get_frames_at_tensor_indices(self): def test_beta_cuda_interface_get_frame_at( self, asset, contiguous_indices, seek_mode ): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") if in_fbcode() and asset is AV1_VIDEO: pytest.skip("AV1 CUDA not supported internally") @@ -1513,8 +1510,6 @@ def test_beta_cuda_interface_get_frame_at( def test_beta_cuda_interface_get_frames_at( self, asset, contiguous_indices, seek_mode ): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") if in_fbcode() and asset is AV1_VIDEO: pytest.skip("AV1 CUDA not supported internally") @@ -1558,8 +1553,6 @@ def test_beta_cuda_interface_get_frames_at( ) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") if in_fbcode() and asset is AV1_VIDEO: pytest.skip("AV1 CUDA not supported internally") @@ -1600,8 +1593,6 @@ def test_beta_cuda_interface_get_frame_played_at(self, asset, seek_mode): ) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") if in_fbcode() and asset is AV1_VIDEO: pytest.skip("AV1 CUDA not supported internally") @@ -1643,8 +1634,6 @@ def test_beta_cuda_interface_get_frames_played_at(self, asset, seek_mode): ) @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) def test_beta_cuda_interface_backwards(self, asset, seek_mode): - if seek_mode == "approximate" and not supports_approximate_mode(asset): - pytest.skip("asset doesn't work with approximate mode") if in_fbcode() and asset is AV1_VIDEO: pytest.skip("AV1 CUDA not supported internally")