From fbcea468e7d8a52a343dca9900b6f21f73f921c9 Mon Sep 17 00:00:00 2001 From: Molly Xu Date: Wed, 1 Oct 2025 20:54:01 -0700 Subject: [PATCH] Added proper tensor support for get_frames_at() (#915) Summary: Modified get_frames_at in _video_decoder to accept tensors and updated all downstream functions to natively accept tensors rather than converting them to lists. Reviewed By: NicolasHug Differential Revision: D83506846 --- src/torchcodec/_core/SingleStreamDecoder.cpp | 30 ++++++++++++------- src/torchcodec/_core/SingleStreamDecoder.h | 2 +- src/torchcodec/_core/custom_ops.cpp | 8 ++--- src/torchcodec/_core/ops.py | 20 ++++++++++--- .../_samplers/video_clip_sampler.py | 2 ++ src/torchcodec/decoders/_video_decoder.py | 10 ++----- test/VideoDecoderTest.cpp | 6 ++-- test/test_decoders.py | 4 ++- test/test_ops.py | 2 +- 9 files changed, 53 insertions(+), 31 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 81e3e4474..052f01164 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -602,11 +602,18 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal( } FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( - const std::vector& frameIndices) { + const torch::Tensor& frameIndices) { validateActiveStream(AVMEDIA_TYPE_VIDEO); - auto indicesAreSorted = - std::is_sorted(frameIndices.begin(), frameIndices.end()); + auto frameIndicesAccessor = frameIndices.accessor(); + + bool indicesAreSorted = true; + for (int64_t i = 1; i < frameIndices.numel(); ++i) { + if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) { + indicesAreSorted = false; + break; + } + } std::vector argsort; if (!indicesAreSorted) { @@ -614,13 +621,15 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( // when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want // to use to decode the frames // and argsort is [ 1, 3, 2, 0] - argsort.resize(frameIndices.size()); + argsort.resize(frameIndices.numel()); for (size_t i = 0; i < argsort.size(); ++i) { argsort[i] = i; } std::sort( - argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) { - return frameIndices[a] < frameIndices[b]; + argsort.begin(), + argsort.end(), + [&frameIndicesAccessor](size_t a, size_t b) { + return frameIndicesAccessor[a] < frameIndicesAccessor[b]; }); } @@ -629,12 +638,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices( const auto& streamInfo = streamInfos_[activeStreamIndex_]; const auto& videoStreamOptions = streamInfo.videoStreamOptions; FrameBatchOutput frameBatchOutput( - frameIndices.size(), videoStreamOptions, streamMetadata); + frameIndices.numel(), videoStreamOptions, streamMetadata); auto previousIndexInVideo = -1; - for (size_t f = 0; f < frameIndices.size(); ++f) { + for (int64_t f = 0; f < frameIndices.numel(); ++f) { auto indexInOutput = indicesAreSorted ? f : argsort[f]; - auto indexInVideo = frameIndices[indexInOutput]; + auto indexInVideo = frameIndicesAccessor[indexInOutput]; if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice @@ -776,7 +785,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt( frameIndices[i] = secondsToIndexLowerBound(frameSeconds); } - return getFramesAtIndices(frameIndices); + // TODO: Support tensors natively instead of a vector to avoid a copy. + return getFramesAtIndices(torch::tensor(frameIndices)); } FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange( diff --git a/src/torchcodec/_core/SingleStreamDecoder.h b/src/torchcodec/_core/SingleStreamDecoder.h index 56bb8bb58..927796a57 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.h +++ b/src/torchcodec/_core/SingleStreamDecoder.h @@ -106,7 +106,7 @@ class SingleStreamDecoder { // Returns frames at the given indices for a given stream as a single stacked // Tensor. - FrameBatchOutput getFramesAtIndices(const std::vector& frameIndices); + FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices); // Returns frames within a given range. The range is defined by [start, stop). // The values retrieved from the range are: [start, start+step, diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index a865bdaed..d4e0fa705 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -55,7 +55,7 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.def( "get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)"); m.def( - "get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)"); + "get_frames_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor)"); m.def( "get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)"); m.def( @@ -378,11 +378,9 @@ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) { // Return the frames at given indices for a given stream OpsFrameBatchOutput get_frames_at_indices( at::Tensor& decoder, - at::IntArrayRef frame_indices) { + const at::Tensor& frame_indices) { auto videoDecoder = unwrapTensorToGetDecoder(decoder); - std::vector frameIndicesVec( - frame_indices.begin(), frame_indices.end()); - auto result = videoDecoder->getFramesAtIndices(frameIndicesVec); + auto result = videoDecoder->getFramesAtIndices(frame_indices); return makeOpsFrameBatchOutput(result); } diff --git a/src/torchcodec/_core/ops.py b/src/torchcodec/_core/ops.py index d618b8d9f..24b61c780 100644 --- a/src/torchcodec/_core/ops.py +++ b/src/torchcodec/_core/ops.py @@ -114,7 +114,9 @@ def load_torchcodec_shared_libraries(): get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default -get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default +_get_frames_at_indices_tensor_input = ( + torch.ops.torchcodec_ns.get_frames_at_indices.default +) get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default @@ -198,6 +200,18 @@ def encode_audio_to_file_like( ) +def get_frames_at_indices( + decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]] +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + if isinstance(frame_indices, torch.Tensor): + # Ensure indices is the correct dtype (int64) + frame_indices = frame_indices.to(torch.int64) + else: + # Convert list to tensor for dispatch + frame_indices = torch.tensor(frame_indices) + return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices) + + # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== @@ -371,9 +385,7 @@ def get_frame_at_index_abstract( @register_fake("torchcodec_ns::get_frames_at_indices") def get_frames_at_indices_abstract( - decoder: torch.Tensor, - *, - frame_indices: List[int], + decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, List[int]] ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] return ( diff --git a/src/torchcodec/_samplers/video_clip_sampler.py b/src/torchcodec/_samplers/video_clip_sampler.py index f05f61650..06976ba78 100644 --- a/src/torchcodec/_samplers/video_clip_sampler.py +++ b/src/torchcodec/_samplers/video_clip_sampler.py @@ -227,6 +227,8 @@ def _get_clips_for_index_based_sampling( clip_start_idx + i * index_based_sampler_args.video_frame_dilation for i in range(index_based_sampler_args.frames_per_clip) ] + # Need torch.stack to convert List[Tensor[int]] into 1D Tensor[int] + batch_indexes = torch.stack(batch_indexes) frames, *_ = get_frames_at_indices( video_decoder, frame_indices=batch_indexes, diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 05c391766..8b067c3d4 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -240,24 +240,20 @@ def get_frame_at(self, index: int) -> Frame: duration_seconds=duration_seconds.item(), ) - def get_frames_at(self, indices: list[int]) -> FrameBatch: + def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch: """Return frames at the given indices. Args: - indices (list of int): The indices of the frames to retrieve. + indices (torch.Tensor or list of int): The indices of the frames to retrieve. Returns: FrameBatch: The frames at the given indices. """ - if isinstance(indices, torch.Tensor): - # TODO we should avoid converting tensors to lists and just let the - # core ops and C++ code natively accept tensors. See - # https://github.com/pytorch/torchcodec/issues/879 - indices = indices.to(torch.int).tolist() data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices ) + return FrameBatch( data=data, pts_seconds=pts_seconds, diff --git a/test/VideoDecoderTest.cpp b/test/VideoDecoderTest.cpp index 0c21f0d46..241a638b4 100644 --- a/test/VideoDecoderTest.cpp +++ b/test/VideoDecoderTest.cpp @@ -222,7 +222,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) { *ourDecoder->getContainerMetadata().bestVideoStreamIndex; ourDecoder->addVideoStream(bestVideoStreamIndex); // Frame with index 180 corresponds to timestamp 6.006. - auto output = ourDecoder->getFramesAtIndices({0, 180}); + auto frameIndices = torch::tensor({0, 180}); + auto output = ourDecoder->getFramesAtIndices(frameIndices); auto tensor = output.data; EXPECT_EQ(tensor.sizes(), std::vector({2, 3, 270, 480})); @@ -246,7 +247,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) { videoStreamOptions.dimensionOrder = "NHWC"; ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions); // Frame with index 180 corresponds to timestamp 6.006. - auto output = ourDecoder->getFramesAtIndices({0, 180}); + auto frameIndices = torch::tensor({0, 180}); + auto output = ourDecoder->getFramesAtIndices(frameIndices); auto tensor = output.data; EXPECT_EQ(tensor.sizes(), std::vector({2, 270, 480, 3})); diff --git a/test/test_decoders.py b/test/test_decoders.py index 5f128e3e0..ad66f45d5 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -569,7 +569,9 @@ def test_get_frames_at_fails(self, device, seek_mode): with pytest.raises(IndexError, match="Invalid frame index=390"): decoder.get_frames_at([390]) - with pytest.raises(RuntimeError, match="Expected a value of type"): + with pytest.raises( + RuntimeError, match="expected scalar type Long but found Float" + ): decoder.get_frames_at([0.3]) @pytest.mark.parametrize("device", all_supported_devices()) diff --git a/test/test_ops.py b/test/test_ops.py index 01bad1ae0..715687afe 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1209,7 +1209,7 @@ def seek(self, offset: int, whence: int) -> int: torch.manual_seed(0) indices = torch.randint( 0, len(NASA_VIDEO.frames[NASA_VIDEO.default_stream_index]), size=(50,) - ).tolist() + ) frames_file_like, *_ = get_frames_at_indices( decoder_file_like, frame_indices=indices