diff --git a/src/torchcodec/samplers/_common.py b/src/torchcodec/samplers/_common.py index 46bf3b180..93dceb750 100644 --- a/src/torchcodec/samplers/_common.py +++ b/src/torchcodec/samplers/_common.py @@ -1,7 +1,7 @@ from typing import Callable, Union -import torch -from torchcodec import Frame, FrameBatch +from torch import Tensor +from torchcodec import FrameBatch _LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] @@ -42,22 +42,6 @@ def _error_policy( } -def _chunk_list(lst, chunk_size): - # return list of sublists of length chunk_size - return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] - - -def _to_framebatch(frames: list[Frame]) -> FrameBatch: - # IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and - # _decode_all_clips_timestamps - data = torch.stack([frame.data for frame in frames]) - pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) - duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) - return FrameBatch( - data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds - ) - - def _validate_common_params(*, decoder, num_frames_per_clip, policy): if len(decoder) < 1: raise ValueError( @@ -72,3 +56,19 @@ def _validate_common_params(*, decoder, num_frames_per_clip, policy): raise ValueError( f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." ) + + +def _make_5d_framebatch( + *, + data: Tensor, + pts_seconds: Tensor, + duration_seconds: Tensor, + num_clips: int, + num_frames_per_clip: int, +) -> FrameBatch: + last_3_dims = data.shape[-3:] + return FrameBatch( + data=data.view(num_clips, num_frames_per_clip, *last_3_dims), + pts_seconds=pts_seconds.view(num_clips, num_frames_per_clip), + duration_seconds=duration_seconds.view(num_clips, num_frames_per_clip), + ) diff --git a/src/torchcodec/samplers/_index_based.py b/src/torchcodec/samplers/_index_based.py index 25e4bd32c..a16a1292f 100644 --- a/src/torchcodec/samplers/_index_based.py +++ b/src/torchcodec/samplers/_index_based.py @@ -1,14 +1,14 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import torch -from torchcodec import Frame, FrameBatch +from torchcodec import FrameBatch from torchcodec.decoders import VideoDecoder +from torchcodec.decoders._core import get_frames_at_indices from torchcodec.samplers._common import ( - _chunk_list, + _make_5d_framebatch, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, - _to_framebatch, _validate_common_params, ) @@ -117,51 +117,6 @@ def _build_all_clips_indices( return all_clips_indices -def _decode_all_clips_indices( - decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int -) -> list[FrameBatch]: - # This takes the list of all the frames to decode (in arbitrary order), - # decode all the frames, and then packs them into clips of length - # num_frames_per_clip. - # - # To avoid backwards seeks (which are slow), we: - # - sort all the frame indices to be decoded - # - dedup them - # - decode all unique frames in sorted order - # - re-assemble the decoded frames back to their original order - # - # TODO: Write this in C++ so we can avoid the copies that happen in `_to_framebatch` - - all_clips_indices_sorted, argsort = zip( - *sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices)) - ) - previous_decoded_frame = None - all_decoded_frames = [None] * len(all_clips_indices) - for i, j in enumerate(argsort): - frame_index = all_clips_indices_sorted[i] - if ( - previous_decoded_frame is not None # then we know i > 0 - and frame_index == all_clips_indices_sorted[i - 1] - ): - # Avoid decoding the same frame twice. - # IMPORTANT: this is only correct because a copy of the frame will - # happen within `_to_framebatch` when we call torch.stack. - # If a copy isn't made, the same underlying memory will be used for - # the 2 consecutive frames. When we re-write this, we should make - # sure to explicitly copy the data. - decoded_frame = previous_decoded_frame - else: - decoded_frame = decoder.get_frame_at(index=frame_index) - previous_decoded_frame = decoded_frame - all_decoded_frames[j] = decoded_frame - - all_clips: list[list[Frame]] = _chunk_list( - all_decoded_frames, chunk_size=num_frames_per_clip - ) - - return [_to_framebatch(clip) for clip in all_clips] - - def _generic_index_based_sampler( kind: Literal["random", "regular"], decoder: VideoDecoder, @@ -174,7 +129,7 @@ def _generic_index_based_sampler( # Important note: sampling_range_end defines the upper bound of where a clip # can *start*, not where a clip can end. policy: Literal["repeat_last", "wrap", "error"], -) -> List[FrameBatch]: +) -> FrameBatch: _validate_common_params( decoder=decoder, @@ -221,9 +176,18 @@ def _generic_index_based_sampler( num_frames_in_video=len(decoder), policy_fun=_POLICY_FUNCTIONS[policy], ) - return _decode_all_clips_indices( - decoder, - all_clips_indices=all_clips_indices, + + # TODO: Use public method of decoder, when it exists + frames, pts_seconds, duration_seconds = get_frames_at_indices( + decoder._decoder, + stream_index=decoder.stream_index, + frame_indices=all_clips_indices, + ) + return _make_5d_framebatch( + data=frames, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, ) @@ -237,7 +201,7 @@ def clips_at_random_indices( sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_index_based_sampler( kind="random", decoder=decoder, @@ -259,7 +223,7 @@ def clips_at_regular_indices( sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_index_based_sampler( kind="regular", diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index f890d2165..e9d485aad 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -1,14 +1,13 @@ -from typing import List, Literal, Optional +from typing import Literal, Optional import torch -from torchcodec import Frame, FrameBatch -from torchcodec.decoders import VideoDecoder +from torchcodec import FrameBatch +from torchcodec.decoders._core import get_frames_by_pts from torchcodec.samplers._common import ( - _chunk_list, + _make_5d_framebatch, _POLICY_FUNCTION_TYPE, _POLICY_FUNCTIONS, - _to_framebatch, _validate_common_params, ) @@ -147,51 +146,6 @@ def _build_all_clips_timestamps( return all_clips_timestamps -def _decode_all_clips_timestamps( - decoder: VideoDecoder, all_clips_timestamps: list[float], num_frames_per_clip: int -) -> list[FrameBatch]: - # This is 99% the same as _decode_all_clips_indices. The only change is the - # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx) - - all_clips_timestamps_sorted, argsort = zip( - *sorted( - (frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps) - ) - ) - previous_decoded_frame = None - all_decoded_frames = [None] * len(all_clips_timestamps) - for i, j in enumerate(argsort): - frame_pts_seconds = all_clips_timestamps_sorted[i] - if ( - previous_decoded_frame is not None # then we know i > 0 - and frame_pts_seconds == all_clips_timestamps_sorted[i - 1] - ): - # Avoid decoding the same frame twice. - # Unfortunatly this is unlikely to lead to speed-up as-is: it's - # pretty unlikely that 2 pts will be the same since pts are float - # contiguous values. Theoretically the dedup can still happen, but - # it would be much more efficient to implement it at the frame index - # level. We should do that once we implement that in C++. - # See also https://github.com/pytorch/torchcodec/issues/256. - # - # IMPORTANT: this is only correct because a copy of the frame will - # happen within `_to_framebatch` when we call torch.stack. - # If a copy isn't made, the same underlying memory will be used for - # the 2 consecutive frames. When we re-write this, we should make - # sure to explicitly copy the data. - decoded_frame = previous_decoded_frame - else: - decoded_frame = decoder.get_frame_displayed_at(seconds=frame_pts_seconds) - previous_decoded_frame = decoded_frame - all_decoded_frames[j] = decoded_frame - - all_clips: list[list[Frame]] = _chunk_list( - all_decoded_frames, chunk_size=num_frames_per_clip - ) - - return [_to_framebatch(clip) for clip in all_clips] - - def _generic_time_based_sampler( kind: Literal["random", "regular"], decoder, @@ -204,7 +158,7 @@ def _generic_time_based_sampler( sampling_range_start: Optional[float], sampling_range_end: Optional[float], # interval is [start, end). policy: str = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a # clip can start. This is an *open* upper bound, i.e. we will make sure no # clip starts exactly at (or above) sampling_range_end. @@ -246,6 +200,7 @@ def _generic_time_based_sampler( sampling_range_end, # excluded seconds_between_clip_starts, ) + num_clips = len(clip_start_seconds) all_clips_timestamps = _build_all_clips_timestamps( clip_start_seconds=clip_start_seconds, @@ -255,9 +210,17 @@ def _generic_time_based_sampler( policy_fun=_POLICY_FUNCTIONS[policy], ) - return _decode_all_clips_timestamps( - decoder, - all_clips_timestamps=all_clips_timestamps, + # TODO: Use public method of decoder, when it exists + frames, pts_seconds, duration_seconds = get_frames_by_pts( + decoder._decoder, + stream_index=decoder.stream_index, + timestamps=all_clips_timestamps, + ) + return _make_5d_framebatch( + data=frames, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, ) @@ -272,7 +235,7 @@ def clips_at_random_timestamps( sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). policy: str = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_time_based_sampler( kind="random", decoder=decoder, @@ -296,7 +259,7 @@ def clips_at_regular_timestamps( sampling_range_start: Optional[float] = None, sampling_range_end: Optional[float] = None, # interval is [start, end). policy: str = "repeat_last", -) -> List[FrameBatch]: +) -> FrameBatch: return _generic_time_based_sampler( kind="regular", decoder=decoder, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 3149a5410..4a12d93c4 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -25,23 +25,26 @@ def _assert_output_type_and_shapes( video, clips, expected_num_clips, num_frames_per_clip ): - assert isinstance(clips, list) - assert len(clips) == expected_num_clips - assert all(isinstance(clip, FrameBatch) for clip in clips) - expected_clip_data_shape = ( + assert isinstance(clips, FrameBatch) + assert clips.data.shape == ( + expected_num_clips, num_frames_per_clip, 3, video.height, video.width, ) - assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + assert clips.pts_seconds.shape == ( + expected_num_clips, + num_frames_per_clip, + ) + assert clips.duration_seconds.shape == ( + expected_num_clips, + num_frames_per_clip, + ) def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None): - # assert regular spacing between sampled clips - seconds_between_clip_starts = torch.tensor( - [clip.pts_seconds[0] for clip in clips] - ).diff() + seconds_between_clip_starts = clips.pts_seconds[:, 0].diff() if expected_seconds_between_clip_starts is not None: # This can only be asserted with the time-based sampler, where @@ -88,13 +91,11 @@ def test_index_based_sampler(sampler, num_indices_between_frames): # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. - - avg_distance_between_frames_seconds = torch.concat( - [clip.pts_seconds.diff() for clip in clips] - ).mean() - assert avg_distance_between_frames_seconds == pytest.approx( - num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 - ) + for clip in clips: + avg_distance_between_frames_seconds = clip.pts_seconds.diff().mean() + assert avg_distance_between_frames_seconds == pytest.approx( + num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 + ) @pytest.mark.parametrize( @@ -130,7 +131,7 @@ def test_time_based_sampler(sampler, seconds_between_frames): if sampler.func is clips_at_regular_timestamps: seconds_between_clip_starts = sampler.keywords["seconds_between_clip_starts"] expected_seconds_between_clip_starts = torch.tensor( - [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float + [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float64 ) _assert_regular_sampler( clips=clips, @@ -140,12 +141,11 @@ def test_time_based_sampler(sampler, seconds_between_frames): expected_seconds_between_frames = ( seconds_between_frames or 1 / decoder.metadata.average_fps ) - avg_seconds_between_frames_seconds = torch.concat( - [clip.pts_seconds.diff() for clip in clips] - ).mean() - assert avg_seconds_between_frames_seconds == pytest.approx( - expected_seconds_between_frames, abs=0.05 - ) + for clip in clips: + avg_seconds_between_frames = clip.pts_seconds.diff().mean() + assert avg_seconds_between_frames == pytest.approx( + expected_seconds_between_frames, abs=0.05 + ) @pytest.mark.parametrize( @@ -284,7 +284,8 @@ def test_sampling_range_default_behavior_random_sampler(sampler): policy="error", ) - last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) + # last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) + last_clip_start_default = clips_default.pts_seconds[:, 0].max() # with manual sampling_range_end value set to last frame / end of video clips_manual = sampler( @@ -294,7 +295,7 @@ def test_sampling_range_default_behavior_random_sampler(sampler): sampling_range_start=sampling_range_start, sampling_range_end=1000, ) - last_clip_start_manual = max([clip.pts_seconds[0] for clip in clips_manual]) + last_clip_start_manual = clips_manual.pts_seconds[:, 0].max() assert last_clip_start_manual - last_clip_start_default > 0.3 @@ -382,12 +383,13 @@ def test_random_sampler_randomness(sampler): # Assert the clip starts aren't sorted, to make sure we haven't messed up # the implementation. (This may fail if we're unlucky, but we hard-coded a # seed, so it will always pass.) - clip_starts = [clip.pts_seconds.item() for clip in clips_1] + clip_starts = clips_1.pts_seconds[:, 0].tolist() assert sorted(clip_starts) != clip_starts # Call the same sampler again with the same seed, expect same results torch.manual_seed(0) clips_2 = sampler(decoder, num_clips=num_clips) + for clip_1, clip_2 in zip(clips_1, clips_2): assert_tensor_equal(clip_1.data, clip_2.data) assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) @@ -427,7 +429,7 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz assert len(clips) == num_clips - clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips]) + clip_starts_seconds = clips.pts_seconds[:, 0] assert len(torch.unique(clip_starts_seconds)) == sampling_range_size # Assert clips starts are ordered, i.e. the start indices don't just "wrap