From fec8a702a437fdf2854f25266c38ab1fff0a0277 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 Mar 2025 14:32:01 +0000 Subject: [PATCH 1/4] WELL THIS WORKS --- .../decoders/_core/VideoDecoder.cpp | 28 +++++++------- test/decoders/test_ops.py | 37 +++++++++---------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c0738e570..6a70c8aaf 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -850,7 +850,7 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + - std::to_string(stopSeconds) + "."); + std::to_string(stopSeconds) + ")."); if (startSeconds == stopSeconds) { // For consistency with video @@ -859,29 +859,29 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - // TODO-AUDIO This essentially enforce that we don't need to seek (backwards). - // We should remove it and seek back to the stream's beginning when needed. - // See test_multiple_calls - TORCH_CHECK( - streamInfo.lastDecodedAvFramePts + - streamInfo.lastDecodedAvFrameDuration <= - secondsToClosestPts(startSeconds, streamInfo.timeBase), - "Audio decoder cannot seek backwards, or start from the last decoded frame."); + // TORCH_CHECK( + // streamInfo.lastDecodedAvFramePts + + // streamInfo.lastDecodedAvFrameDuration <= + // secondsToClosestPts(startSeconds, streamInfo.timeBase), + // "Audio decoder cannot seek backwards, or start from the last decoded + // frame."); - setCursorPtsInSeconds(startSeconds); + setCursorPtsInSeconds(INT64_MIN); // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + // cat(). This would save a copy. We know the duration of the output and the // sample rate, so in theory we know the number of output samples. std::vector tensors; + auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase); auto finished = false; while (!finished) { try { - AVFrameStream avFrameStream = decodeAVFrame([this](AVFrame* avFrame) { - return cursor_ < avFrame->pts + getDuration(avFrame); - }); + AVFrameStream avFrameStream = + decodeAVFrame([this, startPts](AVFrame* avFrame) { + return startPts < avFrame->pts + getDuration(avFrame); + }); auto frameOutput = convertAVFrameToFrameOutput(avFrameStream); tensors.push_back(frameOutput.data); } catch (const EndOfFileException& e) { @@ -938,7 +938,7 @@ I P P P I P P P I P P I P P I P bool VideoDecoder::canWeAvoidSeeking() const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - return true; + return false; } int64_t lastDecodedAvFramePts = streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index e33b9941d..a6c0cce22 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -741,11 +741,9 @@ def test_decode_start_equal_stop(self, asset): @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_multiple_calls(self, asset): - # Ensure that multiple calls are OK as long as we're decoding - # "sequentially", i.e. we don't require a backwards seek. - # And ensure a proper error is raised in such case. - # TODO-AUDIO We shouldn't error, we should just implement the seeking - # back to the beginning of the stream. + # Ensure that multiple calls to get_frames_by_pts_in_range_audio on the + # same decoder are supported, whether it involves forward seeks or + # backwards seeks. def get_reference_frames(start_seconds, stop_seconds): # This stateless helper exists for convenience, to avoid @@ -794,23 +792,22 @@ def get_reference_frames(start_seconds, stop_seconds): frames, get_reference_frames(start_seconds, stop_seconds) ) - # but starting immediately on the same frame raises - expected_match = "Audio decoder cannot seek backwards" - with pytest.raises(RuntimeError, match=expected_match): - get_frames_by_pts_in_range_audio( - decoder, start_seconds=stop_seconds, stop_seconds=6 - ) + # starting immediately on the same frame is OK + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=stop_seconds, stop_seconds=6 + ) + torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6)) - with pytest.raises(RuntimeError, match=expected_match): - get_frames_by_pts_in_range_audio( - decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6 - ) + get_frames_by_pts_in_range_audio( + decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6 + ) + torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6)) - # and seeking backwards doesn't work either - with pytest.raises(RuntimeError, match=expected_match): - frames = get_frames_by_pts_in_range_audio( - decoder, start_seconds=0, stop_seconds=2 - ) + # seeking backwards + frames = get_frames_by_pts_in_range_audio( + decoder, start_seconds=0, stop_seconds=2 + ) + torch.testing.assert_close(frames, get_reference_frames(0, 2)) if __name__ == "__main__": From 39e74145b5cdedda70029d856c704f041ec01cb7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 Mar 2025 14:42:05 +0000 Subject: [PATCH 2/4] Enable backwards seeks --- .../decoders/_core/VideoDecoder.cpp | 21 ++++++++++--------- test/decoders/test_ops.py | 20 ++++++++++++------ 2 files changed, 25 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6a70c8aaf..e8843f4d8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -859,21 +859,20 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( StreamInfo& streamInfo = streamInfos_[activeStreamIndex_]; - // TORCH_CHECK( - // streamInfo.lastDecodedAvFramePts + - // streamInfo.lastDecodedAvFrameDuration <= - // secondsToClosestPts(startSeconds, streamInfo.timeBase), - // "Audio decoder cannot seek backwards, or start from the last decoded - // frame."); - - setCursorPtsInSeconds(INT64_MIN); + auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); + if (startPts < streamInfo.lastDecodedAvFramePts + + streamInfo.lastDecodedAvFrameDuration) { + // If we need to seek backwards, then we have to seek back to the beginning + // of the stream. + // TODO-AUDIO: document why this is needed in a big comment. + setCursorPtsInSeconds(INT64_MIN); + } // TODO-AUDIO Pre-allocate a long-enough tensor instead of creating a vec + // cat(). This would save a copy. We know the duration of the output and the // sample rate, so in theory we know the number of output samples. std::vector tensors; - auto startPts = secondsToClosestPts(startSeconds, streamInfo.timeBase); auto stopPts = secondsToClosestPts(stopSeconds, streamInfo.timeBase); auto finished = false; while (!finished) { @@ -938,7 +937,9 @@ I P P P I P P P I P P I P P I P bool VideoDecoder::canWeAvoidSeeking() const { const StreamInfo& streamInfo = streamInfos_.at(activeStreamIndex_); if (streamInfo.avMediaType == AVMEDIA_TYPE_AUDIO) { - return false; + // For audio, we only need to seek if a backwards seek was requested within + // getFramesPlayedInRangeAudio(), when setCursorPtsInSeconds() was called. + return !cursorWasJustSet_; } int64_t lastDecodedAvFramePts = streamInfos_.at(activeStreamIndex_).lastDecodedAvFramePts; diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index a6c0cce22..1407d9471 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -793,21 +793,29 @@ def get_reference_frames(start_seconds, stop_seconds): ) # starting immediately on the same frame is OK + start_seconds, stop_seconds = stop_seconds, 6 frames = get_frames_by_pts_in_range_audio( - decoder, start_seconds=stop_seconds, stop_seconds=6 + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) ) - torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6)) get_frames_by_pts_in_range_audio( - decoder, start_seconds=stop_seconds + 1e-4, stop_seconds=6 + decoder, start_seconds=start_seconds + 1e-4, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) ) - torch.testing.assert_close(frames, get_reference_frames(stop_seconds, 6)) # seeking backwards + start_seconds, stop_seconds = 0, 2 frames = get_frames_by_pts_in_range_audio( - decoder, start_seconds=0, stop_seconds=2 + decoder, start_seconds=start_seconds, stop_seconds=stop_seconds + ) + torch.testing.assert_close( + frames, get_reference_frames(start_seconds, stop_seconds) ) - torch.testing.assert_close(frames, get_reference_frames(0, 2)) if __name__ == "__main__": From f512912f3e03483a06eacd592ad73a1b0b8bab02 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 Mar 2025 14:46:28 +0000 Subject: [PATCH 3/4] Comment --- test/decoders/test_ops.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/test/decoders/test_ops.py b/test/decoders/test_ops.py index 1407d9471..ce74243f7 100644 --- a/test/decoders/test_ops.py +++ b/test/decoders/test_ops.py @@ -742,16 +742,17 @@ def test_decode_start_equal_stop(self, asset): @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) def test_multiple_calls(self, asset): # Ensure that multiple calls to get_frames_by_pts_in_range_audio on the - # same decoder are supported, whether it involves forward seeks or - # backwards seeks. + # same decoder are supported and correct, whether it involves forward + # seeks or backwards seeks. def get_reference_frames(start_seconds, stop_seconds): - # This stateless helper exists for convenience, to avoid - # complicating this test with pts-to-index conversions. Eventually - # we should remove it and just rely on the asset's methods. - # Using this helper is OK for now: we're comparing a decoder which - # seeks multiple times with a decoder which seeks only once (the one - # here, treated as the reference) + # Usually we get the reference frames from the asset's methods, but + # for this specific test, this helper is more convenient, because + # relying on the asset would force us to convert all timestamps into + # indices. + # Ultimately, this test compares a "stateful decoder" which calls + # `get_frames_by_pts_in_range_audio()`` multiple times with a + # "stateless decoder" (the one here, treated as the reference) decoder = create_from_file(str(asset.path), seek_mode="approximate") add_audio_stream(decoder) From 31d025bdf8791617fb7da0c322c70942ca404c21 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 12 Mar 2025 15:01:59 +0000 Subject: [PATCH 4/4] Fix compilation --- src/torchcodec/decoders/_core/VideoDecoder.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index e8843f4d8..aeefa3b52 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -877,10 +877,9 @@ torch::Tensor VideoDecoder::getFramesPlayedInRangeAudio( auto finished = false; while (!finished) { try { - AVFrameStream avFrameStream = - decodeAVFrame([this, startPts](AVFrame* avFrame) { - return startPts < avFrame->pts + getDuration(avFrame); - }); + AVFrameStream avFrameStream = decodeAVFrame([startPts](AVFrame* avFrame) { + return startPts < avFrame->pts + getDuration(avFrame); + }); auto frameOutput = convertAVFrameToFrameOutput(avFrameStream); tensors.push_back(frameOutput.data); } catch (const EndOfFileException& e) {