From d068737e4f088ea352b7295b31cfd3741a3bdd9c Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 8 Sep 2025 10:31:44 +0100 Subject: [PATCH] Let get_frames_at and get_frames_played_at accept tensor indices --- src/torchcodec/decoders/_video_decoder.py | 12 ++++++++++++ test/test_decoders.py | 11 +++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 21de5496e..3bf7a6ac2 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -247,6 +247,12 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch: Returns: FrameBatch: The frames at the given indices. """ + if isinstance(indices, torch.Tensor): + # TODO we should avoid converting tensors to lists and just let the + # core ops and C++ code natively accept tensors. See + # https://github.com/pytorch/torchcodec/issues/879 + indices = indices.to(torch.int).tolist() + data, pts_seconds, duration_seconds = core.get_frames_at_indices( self._decoder, frame_indices=indices ) @@ -322,6 +328,12 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch: Returns: FrameBatch: The frames that are played at ``seconds``. """ + if isinstance(seconds, torch.Tensor): + # TODO we should avoid converting tensors to lists and just let the + # core ops and C++ code natively accept tensors. See + # https://github.com/pytorch/torchcodec/issues/879 + seconds = seconds.to(torch.float).tolist() + data, pts_seconds, duration_seconds = core.get_frames_by_pts( self._decoder, timestamps=seconds ) diff --git a/test/test_decoders.py b/test/test_decoders.py index 9e28d2f72..c8d12f6c7 100644 --- a/test/test_decoders.py +++ b/test/test_decoders.py @@ -1390,6 +1390,17 @@ def test_custom_frame_mappings_init_fails_invalid_json(self, tmp_path, device): custom_frame_mappings=custom_frame_mappings, ) + def test_get_frames_at_tensor_indices(self): + # Non-regression test for tensor support in get_frames_at() and + # get_frames_played_at() + decoder = VideoDecoder(NASA_VIDEO.path) + + decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.int)) + decoder.get_frames_at(torch.tensor([0, 10], dtype=torch.float)) + + decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.int)) + decoder.get_frames_played_at(torch.tensor([0, 1], dtype=torch.float)) + class TestAudioDecoder: @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3, SINE_MONO_S32))