diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 58a8a9916..e13c9f08b 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -181,6 +181,32 @@ def get_frame_at(self, index: int) -> Frame: duration_seconds=duration_seconds.item(), ) + def get_frames_at(self, indices: list[int]) -> FrameBatch: + """Return frames at the given indices. + + .. note:: + + Calling this method is more efficient that repeated individual calls + to :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at`. This + method makes sure not to decode the same frame twice, and also + avoids "backwards seek" operations, which are slow. + + Args: + indices (list of int): The indices of the frames to retrieve. + + Returns: + FrameBatch: The frames at the given indices. + """ + + data, pts_seconds, duration_seconds = core.get_frames_at_indices( + self._decoder, stream_index=self.stream_index, frame_indices=indices + ) + return FrameBatch( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) + def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch: """Return multiple frames at the given index range. @@ -238,6 +264,31 @@ def get_frame_displayed_at(self, seconds: float) -> Frame: duration_seconds=duration_seconds.item(), ) + def get_frames_displayed_at(self, seconds: list[float]) -> FrameBatch: + """Return frames displayed at the given timestamps in seconds. + + .. note:: + + Calling this method is more efficient that repeated individual calls + to :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at`. + This method makes sure not to decode the same frame twice, and also + avoids "backwards seek" operations, which are slow. + + Args: + seconds (list of float): The timestamps in seconds when the frames are displayed. + + Returns: + FrameBatch: The frames that are displayed at ``seconds``. + """ + data, pts_seconds, duration_seconds = core.get_frames_by_pts( + self._decoder, timestamps=seconds, stream_index=self.stream_index + ) + return FrameBatch( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) + def get_frames_displayed_in_range( self, start_seconds: float, stop_seconds: float ) -> FrameBatch: diff --git a/src/torchcodec/samplers/_common.py b/src/torchcodec/samplers/_common.py index 93dceb750..abf42ffff 100644 --- a/src/torchcodec/samplers/_common.py +++ b/src/torchcodec/samplers/_common.py @@ -1,6 +1,5 @@ from typing import Callable, Union -from torch import Tensor from torchcodec import FrameBatch _LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] @@ -58,17 +57,15 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy): ) -def _make_5d_framebatch( +def _reshape_4d_framebatch_into_5d( *, - data: Tensor, - pts_seconds: Tensor, - duration_seconds: Tensor, + frames: FrameBatch, num_clips: int, num_frames_per_clip: int, ) -> FrameBatch: - last_3_dims = data.shape[-3:] + last_3_dims = frames.data.shape[-3:] return FrameBatch( - data=data.view(num_clips, num_frames_per_clip, *last_3_dims), - pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), - duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), + data=frames.data.view(num_clips, num_frames_per_clip, *last_3_dims), + pts_seconds=frames.pts_seconds.view(num_clips, num_frames_per_clip), + duration_seconds=frames.duration_seconds.view(num_clips, num_frames_per_clip), ) diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index a16a1292f..d528f8019 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -4,11 +4,10 @@ from torchcodec import FrameBatch from torchcodec.decoders import VideoDecoder -from torchcodec.decoders._core import get_frames_at_indices from torchcodec.samplers._common import ( - _make_5d_framebatch, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, + _reshape_4d_framebatch_into_5d, _validate_common_params, ) @@ -177,16 +176,9 @@ def _generic_index_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) - # TODO: Use public method of decoder, when it exists - frames, pts_seconds, duration_seconds = get_frames_at_indices( - decoder._decoder, - stream_index=decoder.stream_index, - frame_indices=all_clips_indices, - ) - return _make_5d_framebatch( - data=frames, - pts_seconds=pts_seconds, - duration_seconds=duration_seconds, + frames = decoder.get_frames_at(indices=all_clips_indices) + return _reshape_4d_framebatch_into_5d( + frames=frames, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, ) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index 8fca584c0..6a9b0dd89 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -3,11 +3,10 @@ import torch from torchcodec import FrameBatch -from torchcodec.decoders._core import get_frames_by_pts from torchcodec.samplers._common import ( - _make_5d_framebatch, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, + _reshape_4d_framebatch_into_5d, _validate_common_params, ) @@ -210,16 +209,9 @@ def _generic_time_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) - # TODO: Use public method of decoder, when it exists - frames, pts_seconds, duration_seconds = get_frames_by_pts( - decoder._decoder, - stream_index=decoder.stream_index, - timestamps=all_clips_timestamps, - ) - return _make_5d_framebatch( - data=frames, - pts_seconds=pts_seconds, - duration_seconds=duration_seconds, + frames = decoder.get_frames_displayed_at(seconds=all_clips_timestamps) + return _reshape_4d_framebatch_into_5d( + frames=frames, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, ) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index f57314657..46c6f9fa3 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -7,6 +7,7 @@ import numpy import pytest import torch +from torchcodec import FrameBatch from torchcodec.decoders import _core, VideoDecoder @@ -301,9 +302,12 @@ def test_get_frame_at(self): assert_tensor_equal(ref_frame9, frame9.data) assert isinstance(frame9.pts_seconds, float) - assert frame9.pts_seconds == pytest.approx(0.3003) + expected_frame_info = NASA_VIDEO.get_frame_info(9) + assert frame9.pts_seconds == pytest.approx(expected_frame_info.pts_seconds) assert isinstance(frame9.duration_seconds, float) - assert frame9.duration_seconds == pytest.approx(0.03337, rel=1e-3) + assert frame9.duration_seconds == pytest.approx( + expected_frame_info.duration_seconds, rel=1e-3 + ) # test numpy.int64 frame9 = decoder.get_frame_at(numpy.int64(9)) @@ -340,6 +344,50 @@ def test_get_frame_at_fails(self): with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(10000) # noqa + def test_get_frames_at(self): + decoder = VideoDecoder(NASA_VIDEO.path) + + frames = decoder.get_frames_at([35, 25]) + + assert isinstance(frames, FrameBatch) + + assert_tensor_equal(frames[0].data, NASA_VIDEO.get_frame_data_by_index(35)) + assert_tensor_equal(frames[1].data, NASA_VIDEO.get_frame_data_by_index(25)) + + expected_pts_seconds = torch.tensor( + [ + NASA_VIDEO.get_frame_info(35).pts_seconds, + NASA_VIDEO.get_frame_info(25).pts_seconds, + ], + dtype=torch.float64, + ) + torch.testing.assert_close( + frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0 + ) + + expected_duration_seconds = torch.tensor( + [ + NASA_VIDEO.get_frame_info(35).duration_seconds, + NASA_VIDEO.get_frame_info(25).duration_seconds, + ], + dtype=torch.float64, + ) + torch.testing.assert_close( + frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 + ) + + def test_get_frames_at_fails(self): + decoder = VideoDecoder(NASA_VIDEO.path) + + with pytest.raises(RuntimeError, match="Invalid frame index=-1"): + decoder.get_frames_at([-1]) + + with pytest.raises(RuntimeError, match="Invalid frame index=390"): + decoder.get_frames_at([390]) + + with pytest.raises(RuntimeError, match="Expected a value of type"): + decoder.get_frames_at([0.3]) + def test_get_frame_displayed_at(self): decoder = VideoDecoder(NASA_VIDEO.path) @@ -365,6 +413,51 @@ def test_get_frame_displayed_at_fails(self): with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_displayed_at(100.0) # noqa + def test_get_frames_displayed_at(self): + + decoder = VideoDecoder(NASA_VIDEO.path) + + # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has + # index 35. We use those indices as reference to test against. + seconds = [0.84, 1.17, 0.85] + reference_indices = [25, 35, 25] + frames = decoder.get_frames_displayed_at(seconds) + + assert isinstance(frames, FrameBatch) + + for i in range(len(reference_indices)): + assert_tensor_equal( + frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i]) + ) + + expected_pts_seconds = torch.tensor( + [NASA_VIDEO.get_frame_info(i).pts_seconds for i in reference_indices], + dtype=torch.float64, + ) + torch.testing.assert_close( + frames.pts_seconds, expected_pts_seconds, atol=1e-4, rtol=0 + ) + + expected_duration_seconds = torch.tensor( + [NASA_VIDEO.get_frame_info(i).duration_seconds for i in reference_indices], + dtype=torch.float64, + ) + torch.testing.assert_close( + frames.duration_seconds, expected_duration_seconds, atol=1e-4, rtol=0 + ) + + def test_get_frames_displayed_at_fails(self): + decoder = VideoDecoder(NASA_VIDEO.path) + + with pytest.raises(RuntimeError, match="must be in range"): + decoder.get_frames_displayed_at([-1]) + + with pytest.raises(RuntimeError, match="must be in range"): + decoder.get_frames_displayed_at([14]) + + with pytest.raises(RuntimeError, match="Expected a value of type"): + decoder.get_frames_displayed_at(["bad"]) + @pytest.mark.parametrize("stream_index", [0, 3, None]) def test_get_frames_in_range(self, stream_index): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index) @@ -456,10 +549,11 @@ def test_get_frames_in_range(self, stream_index): ( lambda decoder: decoder[0], lambda decoder: decoder.get_frame_at(0).data, + lambda decoder: decoder.get_frames_at([0, 1]).data, lambda decoder: decoder.get_frames_in_range(0, 4).data, lambda decoder: decoder.get_frame_displayed_at(0).data, - # TODO: uncomment once D60001893 lands - # lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data, + lambda decoder: decoder.get_frames_displayed_at([0, 1]).data, + lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data, ), ) def test_dimension_order(self, dimension_order, frame_getter): diff --git a/test/utils.py b/test/utils.py index a32ede2b9..2cd10bf17 100644 --- a/test/utils.py +++ b/test/utils.py @@ -265,6 +265,8 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: 8: TestFrameInfo(pts_seconds=0.266933, duration_seconds=0.033367), 9: TestFrameInfo(pts_seconds=0.300300, duration_seconds=0.033367), 10: TestFrameInfo(pts_seconds=0.333667, duration_seconds=0.033367), + 25: TestFrameInfo(pts_seconds=0.8342, duration_seconds=0.033367), + 35: TestFrameInfo(pts_seconds=1.1678, duration_seconds=0.033367), }, }, )