From 04b2ead8f62c0c74a13e2fc0c72c4238d900fda0 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 30 May 2025 19:02:51 -0700 Subject: [PATCH 1/2] Allow num_frames and duration to be absent in C++ decoder --- src/torchcodec/_core/SingleStreamDecoder.cpp | 89 +++++++++++++------- src/torchcodec/_core/SingleStreamDecoder.h | 4 +- test/test_decoders.py | 4 +- 3 files changed, 63 insertions(+), 34 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 9bc003a9b..0ae87aa05 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -602,16 +602,22 @@ FrameBatchOutput SingleStreamDecoder::getFramesInRange( const auto& streamMetadata = containerMetadata_.allStreamMetadata[activeStreamIndex_]; const auto& streamInfo = streamInfos_[activeStreamIndex_]; - int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); - TORCH_CHECK( - stop <= numFrames, - "Range stop, " + std::to_string(stop) + - ", is more than the number of frames, " + std::to_string(numFrames)); TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); + // Note that if we do not have the number of frames available in our metadata, + // then we assume that the upper part of the range is valid. + std::optional numFrames = getNumFrames(streamMetadata); + if (numFrames.has_value()) { + TORCH_CHECK( + stop <= numFrames.value(), + "Range stop, " + std::to_string(stop) + + ", is more than the number of frames, " + + std::to_string(numFrames.value())); + } + int64_t numOutputFrames = std::ceil((stop - start) / double(step)); const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( @@ -676,7 +682,7 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( containerMetadata_.allStreamMetadata[activeStreamIndex_]; double minSeconds = getMinSeconds(streamMetadata); - double maxSeconds = getMaxSeconds(streamMetadata); + std::optional maxSeconds = getMaxSeconds(streamMetadata); // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to @@ -687,10 +693,20 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( for (size_t i = 0; i < timestamps.size(); ++i) { auto frameSeconds = timestamps[i]; TORCH_CHECK( - frameSeconds >= minSeconds && frameSeconds < maxSeconds, + frameSeconds >= minSeconds, "frame pts is " + std::to_string(frameSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); + "; must be greater than or equal to " + std::to_string(minSeconds) + + "."); + + // Note that if we can't determine the maximum number of seconds from the + // metadata, then we assume the frame's pts is valid. + if (maxSeconds.has_value()) { + TORCH_CHECK( + frameSeconds < maxSeconds.value(), + "frame pts is " + std::to_string(frameSeconds) + + "; must be less than " + std::to_string(maxSeconds.value()) + + "."); + } frameIndices[i] = secondsToIndexLowerBound(frameSeconds); } @@ -737,17 +753,26 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( } double minSeconds = getMinSeconds(streamMetadata); - double maxSeconds = getMaxSeconds(streamMetadata); TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, + startSeconds >= minSeconds, "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); + "; must be greater than or equal to " + std::to_string(minSeconds) + + "."); + + // Note that if we can't determine the maximum seconds from the metadata, then + // we assume upper range is valid. + std::optional maxSeconds = getMaxSeconds(streamMetadata); + if (maxSeconds.has_value()) { + TORCH_CHECK( + startSeconds < maxSeconds.value(), + "Start seconds is " + std::to_string(startSeconds) + + "; must be less than " + std::to_string(maxSeconds.value()) + "."); + TORCH_CHECK( + stopSeconds <= maxSeconds.value(), + "Stop seconds (" + std::to_string(stopSeconds) + + "; must be less than or equal to " + + std::to_string(maxSeconds.value()) + ")."); + } // Note that we look at nextPts for a frame, and not its pts or duration. // Our abstract player displays frames starting at the pts for that frame @@ -1456,15 +1481,12 @@ int64_t SingleStreamDecoder::getPts(int64_t frameIndex) { // STREAM AND METADATA APIS // -------------------------------------------------------------------------- -int64_t SingleStreamDecoder::getNumFrames( +std::optional SingleStreamDecoder::getNumFrames( const StreamMetadata& streamMetadata) { switch (seekMode_) { case SeekMode::exact: return streamMetadata.numFramesFromScan.value(); case SeekMode::approximate: { - TORCH_CHECK( - streamMetadata.numFrames.has_value(), - "Cannot use approximate mode since we couldn't find the number of frames from the metadata."); return streamMetadata.numFrames.value(); } default: @@ -1484,16 +1506,13 @@ double SingleStreamDecoder::getMinSeconds( } } -double SingleStreamDecoder::getMaxSeconds( +std::optional SingleStreamDecoder::getMaxSeconds( const StreamMetadata& streamMetadata) { switch (seekMode_) { case SeekMode::exact: return streamMetadata.maxPtsSecondsFromScan.value(); case SeekMode::approximate: { - TORCH_CHECK( - streamMetadata.durationSeconds.has_value(), - "Cannot use approximate mode since we couldn't find the duration from the metadata."); - return streamMetadata.durationSeconds.value(); + return streamMetadata.durationSeconds; } default: throw std::runtime_error("Unknown SeekMode"); @@ -1539,12 +1558,22 @@ void SingleStreamDecoder::validateScannedAllStreams(const std::string& msg) { void SingleStreamDecoder::validateFrameIndex( const StreamMetadata& streamMetadata, int64_t frameIndex) { - int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( - frameIndex >= 0 && frameIndex < numFrames, + frameIndex >= 0, "Invalid frame index=" + std::to_string(frameIndex) + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + - " numFrames=" + std::to_string(numFrames)); + "; must be greater than or equal to 0"); + + // Note that if we do not have the number of frames available in our metadata, + // then we assume that the frameIndex is valid. + std::optional numFrames = getNumFrames(streamMetadata); + if (numFrames.has_value()) { + TORCH_CHECK( + frameIndex < numFrames.value(), + "Invalid frame index=" + std::to_string(frameIndex) + + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + + "; must be less than " + std::to_string(numFrames.value())); + } } // -------------------------------------------------------------------------- diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index cbacb8477..cf46494c7 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -304,9 +304,9 @@ class SingleStreamDecoder { // index. Note that this index may be truncated for some files. int getBestStreamIndex(AVMediaType mediaType); - int64_t getNumFrames(const StreamMetadata& streamMetadata); + std::optional getNumFrames(const StreamMetadata& streamMetadata); double getMinSeconds(const StreamMetadata& streamMetadata); - double getMaxSeconds(const StreamMetadata& streamMetadata); + std::optional getMaxSeconds(const StreamMetadata& streamMetadata); // -------------------------------------------------------------------------- // VALIDATION UTILS diff --git a/test/test_decoders.py b/test/test_decoders.py index 133b3cb65..988abf947 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -597,10 +597,10 @@ def test_get_frames_played_at(self, device, seek_mode): def test_get_frames_played_at_fails(self, device, seek_mode): decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) - with pytest.raises(RuntimeError, match="must be in range"): + with pytest.raises(RuntimeError, match="must be greater than or equal to"): decoder.get_frames_played_at([-1]) - with pytest.raises(RuntimeError, match="must be in range"): + with pytest.raises(RuntimeError, match="must be less than"): decoder.get_frames_played_at([14]) with pytest.raises(RuntimeError, match="Expected a value of type"): From 7a6e6ca5d6e4fa5c12325c028a5c029897834c40 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 2 Jun 2025 07:05:09 -0700 Subject: [PATCH 2/2] getNumFrames() should return optional in approximate mode --- src/torchcodec/_core/SingleStreamDecoder.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 0ae87aa05..f4a285ec6 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -1487,7 +1487,7 @@ std::optional SingleStreamDecoder::getNumFrames( case SeekMode::exact: return streamMetadata.numFramesFromScan.value(); case SeekMode::approximate: { - return streamMetadata.numFrames.value(); + return streamMetadata.numFrames; } default: throw std::runtime_error("Unknown SeekMode");