From 6471c2add671e6e67293e1a035f078b3563d38d0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Oct 2024 06:10:36 -0700 Subject: [PATCH 01/15] Add time-based regular sampler --- src/torchcodec/samplers/__init__.py | 6 +- src/torchcodec/samplers/_implem.py | 228 +++++++++++++++++++++++++++- 2 files changed, 230 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py index 5a173c218..66cd9c91d 100644 --- a/src/torchcodec/samplers/__init__.py +++ b/src/torchcodec/samplers/__init__.py @@ -1 +1,5 @@ -from ._implem import clips_at_random_indices, clips_at_regular_indices +from ._implem import ( + clips_at_random_indices, + clips_at_regular_indices, + clips_at_regular_timestamps, +) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 884f91f4b..cfb88914e 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -6,6 +6,9 @@ from torchcodec.decoders import VideoDecoder +_EPS = 1e-4 + + def _validate_params( *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames, policy ): @@ -216,7 +219,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: return [to_framebatch(clip) for clip in all_clips] -def _generic_sampler( +def _generic_index_based_sampler( kind: Literal["random", "regular"], decoder: VideoDecoder, *, @@ -290,7 +293,7 @@ def clips_at_random_indices( sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: - return _generic_sampler( + return _generic_index_based_sampler( kind="random", decoder=decoder, num_clips=num_clips, @@ -313,7 +316,7 @@ def clips_at_regular_indices( policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: - return _generic_sampler( + return _generic_index_based_sampler( kind="regular", decoder=decoder, num_clips=num_clips, @@ -323,3 +326,222 @@ def clips_at_regular_indices( sampling_range_end=sampling_range_end, policy=policy, ) + + +def _get_approximate_clip_span_seconds( + *, + decoder, + num_frames_per_clip, + seconds_between_frames, +): + + # Compute clip span, in seconds. We can only compute an approximate value: + # we assume the fps are constant. Computing the real value requires + # accounting for variable fps. + + assert decoder.metadata.average_fps is not None + average_frame_duration_seconds = 1 / decoder.metadata.average_fps + if seconds_between_frames is None: + approximate_clip_span_seconds = ( + num_frames_per_clip * average_frame_duration_seconds + ) + else: + # aaa, bbb, ccc, ddd are 4 frames within a clip. + # + # seconds_between_frames + # | + # v + # < ---- > + # clip = [aaa....bbb....ccc....ddd] + # < ----------------- > + # ^ + # | + # (num_frames_per_clip - 1) * seconds_between_frames + # + # Now to compute the clip span, we need to add the duration of the last + # frame. The formula is fairly approximate, as we assume fps are + # constant, and that + # seconds_between_frames > average_frame_duration_seconds. It's good + # enough for what we need to do. + approximate_clip_span_seconds = ( + num_frames_per_clip - 1 + ) * seconds_between_frames + average_frame_duration_seconds + + return approximate_clip_span_seconds + + +def _validate_sampling_range_time_based( + *, + decoder, + num_frames_per_clip, + seconds_between_frames, + sampling_range_start, + sampling_range_end, +): + assert decoder.metadata.end_stream_seconds is not None + if sampling_range_start is None: + assert decoder.metadata.begin_stream_seconds is not None + sampling_range_start = decoder.metadata.begin_stream_seconds + + if sampling_range_end is None: + approximate_clip_span_seconds = _get_approximate_clip_span_seconds( + decoder=decoder, + seconds_between_frames=seconds_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) + sampling_range_end = ( + decoder.metadata.end_stream_seconds - approximate_clip_span_seconds + ) + sampling_range_end = min( + sampling_range_end, decoder.metadata.end_stream_seconds - _EPS + ) + + return sampling_range_start, sampling_range_end + + +def _build_all_clips_timestamps( + *, + decoder: VideoDecoder, + clip_start_seconds: torch.Tensor, # 1D float tensor + num_frames_per_clip: int, + seconds_between_frames: Optional[float], + end_video_seconds: float, + policy_fun: _POLICY_FUNCTION_TYPE, +) -> list[int]: + all_clips_timestamps: list[float] = [] + + approximate_clip_span_seconds = _get_approximate_clip_span_seconds( + decoder=decoder, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + ) + + if seconds_between_frames is None: + average_frame_duration_seconds = 1 / decoder.metadata.average_fps + seconds_between_frames = average_frame_duration_seconds + # TODO: What we're doing above defeats the purpose of having a + # time-based sampler because we are assuming constant fps. We won't + # accurately get consecutive frames for variable fps, while this is the + # desired behavior when seconds_between_frames is None. I think we need + # an API like next_pts = decoder._get_next_pts(current_pts) that returns + # the pts of the *next* frame. i.e. if frame i is the one displayed at + # current_pts, we want the pts of frame i+1. + + for start_seconds in clip_start_seconds: + frame_pts_upper_bound = min( + start_seconds + approximate_clip_span_seconds, end_video_seconds - _EPS + ) + # This is correct when seconds_between_frames is specified by the user, + # but not quite correct when it's None if fps are variable. See note + # above. + frame_pts = torch.arange( + start_seconds, frame_pts_upper_bound, step=seconds_between_frames + ).tolist() + if len(frame_pts) < num_frames_per_clip: + frame_pts = policy_fun(frame_pts, num_frames_per_clip) + all_clips_timestamps += frame_pts + + return all_clips_timestamps + + +def _decode_all_clips_timestamps( + decoder: VideoDecoder, all_clips_timestamps: list[int], num_frames_per_clip: int +) -> list[FrameBatch]: + # This is 99% the same as _decode_all_clips_indices. The only change is the + # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx) + + def chunk_list(lst, chunk_size): + # return list of sublists of length chunk_size + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + def to_framebatch(frames: list[Frame]) -> FrameBatch: + # IMPORTANT: see other IMPORTANT note below + data = torch.stack([frame.data for frame in frames]) + pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) + duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) + return FrameBatch( + data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds + ) + + all_clips_timestamps_sorted, argsort = zip( + *sorted( + (frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps) + ) + ) + previous_decoded_frame = None + all_decoded_frames = [None] * len(all_clips_timestamps) + for i, j in enumerate(argsort): + frame_pts_seconds = all_clips_timestamps_sorted[i] + if ( + previous_decoded_frame is not None # then we know i > 0 + and frame_pts_seconds == all_clips_timestamps_sorted[i - 1] + ): + # Avoid decoding the same frame twice. + # IMPORTANT: this is only correct because a copy of the frame will + # happen within `to_framebatch` when we call torch.stack. + # If a copy isn't made, the same underlying memory will be used for + # the 2 consecutive frames. When we re-write this, we should make + # sure to explicitly copy the data. + decoded_frame = previous_decoded_frame + else: + decoded_frame = decoder.get_frame_displayed_at(seconds=frame_pts_seconds) + previous_decoded_frame = decoded_frame + all_decoded_frames[j] = decoded_frame + + all_clips: list[list[Frame]] = chunk_list( + all_decoded_frames, chunk_size=num_frames_per_clip + ) + + return [to_framebatch(clip) for clip in all_clips] + + +def clips_at_regular_timestamps( + decoder, + *, + seconds_between_clip_starts: int, # TODO or its inverse: num_clips_per_seconds? + num_frames_per_clip: int = 1, + seconds_between_frames: Optional[float] = None, + # None means "begining", which may not always be 0 + sampling_range_start: Optional[float] = None, + sampling_range_end: Optional[float] = None, + policy: str = "repeat_last", +) -> List[FrameBatch]: + + # TODO: better validation + assert seconds_between_clip_starts > 0 + assert num_frames_per_clip > 0 + assert seconds_between_frames is None or seconds_between_frames > 0 + assert sampling_range_start is None or sampling_range_start >= 0 + assert sampling_range_end is None or sampling_range_end >= 0 + + sampling_range_start, sampling_range_end = _validate_sampling_range_time_based( + decoder=decoder, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + ) + + sampling_range_seconds = sampling_range_end - sampling_range_start + num_clips = int(round(sampling_range_seconds / seconds_between_clip_starts)) + + clip_start_seconds = torch.linspace( + sampling_range_start, + sampling_range_end, + steps=num_clips, + ) + + all_clips_timestamps = _build_all_clips_timestamps( + decoder=decoder, + clip_start_seconds=clip_start_seconds, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + end_video_seconds=decoder.metadata.end_stream_seconds, + policy_fun=_POLICY_FUNCTIONS[policy], + ) + + return _decode_all_clips_timestamps( + decoder, + all_clips_timestamps=all_clips_timestamps, + num_frames_per_clip=num_frames_per_clip, + ) From e00fce7a5fdc652738880da0df6dbffc14707ab3 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 8 Oct 2024 11:23:38 -0700 Subject: [PATCH 02/15] Refac, some comments --- src/torchcodec/samplers/_implem.py | 224 ++++++++++++++++------------- 1 file changed, 124 insertions(+), 100 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index cfb88914e..28039cc42 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -9,32 +9,33 @@ _EPS = 1e-4 -def _validate_params( - *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames, policy -): +def _validate_params(*, decoder, num_frames_per_clip, policy): if len(decoder) < 1: raise ValueError( f"Decoder must have at least one frame, found {len(decoder)} frames." ) - if num_clips <= 0: - raise ValueError(f"num_clips ({num_clips}) must be strictly positive") if num_frames_per_clip <= 0: raise ValueError( f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" ) - if num_indices_between_frames <= 0: + if policy not in _POLICY_FUNCTIONS.keys(): raise ValueError( - f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." ) - if policy not in _POLICY_FUNCTIONS.keys(): + +def _validate_params_index_based(*, num_clips, num_indices_between_frames): + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + + if num_indices_between_frames <= 0: raise ValueError( - f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" ) -def _validate_sampling_range( +def _validate_sampling_range_index_based( *, num_indices_between_frames, num_frames_per_clip, @@ -235,13 +236,15 @@ def _generic_index_based_sampler( _validate_params( decoder=decoder, - num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, - num_indices_between_frames=num_indices_between_frames, policy=policy, ) + _validate_params_index_based( + num_clips=num_clips, + num_indices_between_frames=num_indices_between_frames, + ) - sampling_range_start, sampling_range_end = _validate_sampling_range( + sampling_range_start, sampling_range_end = _validate_sampling_range_index_based( num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, sampling_range_start=sampling_range_start, @@ -328,124 +331,134 @@ def clips_at_regular_indices( ) -def _get_approximate_clip_span_seconds( +def _validate_params_time_based( *, decoder, - num_frames_per_clip, + seconds_between_clip_starts, seconds_between_frames, ): + if seconds_between_clip_starts <= 0: + raise ValueError( + f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0, got" + ) - # Compute clip span, in seconds. We can only compute an approximate value: - # we assume the fps are constant. Computing the real value requires - # accounting for variable fps. + if decoder.metadata.average_fps is None: + raise ValueError( + "Could not infer average fps from video metadata. " + "Try using an index-based sampler instead." + ) + if ( + decoder.metadata.end_stream_seconds is None + or decoder.metadata.begin_stream_seconds is None + ): + raise ValueError( + "Could not infer stream end and start from video metadata. " + "Try using an index-based sampler instead." + ) - assert decoder.metadata.average_fps is not None average_frame_duration_seconds = 1 / decoder.metadata.average_fps if seconds_between_frames is None: - approximate_clip_span_seconds = ( - num_frames_per_clip * average_frame_duration_seconds + seconds_between_frames = average_frame_duration_seconds + elif seconds_between_frames <= 0: + raise ValueError( + f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0, got" ) - else: - # aaa, bbb, ccc, ddd are 4 frames within a clip. - # - # seconds_between_frames - # | - # v - # < ---- > - # clip = [aaa....bbb....ccc....ddd] - # < ----------------- > - # ^ - # | - # (num_frames_per_clip - 1) * seconds_between_frames - # - # Now to compute the clip span, we need to add the duration of the last - # frame. The formula is fairly approximate, as we assume fps are - # constant, and that - # seconds_between_frames > average_frame_duration_seconds. It's good - # enough for what we need to do. - approximate_clip_span_seconds = ( - num_frames_per_clip - 1 - ) * seconds_between_frames + average_frame_duration_seconds - return approximate_clip_span_seconds + return seconds_between_frames def _validate_sampling_range_time_based( *, - decoder, num_frames_per_clip, seconds_between_frames, sampling_range_start, sampling_range_end, + begin_stream_seconds, + end_stream_seconds, ): - assert decoder.metadata.end_stream_seconds is not None + if sampling_range_start is None: - assert decoder.metadata.begin_stream_seconds is not None - sampling_range_start = decoder.metadata.begin_stream_seconds + sampling_range_start = begin_stream_seconds + elif sampling_range_start <= begin_stream_seconds: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}" + ) if sampling_range_end is None: - approximate_clip_span_seconds = _get_approximate_clip_span_seconds( - decoder=decoder, - seconds_between_frames=seconds_between_frames, - num_frames_per_clip=num_frames_per_clip, - ) + # We allow a clip to start anywhere within + # [sampling_range_start, sampling_range_end) + # When sampling_range_end is None, we want to automatically set it to + # the largest possible value such that the sampled frames in any clip + # are within the bounds of the video duration (in other words, we don't + # want to have to resort to the `policy`). + # I.e. we want to guarantee that for all frames in any clip we have + # pts < end_stream_seconds. + # + # The frames of clip will be sampled at the following pts: + # clip_timestamps = [ + # clip_start + 0 * seconds_between_frames, + # clip_start + 1 * seconds_between_frames, + # clip_start + 2 * seconds_between_frames, + # ... + # clip_start + (num_frames_per_clip - 1) * seconds_between_frames, + # ] + # To guarantee that any such value is < end_stream_seconds, we only need + # to guarantee that + # clip_start < end_stream_seconds - (num_frames_per_clip - 1) * seconds_between_frames + # + # So that's the value of sampling_range_end we want to use. sampling_range_end = ( - decoder.metadata.end_stream_seconds - approximate_clip_span_seconds + end_stream_seconds - (num_frames_per_clip - 1) * seconds_between_frames + ) + elif sampling_range_end <= begin_stream_seconds: + raise ValueError( + f"sampling_range_end ({sampling_range_end}) must be at least {begin_stream_seconds}" ) - sampling_range_end = min( - sampling_range_end, decoder.metadata.end_stream_seconds - _EPS - ) + + if sampling_range_start >= sampling_range_end: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be less than sampling_range_end ({sampling_range_end})" + ) + + sampling_range_end = min(sampling_range_end, end_stream_seconds) return sampling_range_start, sampling_range_end def _build_all_clips_timestamps( *, - decoder: VideoDecoder, clip_start_seconds: torch.Tensor, # 1D float tensor num_frames_per_clip: int, - seconds_between_frames: Optional[float], - end_video_seconds: float, + seconds_between_frames: float, + end_stream_seconds: float, policy_fun: _POLICY_FUNCTION_TYPE, ) -> list[int]: - all_clips_timestamps: list[float] = [] - - approximate_clip_span_seconds = _get_approximate_clip_span_seconds( - decoder=decoder, - num_frames_per_clip=num_frames_per_clip, - seconds_between_frames=seconds_between_frames, - ) - - if seconds_between_frames is None: - average_frame_duration_seconds = 1 / decoder.metadata.average_fps - seconds_between_frames = average_frame_duration_seconds - # TODO: What we're doing above defeats the purpose of having a - # time-based sampler because we are assuming constant fps. We won't - # accurately get consecutive frames for variable fps, while this is the - # desired behavior when seconds_between_frames is None. I think we need - # an API like next_pts = decoder._get_next_pts(current_pts) that returns - # the pts of the *next* frame. i.e. if frame i is the one displayed at - # current_pts, we want the pts of frame i+1. + all_clips_timestamps: list[float] = [] for start_seconds in clip_start_seconds: - frame_pts_upper_bound = min( - start_seconds + approximate_clip_span_seconds, end_video_seconds - _EPS + # clip_timestamps = [ + # start_seconds + 0 * seconds_between_frames, + # start_seconds + 1 * seconds_between_frames, + # start_seconds + 2 * seconds_between_frames, + # ... + # ] + clip_timestamps = torch.full( + size=(num_frames_per_clip,), fill_value=start_seconds ) - # This is correct when seconds_between_frames is specified by the user, - # but not quite correct when it's None if fps are variable. See note - # above. - frame_pts = torch.arange( - start_seconds, frame_pts_upper_bound, step=seconds_between_frames - ).tolist() - if len(frame_pts) < num_frames_per_clip: - frame_pts = policy_fun(frame_pts, num_frames_per_clip) - all_clips_timestamps += frame_pts + clip_timestamps += torch.arange(num_frames_per_clip) * seconds_between_frames + + # We clip to valid values, so that we can call the same policies as + # for index-based samplers. + clip_timestamps = clip_timestamps[clip_timestamps < end_stream_seconds].tolist() + if len(clip_timestamps) < num_frames_per_clip: + clip_timestamps = policy_fun(clip_timestamps, num_frames_per_clip) + all_clips_timestamps += clip_timestamps return all_clips_timestamps def _decode_all_clips_timestamps( - decoder: VideoDecoder, all_clips_timestamps: list[int], num_frames_per_clip: int + decoder: VideoDecoder, all_clips_timestamps: list[float], num_frames_per_clip: int ) -> list[FrameBatch]: # This is 99% the same as _decode_all_clips_indices. The only change is the # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx) @@ -503,40 +516,51 @@ def clips_at_regular_timestamps( seconds_between_frames: Optional[float] = None, # None means "begining", which may not always be 0 sampling_range_start: Optional[float] = None, - sampling_range_end: Optional[float] = None, + sampling_range_end: Optional[float] = None, # interval is [start, end). policy: str = "repeat_last", ) -> List[FrameBatch]: - # TODO: better validation - assert seconds_between_clip_starts > 0 - assert num_frames_per_clip > 0 - assert seconds_between_frames is None or seconds_between_frames > 0 - assert sampling_range_start is None or sampling_range_start >= 0 - assert sampling_range_end is None or sampling_range_end >= 0 + _validate_params( + decoder=decoder, + num_frames_per_clip=num_frames_per_clip, + policy=policy, + ) - sampling_range_start, sampling_range_end = _validate_sampling_range_time_based( + seconds_between_frames = _validate_params_time_based( decoder=decoder, + seconds_between_clip_starts=seconds_between_clip_starts, + seconds_between_frames=seconds_between_frames, + ) + + sampling_range_start, sampling_range_end = _validate_sampling_range_time_based( num_frames_per_clip=num_frames_per_clip, seconds_between_frames=seconds_between_frames, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, + begin_stream_seconds=decoder.metadata.begin_stream_seconds, + end_stream_seconds=decoder.metadata.end_stream_seconds, ) sampling_range_seconds = sampling_range_end - sampling_range_start num_clips = int(round(sampling_range_seconds / seconds_between_clip_starts)) + # Note on sampling_range_end: we want the sampling range to be + # [sampling_range_start, sampling_range_end) where the upper bound is open. + # This is for consistency with the index-based case. + # Because in torch.linspace the upper bound is inclusive, we would risk + # getting exactly sampling_range_end as a clip start value. To avoid that, + # we substract a small value. clip_start_seconds = torch.linspace( sampling_range_start, - sampling_range_end, + sampling_range_end - _EPS, steps=num_clips, ) all_clips_timestamps = _build_all_clips_timestamps( - decoder=decoder, clip_start_seconds=clip_start_seconds, num_frames_per_clip=num_frames_per_clip, seconds_between_frames=seconds_between_frames, - end_video_seconds=decoder.metadata.end_stream_seconds, + end_stream_seconds=decoder.metadata.end_stream_seconds, policy_fun=_POLICY_FUNCTIONS[policy], ) From cca1ff964cf63166b3028e8891cb7e839ae8a8bc Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Oct 2024 02:05:08 -0700 Subject: [PATCH 03/15] More tests --- src/torchcodec/samplers/_implem.py | 42 +++--- test/samplers/test_samplers.py | 204 ++++++++++++++++++++++++----- 2 files changed, 189 insertions(+), 57 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 28039cc42..33204826e 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -6,9 +6,6 @@ from torchcodec.decoders import VideoDecoder -_EPS = 1e-4 - - def _validate_params(*, decoder, num_frames_per_clip, policy): if len(decoder) < 1: raise ValueError( @@ -339,7 +336,7 @@ def _validate_params_time_based( ): if seconds_between_clip_starts <= 0: raise ValueError( - f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0, got" + f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0" ) if decoder.metadata.average_fps is None: @@ -379,10 +376,15 @@ def _validate_sampling_range_time_based( if sampling_range_start is None: sampling_range_start = begin_stream_seconds - elif sampling_range_start <= begin_stream_seconds: - raise ValueError( - f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}" - ) + else: + if sampling_range_start <= begin_stream_seconds: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}" + ) + if sampling_range_start >= end_stream_seconds: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be smaller than {end_stream_seconds}" + ) if sampling_range_end is None: # We allow a clip to start anywhere within @@ -394,7 +396,7 @@ def _validate_sampling_range_time_based( # I.e. we want to guarantee that for all frames in any clip we have # pts < end_stream_seconds. # - # The frames of clip will be sampled at the following pts: + # The frames of a clip will be sampled at the following pts: # clip_timestamps = [ # clip_start + 0 * seconds_between_frames, # clip_start + 1 * seconds_between_frames, @@ -417,7 +419,7 @@ def _validate_sampling_range_time_based( if sampling_range_start >= sampling_range_end: raise ValueError( - f"sampling_range_start ({sampling_range_start}) must be less than sampling_range_end ({sampling_range_end})" + f"sampling_range_start ({sampling_range_start}) must be smaller than sampling_range_end ({sampling_range_end})" ) sampling_range_end = min(sampling_range_end, end_stream_seconds) @@ -511,7 +513,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: def clips_at_regular_timestamps( decoder, *, - seconds_between_clip_starts: int, # TODO or its inverse: num_clips_per_seconds? + seconds_between_clip_starts: float, num_frames_per_clip: int = 1, seconds_between_frames: Optional[float] = None, # None means "begining", which may not always be 0 @@ -519,6 +521,9 @@ def clips_at_regular_timestamps( sampling_range_end: Optional[float] = None, # interval is [start, end). policy: str = "repeat_last", ) -> List[FrameBatch]: + # Note: *everywhere*, sampling_range_end denotes the upper bound of where a + # clip can start. This is an *open* upper bound, i.e. we will make sure no + # clip start exactly at (or above) sampling_range_end. _validate_params( decoder=decoder, @@ -541,19 +546,10 @@ def clips_at_regular_timestamps( end_stream_seconds=decoder.metadata.end_stream_seconds, ) - sampling_range_seconds = sampling_range_end - sampling_range_start - num_clips = int(round(sampling_range_seconds / seconds_between_clip_starts)) - - # Note on sampling_range_end: we want the sampling range to be - # [sampling_range_start, sampling_range_end) where the upper bound is open. - # This is for consistency with the index-based case. - # Because in torch.linspace the upper bound is inclusive, we would risk - # getting exactly sampling_range_end as a clip start value. To avoid that, - # we substract a small value. - clip_start_seconds = torch.linspace( + clip_start_seconds = torch.arange( sampling_range_start, - sampling_range_end - _EPS, - steps=num_clips, + sampling_range_end, # excluded + seconds_between_clip_starts, ) all_clips_timestamps = _build_all_clips_timestamps( diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 42303b33f..2fca06a90 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -1,21 +1,61 @@ import contextlib import random import re +from copy import copy import pytest import torch from torchcodec import FrameBatch from torchcodec.decoders import VideoDecoder -from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices +from torchcodec.samplers import ( + clips_at_random_indices, + clips_at_regular_indices, + clips_at_regular_timestamps, +) from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS from ..utils import assert_tensor_equal, NASA_VIDEO +def _assert_output_type_and_shapes( + video, clips, expected_num_clips, num_frames_per_clip +): + assert isinstance(clips, list) + assert len(clips) == expected_num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + video.height, + video.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + +def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None): + # assert regular spacing between sampled clips + seconds_between_clip_starts = torch.tensor( + [clip.pts_seconds[0] for clip in clips] + ).diff() + for diff in seconds_between_clip_starts: + # Note: need approximate check as actual diff values typically look like + # [3.2032, 3.2366, 3.2366, 3.2366] + assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) + + if expected_seconds_between_clip_starts is not None: + # This can only be asserted with the time-based sampler, where + # seconds_between_clip_starts is known since it's passed by the user. + torch.testing.assert_close( + diff, expected_seconds_between_clip_starts, atol=0.01, rtol=0 + ) + + assert (diff > 0).all() # Also assert clips are sorted by start time + + @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) -def test_sampler(sampler, num_indices_between_frames): +def test_index_based_sampler(sampler, num_indices_between_frames): decoder = VideoDecoder(NASA_VIDEO.path) num_clips = 5 num_frames_per_clip = 3 @@ -27,27 +67,15 @@ def test_sampler(sampler, num_indices_between_frames): num_indices_between_frames=num_indices_between_frames, ) - assert isinstance(clips, list) - assert len(clips) == num_clips - assert all(isinstance(clip, FrameBatch) for clip in clips) - expected_clip_data_shape = ( - num_frames_per_clip, - 3, - NASA_VIDEO.height, - NASA_VIDEO.width, + _assert_output_type_and_shapes( + video=NASA_VIDEO, + clips=clips, + expected_num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, ) - assert all(clip.data.shape == expected_clip_data_shape for clip in clips) if sampler is clips_at_regular_indices: - # assert regular spacing between sampled clips - # Note: need approximate check as actual values typically look like [3.2032, 3.2366, 3.2366, 3.2366] - seconds_between_clip_starts = torch.tensor( - [clip.pts_seconds[0] for clip in clips] - ).diff() - for diff in seconds_between_clip_starts: - assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) - - assert (diff > 0).all() # Also assert clips are sorted by start time + _assert_regular_sampler(clips=clips, expected_seconds_between_clip_starts=None) # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" @@ -61,6 +89,46 @@ def test_sampler(sampler, num_indices_between_frames): ) +@pytest.mark.parametrize("seconds_between_frames", [None, 3]) +def test_time_based_sampler(seconds_between_frames): + decoder = VideoDecoder(NASA_VIDEO.path) + num_frames_per_clip = 3 + seconds_between_clip_starts = 2 + + clips = clips_at_regular_timestamps( + decoder, + seconds_between_clip_starts=seconds_between_clip_starts, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + ) + + expeted_num_clips = len(clips) # no-op check, it's just hard to assert + _assert_output_type_and_shapes( + video=NASA_VIDEO, + clips=clips, + expected_num_clips=expeted_num_clips, + num_frames_per_clip=num_frames_per_clip, + ) + + expected_seconds_between_clip_starts = torch.tensor( + seconds_between_clip_starts, dtype=torch.float + ) + _assert_regular_sampler( + clips=clips, + expected_seconds_between_clip_starts=expected_seconds_between_clip_starts, + ) + + expected_seconds_between_frames = ( + seconds_between_frames or 1 / decoder.metadata.average_fps + ) + avg_seconds_between_frames_seconds = torch.concat( + [clip.pts_seconds.diff() for clip in clips] + ).mean() + assert avg_seconds_between_frames_seconds == pytest.approx( + expected_seconds_between_frames, abs=0.05 + ) + + @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize( "sampling_range_start, sampling_range_end, assert_all_equal", @@ -260,34 +328,54 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz assert (clip_starts_seconds.diff() >= 0).all() -@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) -def test_random_sampler_errors(sampler): +from functools import partial + + +@pytest.mark.parametrize( + "sampler", + ( + clips_at_random_indices, + clips_at_regular_indices, + partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + ), +) +def test_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( - ValueError, match=re.escape("num_clips (0) must be strictly positive") + ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") ): - sampler(decoder, num_clips=0) + sampler(decoder, num_frames_per_clip=0) + + with pytest.raises(ValueError, match="Invalid policy"): + sampler(decoder, policy="BAD") with pytest.raises( - ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") + ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") ): - sampler(decoder, num_frames_per_clip=0) + sampler(decoder, sampling_range_start=1000) with pytest.raises( ValueError, - match=re.escape("num_indices_between_frames (0) must be strictly positive"), + match=re.escape( + "sampling_range_start (4) must be smaller than sampling_range_end" + ), ): - sampler(decoder, num_indices_between_frames=0) + sampler(decoder, sampling_range_start=4, sampling_range_end=4) + +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_index_based_samplers_errors(sampler): + decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( - ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") + ValueError, match=re.escape("num_clips (0) must be strictly positive") ): - sampler(decoder, sampling_range_start=1000) + sampler(decoder, num_clips=0) with pytest.raises( - ValueError, match=re.escape("sampling_range_start (4) must be smaller than") + ValueError, + match=re.escape("num_indices_between_frames (0) must be strictly positive"), ): - sampler(decoder, sampling_range_start=4, sampling_range_end=4) + sampler(decoder, num_indices_between_frames=0) with pytest.raises( ValueError, match=re.escape("sampling_range_start (290) must be smaller than") @@ -304,8 +392,56 @@ def test_random_sampler_errors(sampler): sampling_range_end=None, ) - with pytest.raises(ValueError, match="Invalid policy"): - sampler(decoder, policy="BAD") + +def test_time_based_sampler_errors(): + decoder = VideoDecoder(NASA_VIDEO.path) + sampler = partial(clips_at_regular_timestamps, seconds_between_clip_starts=1) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (-1) must be at least 0.0") + ): + sampler(decoder, sampling_range_start=-1) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_end (-1) must be at least 0.0") + ): + sampler(decoder, sampling_range_end=-1) + + with pytest.raises( + ValueError, match=re.escape("seconds_between_clip_starts (-1) must be > 0") + ): + sampler(decoder, seconds_between_clip_starts=-1) + + @contextlib.contextmanager + def restore_metadata(): + # Context manager helper that restores the decoder's metadata to its + # original state upon exit. + try: + original_metadata = copy(decoder.metadata) + yield + finally: + decoder.metadata = original_metadata + + with restore_metadata(): + decoder.metadata.begin_stream_seconds = None + with pytest.raises( + ValueError, match="Could not infer stream end and start from video metadata" + ): + sampler(decoder) + + with restore_metadata(): + decoder.metadata.end_stream_seconds = None + with pytest.raises( + ValueError, match="Could not infer stream end and start from video metadata" + ): + sampler(decoder) + + with restore_metadata(): + decoder.metadata.begin_stream_seconds = None + decoder.metadata.end_stream_seconds = None + decoder.metadata.average_fps_from_header = None + with pytest.raises(ValueError, match="Could not infer average fps"): + sampler(decoder) class TestPolicy: From 0afbb97bcd87aa858ffd3cc081ffcd1a31078d26 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Oct 2024 03:23:24 -0700 Subject: [PATCH 04/15] More tests --- src/torchcodec/samplers/_implem.py | 15 ++- test/samplers/test_samplers.py | 148 +++++++++++++++++++++++++---- 2 files changed, 143 insertions(+), 20 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 33204826e..849063685 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -445,9 +445,12 @@ def _build_all_clips_timestamps( # ... # ] clip_timestamps = torch.full( - size=(num_frames_per_clip,), fill_value=start_seconds + size=(num_frames_per_clip,), fill_value=start_seconds, dtype=torch.float + ) + clip_timestamps += ( + torch.arange(num_frames_per_clip, dtype=torch.float) + * seconds_between_frames ) - clip_timestamps += torch.arange(num_frames_per_clip) * seconds_between_frames # We clip to valid values, so that we can call the same policies as # for index-based samplers. @@ -546,6 +549,14 @@ def clips_at_regular_timestamps( end_stream_seconds=decoder.metadata.end_stream_seconds, ) + # from math import ceil + # num_clips = int(ceil((sampling_range_end - sampling_range_start) // seconds_between_clip_starts)) + # clip_start_seconds = torch.linspace( + # sampling_range_start, + # sampling_range_end - 1e-4, + # steps=num_clips, + # ) + # print(clip_start_seconds) clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 2fca06a90..1eec90ebf 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -2,6 +2,7 @@ import random import re from copy import copy +from functools import partial import pytest import torch @@ -13,7 +14,11 @@ clips_at_regular_indices, clips_at_regular_timestamps, ) -from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS +from torchcodec.samplers._implem import ( + _build_all_clips_indices, + _build_all_clips_timestamps, + _POLICY_FUNCTIONS, +) from ..utils import assert_tensor_equal, NASA_VIDEO @@ -129,20 +134,41 @@ def test_time_based_sampler(seconds_between_frames): ) -@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +# @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize( - "sampling_range_start, sampling_range_end, assert_all_equal", + "sampler, sampling_range_start, sampling_range_end, assert_all_equal", ( - (10, 11, True), - (10, 12, False), + (partial(clips_at_random_indices, num_clips=10), 10, 11, True), + (partial(clips_at_random_indices, num_clips=10), 10, 12, False), + (partial(clips_at_regular_indices, num_clips=10), 10, 11, True), + (partial(clips_at_regular_indices, num_clips=10), 10, 12, False), + ( + partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + 10.0, + 11.0, + True, + ), + ( + partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + 10.0, + 12.0, + False, + ), ), ) def test_sampling_range( sampler, sampling_range_start, sampling_range_end, assert_all_equal ): + # For index-based: # Test the sampling_range_start and sampling_range_end parameters by # asserting that all clips are equal if the sampling range is of size 1, # and that they are not all equal if the sampling range is of size 2. + # + # For time-based: + # The test is similar but with different semantics. We set the sampling + # range to be 1 second or 2 seconds. Since we set seconds_between_clip_start + # to 1 we expect exactly one clip with the sampling range is of size 1, and + # 2 different clips when teh sampling range is 2 seconds. # When size=2 there's still a (small) non-zero probability of sampling the # same indices for clip starts, so we hard-code a seed that works. @@ -152,7 +178,6 @@ def test_sampling_range( clips = sampler( decoder, - num_clips=10, num_frames_per_clip=2, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, @@ -202,10 +227,10 @@ def test_sampling_range_negative(sampler): assert_tensor_equal(clip.data, clips_1[0].data) -@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) -def test_sampling_range_default_behavior(sampler): +def test_sampling_range_default_behavior_random_sampler(): # This is a functional test for the default behavior of the - # sampling_range_end parameter. By default it's None, which means the + # sampling_range_end parameter, for the random sampler. + # By default it's None, which means the # sampler automatically sets its value such that we never sample "beyond" # the number of frames in the video. That means that the last few frames of # the video are less likely to be part of a clip. @@ -216,7 +241,7 @@ def test_sampling_range_default_behavior(sampler): # # In this test we assert that the last clip starts significantly earlier # when sampling_range_end=None than when sampling_range_end=len(decoder). - # This is only a proxy, for lack of better testing oppportunities. + # This is only a proxy, for lack of better testing opportunities. torch.manual_seed(0) @@ -227,18 +252,19 @@ def test_sampling_range_default_behavior(sampler): sampling_range_start = -20 # with default sampling_range_end value - clips_default = sampler( + clips_default = clips_at_random_indices( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, sampling_range_start=sampling_range_start, sampling_range_end=None, + policy="error", ) last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) # with manual sampling_range_end value set to last frame - clips_manual = sampler( + clips_manual = clips_at_random_indices( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -250,15 +276,64 @@ def test_sampling_range_default_behavior(sampler): assert last_clip_start_manual - last_clip_start_default > 0.3 -@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +@pytest.mark.parametrize( + "sampler", + ( + partial( + clips_at_regular_indices, + num_clips=5, + sampling_range_start=-30, + ), + partial( + clips_at_regular_timestamps, + seconds_between_clip_starts=0.02, + sampling_range_start=12.3, + ), + ), +) +def test_sampling_range_default_regular_sampler(sampler): + # For a regular sampler, checking the default behavior of sampling_range_end + # is slightly easier: we can assert that the last frame of the last clip + # *is* the last frame of the video. + # Note that this doesn't always happen. It depends on specific values passed + # for num_clips / seconds_between_clip_starts, and where the sampling range + # starts. We just need to assert that it *can* happen. + + decoder = VideoDecoder(NASA_VIDEO.path) + last_frame = decoder.get_frame_at(len(decoder) - 1) + + clips = sampler(decoder, num_frames_per_clip=5, policy="error") + + last_clip = clips[-1] + assert last_clip.pts_seconds[-1] == last_frame.pts_seconds + assert last_clip.pts_seconds[-2] != last_frame.pts_seconds + + +@pytest.mark.parametrize( + "sampler", + ( + partial( + clips_at_random_indices, sampling_range_start=-1, sampling_range_end=1000 + ), + partial( + clips_at_regular_indices, sampling_range_start=-1, sampling_range_end=1000 + ), + # Note: the hard-coded value of sampling_range_start=12 is because we know + # the NASA_VIDEO is 13.01s seconds long + partial( + clips_at_regular_timestamps, + seconds_between_clip_starts=0.1, + sampling_range_start=13, + sampling_range_end=1000, + ), + ), +) def test_sampling_range_error_policy(sampler): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises(ValueError, match="beyond the number of frames"): sampler( decoder, num_frames_per_clip=10, - sampling_range_start=-1, - sampling_range_end=len(decoder), policy="error", ) @@ -328,9 +403,6 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz assert (clip_starts_seconds.diff() >= 0).all() -from functools import partial - - @pytest.mark.parametrize( "sampler", ( @@ -516,3 +588,43 @@ def test_build_all_clips_indices( assert all(isinstance(index, int) for index in all_clips_indices) assert len(all_clips_indices) == len(clip_start_indices) * NUM_FRAMES_PER_CLIP assert all_clips_indices == expected_all_clips_indices + + +@pytest.mark.parametrize( + "clip_start_seconds, seconds_between_frames, policy, expected_all_clips_timestamps", + ( + ( + [0, 1, 2], # clip_start_seconds + 1.5, # seconds_between_frames + "repeat_last", # policy + # expected_all_clips_seconds = + [0, 1.5, 3, 4.5, 4.5] + [1, 2.5, 4, 4, 4] + [2, 3.5, 3.5, 3.5, 3.5], + # Note how 5 isn't in the last clip, as it's not a seekable pts + # since we set end_stream_seconds=5 + ), + # Same as above with wrap policy + ( + [0, 1, 2], # clip_start_seconds + 1.5, # seconds_between_frames + "wrap", # policy + # expected_all_clips_seconds = + [0, 1.5, 3, 4.5, 0] + [1, 2.5, 4, 1, 2.5] + [2, 3.5, 2, 3.5, 2], + ), + ), +) +def test_build_all_clips_timestamps( + clip_start_seconds, seconds_between_frames, policy, expected_all_clips_timestamps +): + NUM_FRAMES_PER_CLIP = 5 + all_clips_timestamps = _build_all_clips_timestamps( + clip_start_seconds=clip_start_seconds, + num_frames_per_clip=5, + seconds_between_frames=seconds_between_frames, + end_stream_seconds=5.0, + policy_fun=_POLICY_FUNCTIONS[policy], + ) + + assert isinstance(all_clips_timestamps, list) + assert all(isinstance(timestamp, float) for timestamp in all_clips_timestamps) + assert len(all_clips_timestamps) == len(clip_start_seconds) * NUM_FRAMES_PER_CLIP + assert all_clips_timestamps == expected_all_clips_timestamps From 58bbd7f041263d35ba9ff5c771d1c561557cd75f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Oct 2024 03:42:29 -0700 Subject: [PATCH 05/15] Fix mypy --- src/torchcodec/samplers/_implem.py | 43 ++++++++++++++---------------- test/samplers/test_samplers.py | 6 ++--- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 849063685..5c7dac804 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -1,4 +1,4 @@ -from typing import Callable, List, Literal, Optional +from typing import Callable, List, Literal, Optional, Union import torch @@ -90,24 +90,29 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): return num_indices_between_frames * (num_frames_per_clip - 1) + 1 +_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] + + def _repeat_last_policy( - frame_indices: list[int], num_frames_per_clip: int -) -> list[int]: - # frame_indices = [1, 2, 3], num_frames_per_clip = 5 + values: _LIST_OF_INT_OR_FLOAT, desired_len: int +) -> _LIST_OF_INT_OR_FLOAT: + # values = [1, 2, 3], desired_len = 5 # output = [1, 2, 3, 3, 3] - frame_indices += [frame_indices[-1]] * (num_frames_per_clip - len(frame_indices)) - return frame_indices + values += [values[-1]] * (desired_len - len(values)) + return values -def _wrap_policy(frame_indices: list[int], num_frames_per_clip: int) -> list[int]: - # frame_indices = [1, 2, 3], num_frames_per_clip = 5 +def _wrap_policy( + values: _LIST_OF_INT_OR_FLOAT, desired_len: int +) -> _LIST_OF_INT_OR_FLOAT: + # values = [1, 2, 3], desired_len = 5 # output = [1, 2, 3, 1, 2] - return (frame_indices * (num_frames_per_clip // len(frame_indices) + 1))[ - :num_frames_per_clip - ] + return (values * (desired_len // len(values) + 1))[:desired_len] -def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[int]: +def _error_policy( + frames_indices: _LIST_OF_INT_OR_FLOAT, desired_len: int +) -> _LIST_OF_INT_OR_FLOAT: raise ValueError( "You set the 'error' policy, and the sampler tried to decode a frame " "that is beyond the number of frames in the video. " @@ -115,7 +120,7 @@ def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[i ) -_POLICY_FUNCTION_TYPE = Callable[[list[int], int], list[int]] +_POLICY_FUNCTION_TYPE = Callable[[_LIST_OF_INT_OR_FLOAT, int], _LIST_OF_INT_OR_FLOAT] _POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = { "repeat_last": _repeat_last_policy, "wrap": _wrap_policy, @@ -154,7 +159,7 @@ def _build_all_clips_indices( range(start_index, frame_index_upper_bound, num_indices_between_frames) ) if len(frame_indices) < num_frames_per_clip: - frame_indices = policy_fun(frame_indices, num_frames_per_clip) + frame_indices = policy_fun(frame_indices, num_frames_per_clip) # type: ignore[assignment] all_clips_indices += frame_indices return all_clips_indices @@ -434,7 +439,7 @@ def _build_all_clips_timestamps( seconds_between_frames: float, end_stream_seconds: float, policy_fun: _POLICY_FUNCTION_TYPE, -) -> list[int]: +) -> list[float]: all_clips_timestamps: list[float] = [] for start_seconds in clip_start_seconds: @@ -549,14 +554,6 @@ def clips_at_regular_timestamps( end_stream_seconds=decoder.metadata.end_stream_seconds, ) - # from math import ceil - # num_clips = int(ceil((sampling_range_end - sampling_range_start) // seconds_between_clip_starts)) - # clip_start_seconds = torch.linspace( - # sampling_range_start, - # sampling_range_end - 1e-4, - # steps=num_clips, - # ) - # print(clip_start_seconds) clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 1eec90ebf..5474c692e 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -528,13 +528,11 @@ class TestPolicy: ) def test_policy(self, policy, frame_indices, expected_frame_indices): policy_fun = _POLICY_FUNCTIONS[policy] - assert ( - policy_fun(frame_indices, num_frames_per_clip=5) == expected_frame_indices - ) + assert policy_fun(frame_indices, desired_len=5) == expected_frame_indices def test_error_policy(self): with pytest.raises(ValueError, match="beyond the number of frames"): - _POLICY_FUNCTIONS["error"]([1, 2, 3], num_frames_per_clip=5) + _POLICY_FUNCTIONS["error"]([1, 2, 3], desired_len=5) @pytest.mark.parametrize( From 24cebf8e1a95a8c880e99a85244a89750bb0cc20 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Oct 2024 03:48:47 -0700 Subject: [PATCH 06/15] Typo --- src/torchcodec/samplers/_implem.py | 2 +- test/samplers/test_samplers.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 5c7dac804..7107ceca4 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -531,7 +531,7 @@ def clips_at_regular_timestamps( ) -> List[FrameBatch]: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a # clip can start. This is an *open* upper bound, i.e. we will make sure no - # clip start exactly at (or above) sampling_range_end. + # clip starts exactly at (or above) sampling_range_end. _validate_params( decoder=decoder, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 5474c692e..343555a75 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -134,7 +134,6 @@ def test_time_based_sampler(seconds_between_frames): ) -# @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize( "sampler, sampling_range_start, sampling_range_end, assert_all_equal", ( From 5d44147c579583eaa3f45ddfe933555c59997e4d Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Oct 2024 04:48:49 -0700 Subject: [PATCH 07/15] Minor simpification --- src/torchcodec/samplers/_implem.py | 23 +++++++---------------- 1 file changed, 7 insertions(+), 16 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 7107ceca4..8018da690 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -443,23 +443,14 @@ def _build_all_clips_timestamps( all_clips_timestamps: list[float] = [] for start_seconds in clip_start_seconds: - # clip_timestamps = [ - # start_seconds + 0 * seconds_between_frames, - # start_seconds + 1 * seconds_between_frames, - # start_seconds + 2 * seconds_between_frames, - # ... - # ] - clip_timestamps = torch.full( - size=(num_frames_per_clip,), fill_value=start_seconds, dtype=torch.float - ) - clip_timestamps += ( - torch.arange(num_frames_per_clip, dtype=torch.float) - * seconds_between_frames - ) + clip_timestamps = [ + start_seconds + i * seconds_between_frames + for i in range(num_frames_per_clip) + ] + clip_timestamps = [ + timestamp for timestamp in clip_timestamps if timestamp < end_stream_seconds + ] - # We clip to valid values, so that we can call the same policies as - # for index-based samplers. - clip_timestamps = clip_timestamps[clip_timestamps < end_stream_seconds].tolist() if len(clip_timestamps) < num_frames_per_clip: clip_timestamps = policy_fun(clip_timestamps, num_frames_per_clip) all_clips_timestamps += clip_timestamps From da5d3be47794136c8cbfc60d91e9134cad21b760 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 9 Oct 2024 04:59:00 -0700 Subject: [PATCH 08/15] Typo --- test/samplers/test_samplers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 343555a75..ebba8cf65 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -575,7 +575,7 @@ def test_build_all_clips_indices( NUM_FRAMES_PER_CLIP = 5 all_clips_indices = _build_all_clips_indices( clip_start_indices=clip_start_indices, - num_frames_per_clip=5, + num_frames_per_clip=NUM_FRAMES_PER_CLIP, num_indices_between_frames=num_indices_between_frames, num_frames_in_video=5, policy_fun=_POLICY_FUNCTIONS[policy], @@ -615,7 +615,7 @@ def test_build_all_clips_timestamps( NUM_FRAMES_PER_CLIP = 5 all_clips_timestamps = _build_all_clips_timestamps( clip_start_seconds=clip_start_seconds, - num_frames_per_clip=5, + num_frames_per_clip=NUM_FRAMES_PER_CLIP, seconds_between_frames=seconds_between_frames, end_stream_seconds=5.0, policy_fun=_POLICY_FUNCTIONS[policy], From 6550af932359f4589833592934f5729e0be39a42 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Oct 2024 01:37:55 -0700 Subject: [PATCH 09/15] Fix _assert_regular_sampler --- test/samplers/test_samplers.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index ebba8cf65..1aead0fee 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -43,19 +43,23 @@ def _assert_regular_sampler(clips, expected_seconds_between_clip_starts=None): seconds_between_clip_starts = torch.tensor( [clip.pts_seconds[0] for clip in clips] ).diff() - for diff in seconds_between_clip_starts: - # Note: need approximate check as actual diff values typically look like - # [3.2032, 3.2366, 3.2366, 3.2366] - assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) if expected_seconds_between_clip_starts is not None: # This can only be asserted with the time-based sampler, where # seconds_between_clip_starts is known since it's passed by the user. torch.testing.assert_close( - diff, expected_seconds_between_clip_starts, atol=0.01, rtol=0 + seconds_between_clip_starts, + expected_seconds_between_clip_starts, + atol=0.05, + rtol=0, ) - assert (diff > 0).all() # Also assert clips are sorted by start time + # Also assert clips are sorted by start time + assert (seconds_between_clip_starts > 0).all() + # For all samplers, we can at least assert that seconds_between_clip_starts + # is constant. + for diff in seconds_between_clip_starts: + assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @@ -116,7 +120,7 @@ def test_time_based_sampler(seconds_between_frames): ) expected_seconds_between_clip_starts = torch.tensor( - seconds_between_clip_starts, dtype=torch.float + [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float ) _assert_regular_sampler( clips=clips, From bdf08a1a403c138c467cee2e9e8da8b5305c7e64 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Oct 2024 02:02:20 -0700 Subject: [PATCH 10/15] Address remaining comments --- src/torchcodec/samplers/_implem.py | 139 +++++++++++++---------------- 1 file changed, 64 insertions(+), 75 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 8018da690..954231408 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -6,6 +6,60 @@ from torchcodec.decoders import VideoDecoder +def _chunk_list(lst, chunk_size): + # return list of sublists of length chunk_size + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] + + +def _to_framebatch(frames: list[Frame]) -> FrameBatch: + # IMPORTANT: see other IMPORTANT note in _decode_all_clips_indices and + # _decode_all_clips_timestamps + data = torch.stack([frame.data for frame in frames]) + pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) + duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) + return FrameBatch( + data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds + ) + + +_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] + + +def _repeat_last_policy( + values: _LIST_OF_INT_OR_FLOAT, desired_len: int +) -> _LIST_OF_INT_OR_FLOAT: + # values = [1, 2, 3], desired_len = 5 + # output = [1, 2, 3, 3, 3] + values += [values[-1]] * (desired_len - len(values)) + return values + + +def _wrap_policy( + values: _LIST_OF_INT_OR_FLOAT, desired_len: int +) -> _LIST_OF_INT_OR_FLOAT: + # values = [1, 2, 3], desired_len = 5 + # output = [1, 2, 3, 1, 2] + return (values * (desired_len // len(values) + 1))[:desired_len] + + +def _error_policy( + frames_indices: _LIST_OF_INT_OR_FLOAT, desired_len: int +) -> _LIST_OF_INT_OR_FLOAT: + raise ValueError( + "You set the 'error' policy, and the sampler tried to decode a frame " + "that is beyond the number of frames in the video. " + "Try to leave sampling_range_end to its default value?" + ) + + +_POLICY_FUNCTION_TYPE = Callable[[_LIST_OF_INT_OR_FLOAT, int], _LIST_OF_INT_OR_FLOAT] +_POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = { + "repeat_last": _repeat_last_policy, + "wrap": _wrap_policy, + "error": _error_policy, +} + + def _validate_params(*, decoder, num_frames_per_clip, policy): if len(decoder) < 1: raise ValueError( @@ -90,44 +144,6 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): return num_indices_between_frames * (num_frames_per_clip - 1) + 1 -_LIST_OF_INT_OR_FLOAT = Union[list[int], list[float]] - - -def _repeat_last_policy( - values: _LIST_OF_INT_OR_FLOAT, desired_len: int -) -> _LIST_OF_INT_OR_FLOAT: - # values = [1, 2, 3], desired_len = 5 - # output = [1, 2, 3, 3, 3] - values += [values[-1]] * (desired_len - len(values)) - return values - - -def _wrap_policy( - values: _LIST_OF_INT_OR_FLOAT, desired_len: int -) -> _LIST_OF_INT_OR_FLOAT: - # values = [1, 2, 3], desired_len = 5 - # output = [1, 2, 3, 1, 2] - return (values * (desired_len // len(values) + 1))[:desired_len] - - -def _error_policy( - frames_indices: _LIST_OF_INT_OR_FLOAT, desired_len: int -) -> _LIST_OF_INT_OR_FLOAT: - raise ValueError( - "You set the 'error' policy, and the sampler tried to decode a frame " - "that is beyond the number of frames in the video. " - "Try to leave sampling_range_end to its default value?" - ) - - -_POLICY_FUNCTION_TYPE = Callable[[_LIST_OF_INT_OR_FLOAT, int], _LIST_OF_INT_OR_FLOAT] -_POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = { - "repeat_last": _repeat_last_policy, - "wrap": _wrap_policy, - "error": _error_policy, -} - - def _build_all_clips_indices( *, clip_start_indices: torch.Tensor, # 1D int tensor @@ -177,20 +193,7 @@ def _decode_all_clips_indices( # - decode all unique frames in sorted order # - re-assemble the decoded frames back to their original order # - # TODO: Write this in C++ so we can avoid the copies that happen in `to_framebatch` - - def chunk_list(lst, chunk_size): - # return list of sublists of length chunk_size - return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] - - def to_framebatch(frames: list[Frame]) -> FrameBatch: - # IMPORTANT: see other IMPORTANT note below - data = torch.stack([frame.data for frame in frames]) - pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) - duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) - return FrameBatch( - data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds - ) + # TODO: Write this in C++ so we can avoid the copies that happen in `_to_framebatch` all_clips_indices_sorted, argsort = zip( *sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices)) @@ -205,7 +208,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: ): # Avoid decoding the same frame twice. # IMPORTANT: this is only correct because a copy of the frame will - # happen within `to_framebatch` when we call torch.stack. + # happen within `_to_framebatch` when we call torch.stack. # If a copy isn't made, the same underlying memory will be used for # the 2 consecutive frames. When we re-write this, we should make # sure to explicitly copy the data. @@ -215,11 +218,11 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: previous_decoded_frame = decoded_frame all_decoded_frames[j] = decoded_frame - all_clips: list[list[Frame]] = chunk_list( + all_clips: list[list[Frame]] = _chunk_list( all_decoded_frames, chunk_size=num_frames_per_clip ) - return [to_framebatch(clip) for clip in all_clips] + return [_to_framebatch(clip) for clip in all_clips] def _generic_index_based_sampler( @@ -444,11 +447,10 @@ def _build_all_clips_timestamps( all_clips_timestamps: list[float] = [] for start_seconds in clip_start_seconds: clip_timestamps = [ - start_seconds + i * seconds_between_frames + timestamp for i in range(num_frames_per_clip) - ] - clip_timestamps = [ - timestamp for timestamp in clip_timestamps if timestamp < end_stream_seconds + if (timestamp := start_seconds + i * seconds_between_frames) + < end_stream_seconds ] if len(clip_timestamps) < num_frames_per_clip: @@ -464,19 +466,6 @@ def _decode_all_clips_timestamps( # This is 99% the same as _decode_all_clips_indices. The only change is the # call to .get_frame_displayed_at(pts) instead of .get_frame_at(idx) - def chunk_list(lst, chunk_size): - # return list of sublists of length chunk_size - return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] - - def to_framebatch(frames: list[Frame]) -> FrameBatch: - # IMPORTANT: see other IMPORTANT note below - data = torch.stack([frame.data for frame in frames]) - pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) - duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) - return FrameBatch( - data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds - ) - all_clips_timestamps_sorted, argsort = zip( *sorted( (frame_index, i) for (i, frame_index) in enumerate(all_clips_timestamps) @@ -492,7 +481,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: ): # Avoid decoding the same frame twice. # IMPORTANT: this is only correct because a copy of the frame will - # happen within `to_framebatch` when we call torch.stack. + # happen within `_to_framebatch` when we call torch.stack. # If a copy isn't made, the same underlying memory will be used for # the 2 consecutive frames. When we re-write this, we should make # sure to explicitly copy the data. @@ -502,11 +491,11 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: previous_decoded_frame = decoded_frame all_decoded_frames[j] = decoded_frame - all_clips: list[list[Frame]] = chunk_list( + all_clips: list[list[Frame]] = _chunk_list( all_decoded_frames, chunk_size=num_frames_per_clip ) - return [to_framebatch(clip) for clip in all_clips] + return [_to_framebatch(clip) for clip in all_clips] def clips_at_regular_timestamps( From db48a87f9fd9b864367d4133ca7c91d9c51a8556 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Oct 2024 02:48:51 -0700 Subject: [PATCH 11/15] Add random time-based sampler --- src/torchcodec/samplers/__init__.py | 1 + src/torchcodec/samplers/_implem.py | 96 ++++++++++++++++++++---- test/samplers/test_samplers.py | 110 +++++++++++++++++++--------- 3 files changed, 160 insertions(+), 47 deletions(-) diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py index 66cd9c91d..8616ae35b 100644 --- a/src/torchcodec/samplers/__init__.py +++ b/src/torchcodec/samplers/__init__.py @@ -1,5 +1,6 @@ from ._implem import ( clips_at_random_indices, + clips_at_random_timestamps, clips_at_regular_indices, clips_at_regular_timestamps, ) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 954231408..6860f82a0 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -78,7 +78,7 @@ def _validate_params(*, decoder, num_frames_per_clip, policy): def _validate_params_index_based(*, num_clips, num_indices_between_frames): if num_clips <= 0: - raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + raise ValueError(f"num_clips ({num_clips}) must be > 0") if num_indices_between_frames <= 0: raise ValueError( @@ -339,14 +339,25 @@ def clips_at_regular_indices( def _validate_params_time_based( *, decoder, + num_clips, seconds_between_clip_starts, seconds_between_frames, ): - if seconds_between_clip_starts <= 0: + + if (num_clips is None and seconds_between_clip_starts is None) or ( + num_clips is not None and seconds_between_clip_starts is not None + ): + # This is internal only and should never happen + raise ValueError("Bad, bad programmer!") + + if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0: raise ValueError( f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0" ) + if num_clips is not None and num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be > 0") + if decoder.metadata.average_fps is None: raise ValueError( "Could not infer average fps from video metadata. " @@ -498,15 +509,17 @@ def _decode_all_clips_timestamps( return [_to_framebatch(clip) for clip in all_clips] -def clips_at_regular_timestamps( +def _generic_time_based_sampler( + kind: Literal["random", "regular"], decoder, *, - seconds_between_clip_starts: float, - num_frames_per_clip: int = 1, - seconds_between_frames: Optional[float] = None, + num_clips: Optional[int], # mutually exclusive with seconds_between_clip_starts + seconds_between_clip_starts: Optional[float], + num_frames_per_clip: int, + seconds_between_frames: Optional[float], # None means "begining", which may not always be 0 - sampling_range_start: Optional[float] = None, - sampling_range_end: Optional[float] = None, # interval is [start, end). + sampling_range_start: Optional[float], + sampling_range_end: Optional[float], # interval is [start, end). policy: str = "repeat_last", ) -> List[FrameBatch]: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a @@ -521,6 +534,7 @@ def clips_at_regular_timestamps( seconds_between_frames = _validate_params_time_based( decoder=decoder, + num_clips=num_clips, seconds_between_clip_starts=seconds_between_clip_starts, seconds_between_frames=seconds_between_frames, ) @@ -534,11 +548,19 @@ def clips_at_regular_timestamps( end_stream_seconds=decoder.metadata.end_stream_seconds, ) - clip_start_seconds = torch.arange( - sampling_range_start, - sampling_range_end, # excluded - seconds_between_clip_starts, - ) + if kind == "random": + sampling_range_width = sampling_range_end - sampling_range_start + # torch.rand() returns in [0, 1) + # which ensures all clip starts are < sampling_range_end + clip_start_seconds = ( + torch.rand(num_clips) * sampling_range_width + sampling_range_start + ) + else: + clip_start_seconds = torch.arange( + sampling_range_start, + sampling_range_end, # excluded + seconds_between_clip_starts, + ) all_clips_timestamps = _build_all_clips_timestamps( clip_start_seconds=clip_start_seconds, @@ -553,3 +575,51 @@ def clips_at_regular_timestamps( all_clips_timestamps=all_clips_timestamps, num_frames_per_clip=num_frames_per_clip, ) + + +def clips_at_random_timestamps( + decoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + seconds_between_frames: Optional[float] = None, + # None means "begining", which may not always be 0 + sampling_range_start: Optional[float] = None, + sampling_range_end: Optional[float] = None, # interval is [start, end). + policy: str = "repeat_last", +) -> List[FrameBatch]: + return _generic_time_based_sampler( + kind="random", + decoder=decoder, + num_clips=num_clips, + seconds_between_clip_starts=None, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + policy=policy, + ) + + +def clips_at_regular_timestamps( + decoder, + *, + seconds_between_clip_starts: float, + num_frames_per_clip: int = 1, + seconds_between_frames: Optional[float] = None, + # None means "begining", which may not always be 0 + sampling_range_start: Optional[float] = None, + sampling_range_end: Optional[float] = None, # interval is [start, end). + policy: str = "repeat_last", +) -> List[FrameBatch]: + return _generic_time_based_sampler( + kind="regular", + decoder=decoder, + num_clips=None, + seconds_between_clip_starts=seconds_between_clip_starts, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + policy=policy, + ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 1aead0fee..1bafdff78 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -11,6 +11,7 @@ from torchcodec.decoders import VideoDecoder from torchcodec.samplers import ( clips_at_random_indices, + clips_at_random_timestamps, clips_at_regular_indices, clips_at_regular_timestamps, ) @@ -98,34 +99,45 @@ def test_index_based_sampler(sampler, num_indices_between_frames): ) +@pytest.mark.parametrize( + "sampler", + ( + partial(clips_at_random_timestamps, num_clips=5), + partial(clips_at_regular_timestamps, seconds_between_clip_starts=2), + ), +) @pytest.mark.parametrize("seconds_between_frames", [None, 3]) -def test_time_based_sampler(seconds_between_frames): +def test_time_based_sampler(sampler, seconds_between_frames): decoder = VideoDecoder(NASA_VIDEO.path) num_frames_per_clip = 3 - seconds_between_clip_starts = 2 - clips = clips_at_regular_timestamps( + clips = sampler( decoder, - seconds_between_clip_starts=seconds_between_clip_starts, num_frames_per_clip=num_frames_per_clip, seconds_between_frames=seconds_between_frames, ) - expeted_num_clips = len(clips) # no-op check, it's just hard to assert + expected_num_clips = ( + len(clips) # No-op check, we can't assert with regular sampler + if sampler.func is clips_at_regular_timestamps + else sampler.keywords["num_clips"] + ) _assert_output_type_and_shapes( video=NASA_VIDEO, clips=clips, - expected_num_clips=expeted_num_clips, + expected_num_clips=expected_num_clips, num_frames_per_clip=num_frames_per_clip, ) - expected_seconds_between_clip_starts = torch.tensor( - [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float - ) - _assert_regular_sampler( - clips=clips, - expected_seconds_between_clip_starts=expected_seconds_between_clip_starts, - ) + if sampler.func is clips_at_regular_timestamps: + seconds_between_clip_starts = sampler.keywords["seconds_between_clip_starts"] + expected_seconds_between_clip_starts = torch.tensor( + [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float + ) + _assert_regular_sampler( + clips=clips, + expected_seconds_between_clip_starts=expected_seconds_between_clip_starts, + ) expected_seconds_between_frames = ( seconds_between_frames or 1 / decoder.metadata.average_fps @@ -157,6 +169,8 @@ def test_time_based_sampler(seconds_between_frames): 12.0, False, ), + (partial(clips_at_random_indices, num_clips=10), 10, 11, True), + (partial(clips_at_random_indices, num_clips=10), 10, 12, False), ), ) def test_sampling_range( @@ -169,9 +183,10 @@ def test_sampling_range( # # For time-based: # The test is similar but with different semantics. We set the sampling - # range to be 1 second or 2 seconds. Since we set seconds_between_clip_start - # to 1 we expect exactly one clip with the sampling range is of size 1, and - # 2 different clips when teh sampling range is 2 seconds. + # range to be 1 second or 2 seconds. Since we set + # seconds_between_clip_starts to 1 we expect exactly one clip with the + # sampling range is of size 1, and 2 different clips when teh sampling range + # is 2 seconds. # When size=2 there's still a (small) non-zero probability of sampling the # same indices for clip starts, so we hard-code a seed that works. @@ -230,7 +245,14 @@ def test_sampling_range_negative(sampler): assert_tensor_equal(clip.data, clips_1[0].data) -def test_sampling_range_default_behavior_random_sampler(): +@pytest.mark.parametrize( + "sampler", + ( + clips_at_random_indices, + clips_at_random_timestamps, + ), +) +def test_sampling_range_default_behavior_random_sampler(sampler): # This is a functional test for the default behavior of the # sampling_range_end parameter, for the random sampler. # By default it's None, which means the @@ -252,10 +274,10 @@ def test_sampling_range_default_behavior_random_sampler(): num_clips = 20 num_frames_per_clip = 15 - sampling_range_start = -20 + sampling_range_start = -20 if sampler is clips_at_random_indices else 11.0 # with default sampling_range_end value - clips_default = clips_at_random_indices( + clips_default = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -267,7 +289,7 @@ def test_sampling_range_default_behavior_random_sampler(): last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) # with manual sampling_range_end value set to last frame - clips_manual = clips_at_random_indices( + clips_manual = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -321,7 +343,7 @@ def test_sampling_range_default_regular_sampler(sampler): partial( clips_at_regular_indices, sampling_range_start=-1, sampling_range_end=1000 ), - # Note: the hard-coded value of sampling_range_start=12 is because we know + # Note: the hard-coded value of sampling_range_start=13 is because we know # the NASA_VIDEO is 13.01s seconds long partial( clips_at_regular_timestamps, @@ -329,6 +351,11 @@ def test_sampling_range_default_regular_sampler(sampler): sampling_range_start=13, sampling_range_end=1000, ), + partial( + clips_at_random_timestamps, + sampling_range_start=13, + sampling_range_end=1000, + ), ), ) def test_sampling_range_error_policy(sampler): @@ -341,14 +368,17 @@ def test_sampling_range_error_policy(sampler): ) -def test_random_sampler_randomness(): +@pytest.mark.parametrize( + "sampler", (clips_at_random_indices, clips_at_random_timestamps) +) +def test_random_sampler_randomness(sampler): decoder = VideoDecoder(NASA_VIDEO.path) num_clips = 5 builtin_random_state_start = random.getstate() torch.manual_seed(0) - clips_1 = clips_at_random_indices(decoder, num_clips=num_clips) + clips_1 = sampler(decoder, num_clips=num_clips) # Assert the clip starts aren't sorted, to make sure we haven't messed up # the implementation. (This may fail if we're unlucky, but we hard-coded a @@ -358,7 +388,7 @@ def test_random_sampler_randomness(): # Call the same sampler again with the same seed, expect same results torch.manual_seed(0) - clips_2 = clips_at_random_indices(decoder, num_clips=num_clips) + clips_2 = sampler(decoder, num_clips=num_clips) for clip_1, clip_2 in zip(clips_1, clips_2): assert_tensor_equal(clip_1.data, clip_2.data) assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) @@ -366,7 +396,7 @@ def test_random_sampler_randomness(): # Call with a different seed, expect different results torch.manual_seed(1) - clips_3 = clips_at_random_indices(decoder, num_clips=num_clips) + clips_3 = sampler(decoder, num_clips=num_clips) with pytest.raises(AssertionError, match="Tensor-likes are not"): assert_tensor_equal(clips_1[0].data, clips_3[0].data) @@ -412,6 +442,7 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz clips_at_random_indices, clips_at_regular_indices, partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + clips_at_random_timestamps, ), ) def test_sampler_errors(sampler): @@ -441,9 +472,7 @@ def test_sampler_errors(sampler): @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) def test_index_based_samplers_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) - with pytest.raises( - ValueError, match=re.escape("num_clips (0) must be strictly positive") - ): + with pytest.raises(ValueError, match=re.escape("num_clips (0) must be > 0")): sampler(decoder, num_clips=0) with pytest.raises( @@ -468,9 +497,15 @@ def test_index_based_samplers_errors(sampler): ) -def test_time_based_sampler_errors(): +@pytest.mark.parametrize( + "sampler", + ( + clips_at_random_timestamps, + partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + ), +) +def test_time_based_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) - sampler = partial(clips_at_regular_timestamps, seconds_between_clip_starts=1) with pytest.raises( ValueError, match=re.escape("sampling_range_start (-1) must be at least 0.0") @@ -482,10 +517,17 @@ def test_time_based_sampler_errors(): ): sampler(decoder, sampling_range_end=-1) - with pytest.raises( - ValueError, match=re.escape("seconds_between_clip_starts (-1) must be > 0") - ): - sampler(decoder, seconds_between_clip_starts=-1) + if sampler is clips_at_random_timestamps: + with pytest.raises( + ValueError, + match=re.escape("num_clips (0) must be > 0"), + ): + sampler(decoder, num_clips=0) + else: + with pytest.raises( + ValueError, match=re.escape("seconds_between_clip_starts (-1) must be > 0") + ): + sampler(decoder, seconds_between_clip_starts=-1) @contextlib.contextmanager def restore_metadata(): From 29e4d94539a45a6c61d6b804dd7a4b4c074a97ad Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Oct 2024 02:55:57 -0700 Subject: [PATCH 12/15] fix mypy --- src/torchcodec/samplers/_implem.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 6860f82a0..31416b25d 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -549,6 +549,7 @@ def _generic_time_based_sampler( ) if kind == "random": + assert num_clips is not None # appease type-checker sampling_range_width = sampling_range_end - sampling_range_start # torch.rand() returns in [0, 1) # which ensures all clip starts are < sampling_range_end @@ -556,6 +557,7 @@ def _generic_time_based_sampler( torch.rand(num_clips) * sampling_range_width + sampling_range_start ) else: + assert seconds_between_clip_starts is not None # appease type-checker clip_start_seconds = torch.arange( sampling_range_start, sampling_range_end, # excluded From 4ef46f6c8246e8d8916f9f3db9cbd8c5d737e11f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Oct 2024 03:08:44 -0700 Subject: [PATCH 13/15] Fix --- test/samplers/test_samplers.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 8c9458f2a..04c57bff6 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -274,10 +274,10 @@ def test_sampling_range_default_behavior_random_sampler(sampler): num_clips = 20 num_frames_per_clip = 15 - sampling_range_start = -20 if sampler is clips_at_random_indices else 11.0 + sampling_range_start = -20 if sampler is clips_at_random_indices else 11 # with default sampling_range_end value - clips_default = clips_at_random_indices( + clips_default = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -288,13 +288,13 @@ def test_sampling_range_default_behavior_random_sampler(sampler): last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) - # with manual sampling_range_end value set to last frame - clips_manual = clips_at_random_indices( + # with manual sampling_range_end value set to last frame / end of video + clips_manual = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, sampling_range_start=sampling_range_start, - sampling_range_end=len(decoder), + sampling_range_end=1000, ) last_clip_start_manual = max([clip.pts_seconds[0] for clip in clips_manual]) From ff97ea71095a0d982943c70ccb00b9a98130d812 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Thu, 10 Oct 2024 04:34:14 -0700 Subject: [PATCH 14/15] Update benchmark --- benchmarks/samplers/benchmark_samplers.py | 55 +++++++++++++++++++---- src/torchcodec/samplers/_implem.py | 7 +++ 2 files changed, 53 insertions(+), 9 deletions(-) diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index 050466c16..c8c5a50d0 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -3,7 +3,12 @@ import torch from torchcodec.decoders import VideoDecoder -from torchcodec.samplers import clips_at_random_indices +from torchcodec.samplers import ( + clips_at_random_indices, + clips_at_random_timestamps, + clips_at_regular_indices, + clips_at_regular_timestamps, +) def bench(f, *args, num_exp=100, warmup=0, **kwargs): @@ -34,19 +39,51 @@ def report_stats(times, unit="ms"): return med -def sample(num_clips): +def sample(sampler, **kwargs): decoder = VideoDecoder(VIDEO_PATH) - clips_at_random_indices( + sampler( decoder, - num_clips=num_clips, num_frames_per_clip=10, - num_indices_between_frames=2, + **kwargs, ) VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" +NUM_EXP = 30 + +for num_clips in (1, 50): + print("-" * 10) + print(f"{num_clips = }") + + print("clips_at_random_indices ", end="") + times = bench( + sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2 + ) + report_stats(times, unit="ms") + + print("clips_at_regular_indices ", end="") + times = bench( + sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2 + ) + report_stats(times, unit="ms") -times = bench(sample, num_clips=1, num_exp=30, warmup=2) -report_stats(times, unit="ms") -times = bench(sample, num_clips=50, num_exp=30, warmup=2) -report_stats(times, unit="ms") + print("clips_at_random_timestamps ", end="") + times = bench( + sample, + clips_at_random_timestamps, + num_clips=num_clips, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, unit="ms") + + print("clips_at_regular_timestamps ", end="") + seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long + times = bench( + sample, + clips_at_regular_timestamps, + seconds_between_clip_starts=seconds_between_clip_starts, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, unit="ms") diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 31416b25d..d10f770ed 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -491,6 +491,13 @@ def _decode_all_clips_timestamps( and frame_pts_seconds == all_clips_timestamps_sorted[i - 1] ): # Avoid decoding the same frame twice. + # Unfortunatly this is unlikely to lead to speed-up as-is: it's + # pretty unlikely that 2 pts will be the same since pts are float + # contiguous values. Theoretically the dedup can still happen, but + # it would be much more efficient to implement it at the frame index + # level. We should do that once we implement that in C++. + # See also https://github.com/pytorch/torchcodec/issues/256. + # # IMPORTANT: this is only correct because a copy of the frame will # happen within `_to_framebatch` when we call torch.stack. # If a copy isn't made, the same underlying memory will be used for From 61832347f896f846363e8340c34398309ef08a56 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 11 Oct 2024 15:43:48 +0100 Subject: [PATCH 15/15] Address comments --- src/torchcodec/samplers/_implem.py | 3 +-- test/samplers/test_samplers.py | 5 +++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index d10f770ed..c0dbbf208 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -347,8 +347,7 @@ def _validate_params_time_based( if (num_clips is None and seconds_between_clip_starts is None) or ( num_clips is not None and seconds_between_clip_starts is not None ): - # This is internal only and should never happen - raise ValueError("Bad, bad programmer!") + raise ValueError("This is internal only and should never happen.") if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0: raise ValueError( diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 04c57bff6..8bf252240 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -185,7 +185,7 @@ def test_sampling_range( # The test is similar but with different semantics. We set the sampling # range to be 1 second or 2 seconds. Since we set # seconds_between_clip_starts to 1 we expect exactly one clip with the - # sampling range is of size 1, and 2 different clips when teh sampling range + # sampling range is of size 1, and 2 different clips when the sampling range # is 2 seconds. # When size=2 there's still a (small) non-zero probability of sampling the @@ -344,7 +344,8 @@ def test_sampling_range_default_regular_sampler(sampler): clips_at_regular_indices, sampling_range_start=-1, sampling_range_end=1000 ), # Note: the hard-coded value of sampling_range_start=13 is because we know - # the NASA_VIDEO is 13.01s seconds long + # the NASA_VIDEO is ~13.01s seconds long. We just need to clip to start + # on, or close to the last frame. partial( clips_at_regular_timestamps, seconds_between_clip_starts=0.1,