-
Notifications
You must be signed in to change notification settings - Fork 71
Add regular index-based sampler - part 1 #240
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c090e44
6b52cfd
c618160
b5545a9
2c5a559
ce46196
1873174
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1 +1 @@ | ||
| from ._implem import clips_at_random_indices | ||
| from ._implem import clips_at_random_indices, clips_at_regular_indices |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,19 +4,21 @@ | |
|
|
||
| import pytest | ||
| import torch | ||
|
|
||
| from torchcodec.decoders import FrameBatch, VideoDecoder | ||
| from torchcodec.samplers import clips_at_random_indices | ||
| from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices | ||
|
|
||
| from ..utils import assert_tensor_equal, NASA_VIDEO | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) | ||
| @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) | ||
| def test_random_sampler(num_indices_between_frames): | ||
| def test_sampler(sampler, num_indices_between_frames): | ||
| decoder = VideoDecoder(NASA_VIDEO.path) | ||
| num_clips = 2 | ||
| num_clips = 5 | ||
| num_frames_per_clip = 3 | ||
|
|
||
| clips = clips_at_random_indices( | ||
| clips = sampler( | ||
| decoder, | ||
| num_clips=num_clips, | ||
| num_frames_per_clip=num_frames_per_clip, | ||
|
|
@@ -34,6 +36,17 @@ def test_random_sampler(num_indices_between_frames): | |
| ) | ||
| assert all(clip.data.shape == expected_clip_data_shape for clip in clips) | ||
|
|
||
| if sampler is clips_at_regular_indices: | ||
| # assert regular spacing between sampled clips | ||
| # Note: need approximate check as actual values typically look like [3.2032, 3.2366, 3.2366, 3.2366] | ||
| seconds_between_clip_starts = torch.tensor( | ||
| [clip.pts_seconds[0] for clip in clips] | ||
| ).diff() | ||
| for diff in seconds_between_clip_starts: | ||
| assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) | ||
|
|
||
| assert (diff > 0).all() # Also assert clips are sorted by start time | ||
|
|
||
| # Check the num_indices_between_frames parameter by asserting that the | ||
| # "time" difference between frames in a clip is the same as the "index" | ||
| # distance. | ||
|
|
@@ -45,15 +58,16 @@ def test_random_sampler(num_indices_between_frames): | |
| ) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) | ||
| @pytest.mark.parametrize( | ||
| "sampling_range_start, sampling_range_end, assert_all_equal", | ||
| ( | ||
| (10, 11, True), | ||
| (10, 12, False), | ||
| ), | ||
| ) | ||
| def test_random_sampler_range( | ||
| sampling_range_start, sampling_range_end, assert_all_equal | ||
| def test_sampling_range( | ||
| sampler, sampling_range_start, sampling_range_end, assert_all_equal | ||
| ): | ||
| # Test the sampling_range_start and sampling_range_end parameters by | ||
| # asserting that all clips are equal if the sampling range is of size 1, | ||
|
|
@@ -65,7 +79,7 @@ def test_random_sampler_range( | |
|
|
||
| decoder = VideoDecoder(NASA_VIDEO.path) | ||
|
|
||
| clips = clips_at_random_indices( | ||
| clips = sampler( | ||
| decoder, | ||
| num_clips=10, | ||
| num_frames_per_clip=2, | ||
|
|
@@ -86,21 +100,22 @@ def test_random_sampler_range( | |
| assert_tensor_equal(clip.data, clips[0].data) | ||
|
|
||
|
|
||
| def test_random_sampler_range_negative(): | ||
| @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) | ||
| def test_sampling_range_negative(sampler): | ||
| # Test the passing negative values for sampling_range_start and | ||
| # sampling_range_end is the same as passing `len(decoder) - val` | ||
|
|
||
| decoder = VideoDecoder(NASA_VIDEO.path) | ||
|
|
||
| clips_1 = clips_at_random_indices( | ||
| clips_1 = sampler( | ||
| decoder, | ||
| num_clips=10, | ||
| num_frames_per_clip=2, | ||
| sampling_range_start=len(decoder) - 100, | ||
| sampling_range_end=len(decoder) - 99, | ||
| ) | ||
|
|
||
| clips_2 = clips_at_random_indices( | ||
| clips_2 = sampler( | ||
| decoder, | ||
| num_clips=10, | ||
| num_frames_per_clip=2, | ||
|
|
@@ -150,59 +165,87 @@ def test_random_sampler_randomness(): | |
| assert builtin_random_state_start == builtin_random_state_end | ||
|
|
||
|
|
||
| def test_random_sampler_errors(): | ||
| @pytest.mark.parametrize( | ||
| "num_clips, sampling_range_size", | ||
| ( | ||
| # Ask for 10 clips while the sampling range is 10 frames wide | ||
| # expect 10 clips with 10 unique starting points. | ||
| (10, 10), | ||
| # Ask for 50 clips while the sampling range is only 10 frames wide | ||
| # expect 50 clips with only 10 unique starting points. | ||
| (50, 10), | ||
| ), | ||
| ) | ||
| def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_size): | ||
| # Test for expected behavior described in Note [num clips larger than sampling range] | ||
| decoder = VideoDecoder(NASA_VIDEO.path) | ||
| clips = clips_at_regular_indices( | ||
| decoder, | ||
| num_clips=num_clips, | ||
| sampling_range_start=0, | ||
| sampling_range_end=sampling_range_size, # because sampling_range_start=0 | ||
| ) | ||
|
|
||
| assert len(clips) == num_clips | ||
|
|
||
| clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips]) | ||
| assert len(torch.unique(clip_starts_seconds)) == sampling_range_size | ||
|
|
||
| # Assert clips starts are ordered, i.e. the start indices don't just "wrap | ||
| # around". They're duplicated *and* ordered. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Mention reason for not wrapping around?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There's no particular reason honestly, it's just an assertion of the behavior of |
||
| assert (clip_starts_seconds.diff() >= 0).all() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) | ||
| def test_random_sampler_errors(sampler): | ||
| decoder = VideoDecoder(NASA_VIDEO.path) | ||
| with pytest.raises( | ||
| ValueError, match=re.escape("num_clips (0) must be strictly positive") | ||
| ): | ||
| clips_at_random_indices(decoder, num_clips=0) | ||
| sampler(decoder, num_clips=0) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") | ||
| ): | ||
| clips_at_random_indices(decoder, num_frames_per_clip=0) | ||
| sampler(decoder, num_frames_per_clip=0) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=re.escape("num_indices_between_frames (0) must be strictly positive"), | ||
| ): | ||
| clips_at_random_indices(decoder, num_indices_between_frames=0) | ||
| sampler(decoder, num_indices_between_frames=0) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=re.escape("Clip span (1000) is larger than the number of frames"), | ||
| ): | ||
| clips_at_random_indices(decoder, num_frames_per_clip=1000) | ||
| sampler(decoder, num_frames_per_clip=1000) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, | ||
| match=re.escape("Clip span (1001) is larger than the number of frames"), | ||
| ): | ||
| clips_at_random_indices( | ||
| decoder, num_frames_per_clip=2, num_indices_between_frames=1000 | ||
| ) | ||
| sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") | ||
| ): | ||
| clips_at_random_indices(decoder, sampling_range_start=1000) | ||
| sampler(decoder, sampling_range_start=1000) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("sampling_range_start (4) must be smaller than") | ||
| ): | ||
| clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) | ||
| sampler(decoder, sampling_range_start=4, sampling_range_end=4) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match=re.escape("sampling_range_start (290) must be smaller than") | ||
| ): | ||
| clips_at_random_indices( | ||
| decoder, sampling_range_start=-100, sampling_range_end=-100 | ||
| ) | ||
| sampler(decoder, sampling_range_start=-100, sampling_range_end=-100) | ||
|
|
||
| with pytest.raises( | ||
| ValueError, match="We determined that sampling_range_end should" | ||
| ): | ||
| clips_at_random_indices( | ||
| sampler( | ||
| decoder, | ||
| num_frames_per_clip=10, | ||
| sampling_range_start=len(decoder) - 1, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel like
clips_at_regular_indices(num_clips=1)is an anti pattern. Same with 2. Maybe we should make 3 the default, and enforcenum_clips>=3? Or not enforce, but still set the default to 3?I don't feel too strongly about it either way - just being pedantic.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is what other libraries do we can do the same here.
I was thinking of sampling the middle of the video if num_clips=1 but that would make it slower than existing libraries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. Other libs allow
num_clips=1. Let's revisit if needed