diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index 050466c16..c8c5a50d0 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -3,7 +3,12 @@ import torch from torchcodec.decoders import VideoDecoder -from torchcodec.samplers import clips_at_random_indices +from torchcodec.samplers import ( + clips_at_random_indices, + clips_at_random_timestamps, + clips_at_regular_indices, + clips_at_regular_timestamps, +) def bench(f, *args, num_exp=100, warmup=0, **kwargs): @@ -34,19 +39,51 @@ def report_stats(times, unit="ms"): return med -def sample(num_clips): +def sample(sampler, **kwargs): decoder = VideoDecoder(VIDEO_PATH) - clips_at_random_indices( + sampler( decoder, - num_clips=num_clips, num_frames_per_clip=10, - num_indices_between_frames=2, + **kwargs, ) VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" +NUM_EXP = 30 + +for num_clips in (1, 50): + print("-" * 10) + print(f"{num_clips = }") + + print("clips_at_random_indices ", end="") + times = bench( + sample, clips_at_random_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2 + ) + report_stats(times, unit="ms") + + print("clips_at_regular_indices ", end="") + times = bench( + sample, clips_at_regular_indices, num_clips=num_clips, num_exp=NUM_EXP, warmup=2 + ) + report_stats(times, unit="ms") -times = bench(sample, num_clips=1, num_exp=30, warmup=2) -report_stats(times, unit="ms") -times = bench(sample, num_clips=50, num_exp=30, warmup=2) -report_stats(times, unit="ms") + print("clips_at_random_timestamps ", end="") + times = bench( + sample, + clips_at_random_timestamps, + num_clips=num_clips, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, unit="ms") + + print("clips_at_regular_timestamps ", end="") + seconds_between_clip_starts = 13 / num_clips # approximate. video is 13s long + times = bench( + sample, + clips_at_regular_timestamps, + seconds_between_clip_starts=seconds_between_clip_starts, + num_exp=NUM_EXP, + warmup=2, + ) + report_stats(times, unit="ms") diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py index 66cd9c91d..8616ae35b 100644 --- a/src/torchcodec/samplers/__init__.py +++ b/src/torchcodec/samplers/__init__.py @@ -1,5 +1,6 @@ from ._implem import ( clips_at_random_indices, + clips_at_random_timestamps, clips_at_regular_indices, clips_at_regular_timestamps, ) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 954231408..c0dbbf208 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -78,7 +78,7 @@ def _validate_params(*, decoder, num_frames_per_clip, policy): def _validate_params_index_based(*, num_clips, num_indices_between_frames): if num_clips <= 0: - raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + raise ValueError(f"num_clips ({num_clips}) must be > 0") if num_indices_between_frames <= 0: raise ValueError( @@ -339,14 +339,24 @@ def clips_at_regular_indices( def _validate_params_time_based( *, decoder, + num_clips, seconds_between_clip_starts, seconds_between_frames, ): - if seconds_between_clip_starts <= 0: + + if (num_clips is None and seconds_between_clip_starts is None) or ( + num_clips is not None and seconds_between_clip_starts is not None + ): + raise ValueError("This is internal only and should never happen.") + + if seconds_between_clip_starts is not None and seconds_between_clip_starts <= 0: raise ValueError( f"seconds_between_clip_starts ({seconds_between_clip_starts}) must be > 0" ) + if num_clips is not None and num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be > 0") + if decoder.metadata.average_fps is None: raise ValueError( "Could not infer average fps from video metadata. " @@ -480,6 +490,13 @@ def _decode_all_clips_timestamps( and frame_pts_seconds == all_clips_timestamps_sorted[i - 1] ): # Avoid decoding the same frame twice. + # Unfortunatly this is unlikely to lead to speed-up as-is: it's + # pretty unlikely that 2 pts will be the same since pts are float + # contiguous values. Theoretically the dedup can still happen, but + # it would be much more efficient to implement it at the frame index + # level. We should do that once we implement that in C++. + # See also https://github.com/pytorch/torchcodec/issues/256. + # # IMPORTANT: this is only correct because a copy of the frame will # happen within `_to_framebatch` when we call torch.stack. # If a copy isn't made, the same underlying memory will be used for @@ -498,15 +515,17 @@ def _decode_all_clips_timestamps( return [_to_framebatch(clip) for clip in all_clips] -def clips_at_regular_timestamps( +def _generic_time_based_sampler( + kind: Literal["random", "regular"], decoder, *, - seconds_between_clip_starts: float, - num_frames_per_clip: int = 1, - seconds_between_frames: Optional[float] = None, + num_clips: Optional[int], # mutually exclusive with seconds_between_clip_starts + seconds_between_clip_starts: Optional[float], + num_frames_per_clip: int, + seconds_between_frames: Optional[float], # None means "begining", which may not always be 0 - sampling_range_start: Optional[float] = None, - sampling_range_end: Optional[float] = None, # interval is [start, end). + sampling_range_start: Optional[float], + sampling_range_end: Optional[float], # interval is [start, end). policy: str = "repeat_last", ) -> List[FrameBatch]: # Note: *everywhere*, sampling_range_end denotes the upper bound of where a @@ -521,6 +540,7 @@ def clips_at_regular_timestamps( seconds_between_frames = _validate_params_time_based( decoder=decoder, + num_clips=num_clips, seconds_between_clip_starts=seconds_between_clip_starts, seconds_between_frames=seconds_between_frames, ) @@ -534,11 +554,21 @@ def clips_at_regular_timestamps( end_stream_seconds=decoder.metadata.end_stream_seconds, ) - clip_start_seconds = torch.arange( - sampling_range_start, - sampling_range_end, # excluded - seconds_between_clip_starts, - ) + if kind == "random": + assert num_clips is not None # appease type-checker + sampling_range_width = sampling_range_end - sampling_range_start + # torch.rand() returns in [0, 1) + # which ensures all clip starts are < sampling_range_end + clip_start_seconds = ( + torch.rand(num_clips) * sampling_range_width + sampling_range_start + ) + else: + assert seconds_between_clip_starts is not None # appease type-checker + clip_start_seconds = torch.arange( + sampling_range_start, + sampling_range_end, # excluded + seconds_between_clip_starts, + ) all_clips_timestamps = _build_all_clips_timestamps( clip_start_seconds=clip_start_seconds, @@ -553,3 +583,51 @@ def clips_at_regular_timestamps( all_clips_timestamps=all_clips_timestamps, num_frames_per_clip=num_frames_per_clip, ) + + +def clips_at_random_timestamps( + decoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + seconds_between_frames: Optional[float] = None, + # None means "begining", which may not always be 0 + sampling_range_start: Optional[float] = None, + sampling_range_end: Optional[float] = None, # interval is [start, end). + policy: str = "repeat_last", +) -> List[FrameBatch]: + return _generic_time_based_sampler( + kind="random", + decoder=decoder, + num_clips=num_clips, + seconds_between_clip_starts=None, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + policy=policy, + ) + + +def clips_at_regular_timestamps( + decoder, + *, + seconds_between_clip_starts: float, + num_frames_per_clip: int = 1, + seconds_between_frames: Optional[float] = None, + # None means "begining", which may not always be 0 + sampling_range_start: Optional[float] = None, + sampling_range_end: Optional[float] = None, # interval is [start, end). + policy: str = "repeat_last", +) -> List[FrameBatch]: + return _generic_time_based_sampler( + kind="regular", + decoder=decoder, + num_clips=None, + seconds_between_clip_starts=seconds_between_clip_starts, + num_frames_per_clip=num_frames_per_clip, + seconds_between_frames=seconds_between_frames, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + policy=policy, + ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 1aead0fee..8bf252240 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -11,6 +11,7 @@ from torchcodec.decoders import VideoDecoder from torchcodec.samplers import ( clips_at_random_indices, + clips_at_random_timestamps, clips_at_regular_indices, clips_at_regular_timestamps, ) @@ -98,34 +99,45 @@ def test_index_based_sampler(sampler, num_indices_between_frames): ) +@pytest.mark.parametrize( + "sampler", + ( + partial(clips_at_random_timestamps, num_clips=5), + partial(clips_at_regular_timestamps, seconds_between_clip_starts=2), + ), +) @pytest.mark.parametrize("seconds_between_frames", [None, 3]) -def test_time_based_sampler(seconds_between_frames): +def test_time_based_sampler(sampler, seconds_between_frames): decoder = VideoDecoder(NASA_VIDEO.path) num_frames_per_clip = 3 - seconds_between_clip_starts = 2 - clips = clips_at_regular_timestamps( + clips = sampler( decoder, - seconds_between_clip_starts=seconds_between_clip_starts, num_frames_per_clip=num_frames_per_clip, seconds_between_frames=seconds_between_frames, ) - expeted_num_clips = len(clips) # no-op check, it's just hard to assert + expected_num_clips = ( + len(clips) # No-op check, we can't assert with regular sampler + if sampler.func is clips_at_regular_timestamps + else sampler.keywords["num_clips"] + ) _assert_output_type_and_shapes( video=NASA_VIDEO, clips=clips, - expected_num_clips=expeted_num_clips, + expected_num_clips=expected_num_clips, num_frames_per_clip=num_frames_per_clip, ) - expected_seconds_between_clip_starts = torch.tensor( - [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float - ) - _assert_regular_sampler( - clips=clips, - expected_seconds_between_clip_starts=expected_seconds_between_clip_starts, - ) + if sampler.func is clips_at_regular_timestamps: + seconds_between_clip_starts = sampler.keywords["seconds_between_clip_starts"] + expected_seconds_between_clip_starts = torch.tensor( + [seconds_between_clip_starts] * (len(clips) - 1), dtype=torch.float + ) + _assert_regular_sampler( + clips=clips, + expected_seconds_between_clip_starts=expected_seconds_between_clip_starts, + ) expected_seconds_between_frames = ( seconds_between_frames or 1 / decoder.metadata.average_fps @@ -157,6 +169,8 @@ def test_time_based_sampler(seconds_between_frames): 12.0, False, ), + (partial(clips_at_random_indices, num_clips=10), 10, 11, True), + (partial(clips_at_random_indices, num_clips=10), 10, 12, False), ), ) def test_sampling_range( @@ -169,9 +183,10 @@ def test_sampling_range( # # For time-based: # The test is similar but with different semantics. We set the sampling - # range to be 1 second or 2 seconds. Since we set seconds_between_clip_start - # to 1 we expect exactly one clip with the sampling range is of size 1, and - # 2 different clips when teh sampling range is 2 seconds. + # range to be 1 second or 2 seconds. Since we set + # seconds_between_clip_starts to 1 we expect exactly one clip with the + # sampling range is of size 1, and 2 different clips when the sampling range + # is 2 seconds. # When size=2 there's still a (small) non-zero probability of sampling the # same indices for clip starts, so we hard-code a seed that works. @@ -230,7 +245,14 @@ def test_sampling_range_negative(sampler): assert_tensor_equal(clip.data, clips_1[0].data) -def test_sampling_range_default_behavior_random_sampler(): +@pytest.mark.parametrize( + "sampler", + ( + clips_at_random_indices, + clips_at_random_timestamps, + ), +) +def test_sampling_range_default_behavior_random_sampler(sampler): # This is a functional test for the default behavior of the # sampling_range_end parameter, for the random sampler. # By default it's None, which means the @@ -252,10 +274,10 @@ def test_sampling_range_default_behavior_random_sampler(): num_clips = 20 num_frames_per_clip = 15 - sampling_range_start = -20 + sampling_range_start = -20 if sampler is clips_at_random_indices else 11 # with default sampling_range_end value - clips_default = clips_at_random_indices( + clips_default = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -266,13 +288,13 @@ def test_sampling_range_default_behavior_random_sampler(): last_clip_start_default = max([clip.pts_seconds[0] for clip in clips_default]) - # with manual sampling_range_end value set to last frame - clips_manual = clips_at_random_indices( + # with manual sampling_range_end value set to last frame / end of video + clips_manual = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, sampling_range_start=sampling_range_start, - sampling_range_end=len(decoder), + sampling_range_end=1000, ) last_clip_start_manual = max([clip.pts_seconds[0] for clip in clips_manual]) @@ -321,14 +343,20 @@ def test_sampling_range_default_regular_sampler(sampler): partial( clips_at_regular_indices, sampling_range_start=-1, sampling_range_end=1000 ), - # Note: the hard-coded value of sampling_range_start=12 is because we know - # the NASA_VIDEO is 13.01s seconds long + # Note: the hard-coded value of sampling_range_start=13 is because we know + # the NASA_VIDEO is ~13.01s seconds long. We just need to clip to start + # on, or close to the last frame. partial( clips_at_regular_timestamps, seconds_between_clip_starts=0.1, sampling_range_start=13, sampling_range_end=1000, ), + partial( + clips_at_random_timestamps, + sampling_range_start=13, + sampling_range_end=1000, + ), ), ) def test_sampling_range_error_policy(sampler): @@ -341,14 +369,17 @@ def test_sampling_range_error_policy(sampler): ) -def test_random_sampler_randomness(): +@pytest.mark.parametrize( + "sampler", (clips_at_random_indices, clips_at_random_timestamps) +) +def test_random_sampler_randomness(sampler): decoder = VideoDecoder(NASA_VIDEO.path) num_clips = 5 builtin_random_state_start = random.getstate() torch.manual_seed(0) - clips_1 = clips_at_random_indices(decoder, num_clips=num_clips) + clips_1 = sampler(decoder, num_clips=num_clips) # Assert the clip starts aren't sorted, to make sure we haven't messed up # the implementation. (This may fail if we're unlucky, but we hard-coded a @@ -358,7 +389,7 @@ def test_random_sampler_randomness(): # Call the same sampler again with the same seed, expect same results torch.manual_seed(0) - clips_2 = clips_at_random_indices(decoder, num_clips=num_clips) + clips_2 = sampler(decoder, num_clips=num_clips) for clip_1, clip_2 in zip(clips_1, clips_2): assert_tensor_equal(clip_1.data, clip_2.data) assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) @@ -366,7 +397,7 @@ def test_random_sampler_randomness(): # Call with a different seed, expect different results torch.manual_seed(1) - clips_3 = clips_at_random_indices(decoder, num_clips=num_clips) + clips_3 = sampler(decoder, num_clips=num_clips) with pytest.raises(AssertionError, match="Tensor-likes are not"): assert_tensor_equal(clips_1[0].data, clips_3[0].data) @@ -412,6 +443,7 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz clips_at_random_indices, clips_at_regular_indices, partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + clips_at_random_timestamps, ), ) def test_sampler_errors(sampler): @@ -441,9 +473,7 @@ def test_sampler_errors(sampler): @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) def test_index_based_samplers_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) - with pytest.raises( - ValueError, match=re.escape("num_clips (0) must be strictly positive") - ): + with pytest.raises(ValueError, match=re.escape("num_clips (0) must be > 0")): sampler(decoder, num_clips=0) with pytest.raises( @@ -468,9 +498,15 @@ def test_index_based_samplers_errors(sampler): ) -def test_time_based_sampler_errors(): +@pytest.mark.parametrize( + "sampler", + ( + clips_at_random_timestamps, + partial(clips_at_regular_timestamps, seconds_between_clip_starts=1), + ), +) +def test_time_based_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) - sampler = partial(clips_at_regular_timestamps, seconds_between_clip_starts=1) with pytest.raises( ValueError, match=re.escape("sampling_range_start (-1) must be at least 0.0") @@ -482,10 +518,17 @@ def test_time_based_sampler_errors(): ): sampler(decoder, sampling_range_end=-1) - with pytest.raises( - ValueError, match=re.escape("seconds_between_clip_starts (-1) must be > 0") - ): - sampler(decoder, seconds_between_clip_starts=-1) + if sampler is clips_at_random_timestamps: + with pytest.raises( + ValueError, + match=re.escape("num_clips (0) must be > 0"), + ): + sampler(decoder, num_clips=0) + else: + with pytest.raises( + ValueError, match=re.escape("seconds_between_clip_starts (-1) must be > 0") + ): + sampler(decoder, seconds_between_clip_starts=-1) @contextlib.contextmanager def restore_metadata():