diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py index b5aa1261e..5a173c218 100644 --- a/src/torchcodec/samplers/__init__.py +++ b/src/torchcodec/samplers/__init__.py @@ -1 +1 @@ -from ._implem import clips_at_random_indices +from ._implem import clips_at_random_indices, clips_at_regular_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 4e6b37df9..3a8c46291 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -140,3 +140,64 @@ def clips_at_random_indices( random.setstate(builtin_random_state) return clips + + +def clips_at_regular_indices( + decoder: VideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). +) -> List[FrameBatch]: + + _validate_params( + decoder=decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + + clip_span = _get_clip_span( + num_indices_between_frames=num_indices_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + sampling_range_start, sampling_range_end = _validate_sampling_range( + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames=len(decoder), + clip_span=clip_span, + ) + + # Note [num clips larger than sampling range] + # If we ask for more clips than there are frames in the sampling range or + # in the video, we rely on torch.linspace behavior which will return + # duplicated indices. + # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns + # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # Alternatively we could wrap around, but the current behavior is closer to + # the expected "equally spaced indices" sampling. + clip_start_indices = torch.linspace( + sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int + ) + + # Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. + # See other TODO over there, and apply similar changes here. + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 9525fcba7..2df7bf79f 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -4,19 +4,21 @@ import pytest import torch + from torchcodec.decoders import FrameBatch, VideoDecoder -from torchcodec.samplers import clips_at_random_indices +from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices from ..utils import assert_tensor_equal, NASA_VIDEO +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) -def test_random_sampler(num_indices_between_frames): +def test_sampler(sampler, num_indices_between_frames): decoder = VideoDecoder(NASA_VIDEO.path) - num_clips = 2 + num_clips = 5 num_frames_per_clip = 3 - clips = clips_at_random_indices( + clips = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -34,6 +36,17 @@ def test_random_sampler(num_indices_between_frames): ) assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + if sampler is clips_at_regular_indices: + # assert regular spacing between sampled clips + # Note: need approximate check as actual values typically look like [3.2032, 3.2366, 3.2366, 3.2366] + seconds_between_clip_starts = torch.tensor( + [clip.pts_seconds[0] for clip in clips] + ).diff() + for diff in seconds_between_clip_starts: + assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) + + assert (diff > 0).all() # Also assert clips are sorted by start time + # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. @@ -45,6 +58,7 @@ def test_random_sampler(num_indices_between_frames): ) +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize( "sampling_range_start, sampling_range_end, assert_all_equal", ( @@ -52,8 +66,8 @@ def test_random_sampler(num_indices_between_frames): (10, 12, False), ), ) -def test_random_sampler_range( - sampling_range_start, sampling_range_end, assert_all_equal +def test_sampling_range( + sampler, sampling_range_start, sampling_range_end, assert_all_equal ): # Test the sampling_range_start and sampling_range_end parameters by # asserting that all clips are equal if the sampling range is of size 1, @@ -65,7 +79,7 @@ def test_random_sampler_range( decoder = VideoDecoder(NASA_VIDEO.path) - clips = clips_at_random_indices( + clips = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -86,13 +100,14 @@ def test_random_sampler_range( assert_tensor_equal(clip.data, clips[0].data) -def test_random_sampler_range_negative(): +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_negative(sampler): # Test the passing negative values for sampling_range_start and # sampling_range_end is the same as passing `len(decoder) - val` decoder = VideoDecoder(NASA_VIDEO.path) - clips_1 = clips_at_random_indices( + clips_1 = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -100,7 +115,7 @@ def test_random_sampler_range_negative(): sampling_range_end=len(decoder) - 99, ) - clips_2 = clips_at_random_indices( + clips_2 = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -150,59 +165,87 @@ def test_random_sampler_randomness(): assert builtin_random_state_start == builtin_random_state_end -def test_random_sampler_errors(): +@pytest.mark.parametrize( + "num_clips, sampling_range_size", + ( + # Ask for 10 clips while the sampling range is 10 frames wide + # expect 10 clips with 10 unique starting points. + (10, 10), + # Ask for 50 clips while the sampling range is only 10 frames wide + # expect 50 clips with only 10 unique starting points. + (50, 10), + ), +) +def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_size): + # Test for expected behavior described in Note [num clips larger than sampling range] + decoder = VideoDecoder(NASA_VIDEO.path) + clips = clips_at_regular_indices( + decoder, + num_clips=num_clips, + sampling_range_start=0, + sampling_range_end=sampling_range_size, # because sampling_range_start=0 + ) + + assert len(clips) == num_clips + + clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips]) + assert len(torch.unique(clip_starts_seconds)) == sampling_range_size + + # Assert clips starts are ordered, i.e. the start indices don't just "wrap + # around". They're duplicated *and* ordered. + assert (clip_starts_seconds.diff() >= 0).all() + + +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_random_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( ValueError, match=re.escape("num_clips (0) must be strictly positive") ): - clips_at_random_indices(decoder, num_clips=0) + sampler(decoder, num_clips=0) with pytest.raises( ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") ): - clips_at_random_indices(decoder, num_frames_per_clip=0) + sampler(decoder, num_frames_per_clip=0) with pytest.raises( ValueError, match=re.escape("num_indices_between_frames (0) must be strictly positive"), ): - clips_at_random_indices(decoder, num_indices_between_frames=0) + sampler(decoder, num_indices_between_frames=0) with pytest.raises( ValueError, match=re.escape("Clip span (1000) is larger than the number of frames"), ): - clips_at_random_indices(decoder, num_frames_per_clip=1000) + sampler(decoder, num_frames_per_clip=1000) with pytest.raises( ValueError, match=re.escape("Clip span (1001) is larger than the number of frames"), ): - clips_at_random_indices( - decoder, num_frames_per_clip=2, num_indices_between_frames=1000 - ) + sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000) with pytest.raises( ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=1000) + sampler(decoder, sampling_range_start=1000) with pytest.raises( ValueError, match=re.escape("sampling_range_start (4) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) + sampler(decoder, sampling_range_start=4, sampling_range_end=4) with pytest.raises( ValueError, match=re.escape("sampling_range_start (290) must be smaller than") ): - clips_at_random_indices( - decoder, sampling_range_start=-100, sampling_range_end=-100 - ) + sampler(decoder, sampling_range_start=-100, sampling_range_end=-100) with pytest.raises( ValueError, match="We determined that sampling_range_end should" ): - clips_at_random_indices( + sampler( decoder, num_frames_per_clip=10, sampling_range_start=len(decoder) - 1,