From b5ca81445efd3a577edd56040d275bb9f0f7ec4c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 8 Sep 2025 14:01:46 +0100 Subject: [PATCH] Let `get_frames_at` and `get_frames_played_at` accept tensor indices (#880) --- src/torchcodec/decoders/_video_decoder.py | 12 ++++++++++++ test/test_decoders.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index e0d771685..7cda71ede 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -226,6 +226,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: Returns: FrameBatch: The frames at the given indices. """ + if isinstance(indices, torch.Tensor): + # TODO we should avoid converting tensors to lists and just let the + # core ops and C++ code natively accept tensors. See + # https://github.com/pytorch/torchcodec/issues/879 + indices = indices.to(torch.int).tolist() + data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices ) @@ -301,6 +307,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch: Returns: FrameBatch: The frames that are played at ``seconds``. """ + if isinstance(seconds, torch.Tensor): + # TODO we should avoid converting tensors to lists and just let the + # core ops and C++ code natively accept tensors. See + # https://github.com/pytorch/torchcodec/issues/879 + seconds = seconds.to(torch.float).tolist() + data, pts_seconds, duration_seconds = core.get_frames_by_pts( self._decoder, timestamps=seconds ) diff --git a/test/test_decoders.py b/test/test_decoders.py index ffa18d3a0..50b731506 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1389,6 +1389,17 @@ def test_10bit_videos_cpu(self, asset): # custom_frame_mappings=custom_frame_mappings, # ) + def test_get_frames_at_tensor_indices(self): + # Non-regression test for tensor support in get_frames_at() and + # get_frames_played_at() + decoder = VideoDecoder(NASA_VIDEO.path) + + decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.int)) + decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.float)) + + decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int)) + decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float)) + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))