From c090e44b787fb11c0c91528592b6fe8432b11669 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 03:49:39 -0700 Subject: [PATCH 1/5] Add Random index-based clip sampler --- src/torchcodec/__init__.py | 2 +- src/torchcodec/samplers/__init__.py | 1 + src/torchcodec/samplers/_implem.py | 62 +++++++++++++++++++++++ test/samplers/test_samplers.py | 76 +++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 src/torchcodec/samplers/__init__.py create mode 100644 src/torchcodec/samplers/_implem.py create mode 100644 test/samplers/test_samplers.py diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index a97c32ce0..069f2163a 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import decoders # noqa +from . import decoders, samplers # noqa # noqa __version__ = "0.0.2.dev" diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py new file mode 100644 index 000000000..b5aa1261e --- /dev/null +++ b/src/torchcodec/samplers/__init__.py @@ -0,0 +1 @@ +from ._implem import clips_at_random_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py new file mode 100644 index 000000000..832cd4404 --- /dev/null +++ b/src/torchcodec/samplers/_implem.py @@ -0,0 +1,62 @@ +from typing import List + +import torch + +from torchcodec.decoders import ( # TODO: move FrameBatch to torchcodec.FrameBatch? + FrameBatch, + SimpleVideoDecoder, +) + + +def clips_at_random_indices( + decoder: SimpleVideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, +) -> List[FrameBatch]: + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) + + # Determine the span of a clip, i.e. the number of frames (or indices) + # between the first and last frame in the clip, both included. This isn't + # the same as the number of frames in a clip! + # Example: f means a frame in the clip, x means a frame excluded from the clip + # num_frames_per_clip = 4 + # num_indices_between_frames = 1, clip = ffff , span = 4 + # num_indices_between_frames = 2, clip = fxfxfxf , span = 7 + # num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 + clip_span = num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + last_clip_start_index = len(decoder) - clip_span + clip_start_indices = torch.randint( + low=0, high=last_clip_start_index + 1, size=(num_clips,) + ) + + # TODO: This is inefficient as we are potentially seeking backwards. + # We should sort by clip start before querying, and re-shuffle. + # Note: we may still have to seek backwards if we have overlapping clips. + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py new file mode 100644 index 000000000..2bf056b2b --- /dev/null +++ b/test/samplers/test_samplers.py @@ -0,0 +1,76 @@ +import re + +import pytest +import torch +from torchcodec.decoders import FrameBatch, SimpleVideoDecoder +from torchcodec.samplers import clips_at_random_indices + +from ..utils import NASA_VIDEO + + +@pytest.mark.parametrize("num_indices_between_frames", [1, 5]) +def test_random_sampler(num_indices_between_frames): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 2 + num_frames_per_clip = 3 + + clips = clips_at_random_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + + assert isinstance(clips, list) + assert len(clips) == num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + NASA_VIDEO.height, + NASA_VIDEO.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + # Check the num_indices_between_frames parameter by asserting that the + # "time" difference between frames in a clip is the same as the "index" + # distance. + avg_distance_between_frames_seconds = torch.concat( + [clip.pts_seconds.diff() for clip in clips] + ).mean() + assert avg_distance_between_frames_seconds == pytest.approx( + num_indices_between_frames / decoder.metadata.average_fps + ) + + +def test_random_sampler_errors(): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + with pytest.raises( + ValueError, match=re.escape("num_clips (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_clips=0) + + with pytest.raises( + ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_frames_per_clip=0) + + with pytest.raises( + ValueError, + match=re.escape("num_indices_between_frames (0) must be strictly positive"), + ): + clips_at_random_indices(decoder, num_indices_between_frames=0) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1000) is larger than the number of frames"), + ): + clips_at_random_indices(decoder, num_frames_per_clip=1000) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1001) is larger than the number of frames"), + ): + clips_at_random_indices( + decoder, num_frames_per_clip=2, num_indices_between_frames=1000 + ) From c61816006b8026e0f1c4844e5d47ed8b801de5e4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 09:01:14 -0700 Subject: [PATCH 2/5] Basic linear sampler --- src/torchcodec/samplers/__init__.py | 2 +- src/torchcodec/samplers/_implem.py | 47 +++++++++++++++++++++++++++++ test/samplers/test_samplers.py | 38 ++++++++++++++++++++++- 3 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py index b5aa1261e..5a173c218 100644 --- a/src/torchcodec/samplers/__init__.py +++ b/src/torchcodec/samplers/__init__.py @@ -1 +1 @@ -from ._implem import clips_at_random_indices +from ._implem import clips_at_random_indices, clips_at_regular_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 832cd4404..1a66de0e5 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -60,3 +60,50 @@ def clips_at_random_indices( ] return clips + + +def clips_at_regular_indices( + decoder: SimpleVideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, +) -> List[FrameBatch]: + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) + + clip_span = num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + # TODO what if num_clips > (len(decoder) - clip_span)? We can: + # - wrap around + # - as a real wrap around [0, 1, 2, 3, 0, 1] + # - the way linspace natively does it i.g. [0, 0, 1, 2, 2, 3] + # - truncate and return less than num_clips. + # - error + + clip_start_indices = torch.linspace(0, len(decoder) - clip_span, steps=num_clips, dtype=torch.int) + + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 2bf056b2b..b84aa3cf3 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -3,7 +3,7 @@ import pytest import torch from torchcodec.decoders import FrameBatch, SimpleVideoDecoder -from torchcodec.samplers import clips_at_random_indices +from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices from ..utils import NASA_VIDEO @@ -74,3 +74,39 @@ def test_random_sampler_errors(): clips_at_random_indices( decoder, num_frames_per_clip=2, num_indices_between_frames=1000 ) + + +@pytest.mark.parametrize("num_indices_between_frames", [1])#, 5]) +def test_clips_at_regular_indices(num_indices_between_frames): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 4 + num_frames_per_clip = 3 + + clips = clips_at_regular_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + print(clips) + + assert isinstance(clips, list) + assert len(clips) == num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + NASA_VIDEO.height, + NASA_VIDEO.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + # # Check the num_indices_between_frames parameter by asserting that the + # # "time" difference between frames in a clip is the same as the "index" + # # distance. + # avg_distance_between_frames_seconds = torch.concat( + # [clip.pts_seconds.diff() for clip in clips] + # ).mean() + # assert avg_distance_between_frames_seconds == pytest.approx( + # num_indices_between_frames / decoder.metadata.average_fps + # ) From 2c5a559fc73c6bf54765096df674fdb5389242b4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 03:59:06 -0700 Subject: [PATCH 3/5] Add tests --- src/torchcodec/samplers/_implem.py | 17 +++++- test/samplers/test_samplers.py | 91 +++++++++++------------------- 2 files changed, 48 insertions(+), 60 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index b6c20df20..9e390b123 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -148,6 +148,8 @@ def clips_at_regular_indices( num_clips: int = 1, num_frames_per_clip: int = 1, num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). ) -> List[FrameBatch]: _validate_params( @@ -162,6 +164,12 @@ def clips_at_regular_indices( num_frames_per_clip=num_frames_per_clip, ) + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + # TODO what if num_clips > (len(decoder) - clip_span)? We can: # - wrap around # - as a real wrap around [0, 1, 2, 3, 0, 1] @@ -169,8 +177,15 @@ def clips_at_regular_indices( # - truncate and return less than num_clips. # - error + sampling_range_start, sampling_range_end = _validate_sampling_range( + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames=len(decoder), + clip_span=clip_span, + ) + clip_start_indices = torch.linspace( - 0, len(decoder) - clip_span, steps=num_clips, dtype=torch.int + sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) clips = [ diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 194c4b941..097f830aa 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -11,13 +11,14 @@ from ..utils import assert_tensor_equal, NASA_VIDEO +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) -def test_random_sampler(num_indices_between_frames): +def test_sampler(sampler, num_indices_between_frames): decoder = VideoDecoder(NASA_VIDEO.path) - num_clips = 2 + num_clips = 5 num_frames_per_clip = 3 - clips = clips_at_random_indices( + clips = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -35,6 +36,15 @@ def test_random_sampler(num_indices_between_frames): ) assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + if sampler is clips_at_regular_indices: + # assert regular spacing between sampled clips + # Note: need approximate check as actual values typically look like [3.2032, 3.2366, 3.2366, 3.2366] + seconds_between_clip_starts = torch.tensor( + [clip.pts_seconds[0] for clip in clips] + ).diff() + for diff in seconds_between_clip_starts: + assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) + # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. @@ -46,6 +56,7 @@ def test_random_sampler(num_indices_between_frames): ) +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize( "sampling_range_start, sampling_range_end, assert_all_equal", ( @@ -53,8 +64,8 @@ def test_random_sampler(num_indices_between_frames): (10, 12, False), ), ) -def test_random_sampler_range( - sampling_range_start, sampling_range_end, assert_all_equal +def test_sampling_range( + sampler, sampling_range_start, sampling_range_end, assert_all_equal ): # Test the sampling_range_start and sampling_range_end parameters by # asserting that all clips are equal if the sampling range is of size 1, @@ -66,7 +77,7 @@ def test_random_sampler_range( decoder = VideoDecoder(NASA_VIDEO.path) - clips = clips_at_random_indices( + clips = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -87,13 +98,14 @@ def test_random_sampler_range( assert_tensor_equal(clip.data, clips[0].data) -def test_random_sampler_range_negative(): +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_negative(sampler): # Test the passing negative values for sampling_range_start and # sampling_range_end is the same as passing `len(decoder) - val` decoder = VideoDecoder(NASA_VIDEO.path) - clips_1 = clips_at_random_indices( + clips_1 = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -101,7 +113,7 @@ def test_random_sampler_range_negative(): sampling_range_end=len(decoder) - 99, ) - clips_2 = clips_at_random_indices( + clips_2 = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -151,97 +163,58 @@ def test_random_sampler_randomness(): assert builtin_random_state_start == builtin_random_state_end -def test_random_sampler_errors(): +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_random_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( ValueError, match=re.escape("num_clips (0) must be strictly positive") ): - clips_at_random_indices(decoder, num_clips=0) + sampler(decoder, num_clips=0) with pytest.raises( ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") ): - clips_at_random_indices(decoder, num_frames_per_clip=0) + sampler(decoder, num_frames_per_clip=0) with pytest.raises( ValueError, match=re.escape("num_indices_between_frames (0) must be strictly positive"), ): - clips_at_random_indices(decoder, num_indices_between_frames=0) + sampler(decoder, num_indices_between_frames=0) with pytest.raises( ValueError, match=re.escape("Clip span (1000) is larger than the number of frames"), ): - clips_at_random_indices(decoder, num_frames_per_clip=1000) + sampler(decoder, num_frames_per_clip=1000) with pytest.raises( ValueError, match=re.escape("Clip span (1001) is larger than the number of frames"), ): - clips_at_random_indices( - decoder, num_frames_per_clip=2, num_indices_between_frames=1000 - ) + sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000) with pytest.raises( ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=1000) + sampler(decoder, sampling_range_start=1000) with pytest.raises( ValueError, match=re.escape("sampling_range_start (4) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) + sampler(decoder, sampling_range_start=4, sampling_range_end=4) with pytest.raises( ValueError, match=re.escape("sampling_range_start (290) must be smaller than") ): - clips_at_random_indices( - decoder, sampling_range_start=-100, sampling_range_end=-100 - ) + sampler(decoder, sampling_range_start=-100, sampling_range_end=-100) with pytest.raises( ValueError, match="We determined that sampling_range_end should" ): - clips_at_random_indices( + sampler( decoder, num_frames_per_clip=10, sampling_range_start=len(decoder) - 1, sampling_range_end=None, ) - - -@pytest.mark.parametrize("num_indices_between_frames", [1]) # , 5]) -def test_clips_at_regular_indices(num_indices_between_frames): - decoder = VideoDecoder(NASA_VIDEO.path) - num_clips = 4 - num_frames_per_clip = 3 - - clips = clips_at_regular_indices( - decoder, - num_clips=num_clips, - num_frames_per_clip=num_frames_per_clip, - num_indices_between_frames=num_indices_between_frames, - ) - print(clips) - - assert isinstance(clips, list) - assert len(clips) == num_clips - assert all(isinstance(clip, FrameBatch) for clip in clips) - expected_clip_data_shape = ( - num_frames_per_clip, - 3, - NASA_VIDEO.height, - NASA_VIDEO.width, - ) - assert all(clip.data.shape == expected_clip_data_shape for clip in clips) - - # # Check the num_indices_between_frames parameter by asserting that the - # # "time" difference between frames in a clip is the same as the "index" - # # distance. - # avg_distance_between_frames_seconds = torch.concat( - # [clip.pts_seconds.diff() for clip in clips] - # ).mean() - # assert avg_distance_between_frames_seconds == pytest.approx( - # num_indices_between_frames / decoder.metadata.average_fps - # ) From ce4619640679d92862a763f039f8fba815b8b054 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 05:12:45 -0700 Subject: [PATCH 4/5] Handle edge case when num_clips is larger than available sampling range --- src/torchcodec/samplers/_implem.py | 14 ++++++------- test/samplers/test_samplers.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 9e390b123..75fd6db55 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -170,13 +170,6 @@ def clips_at_regular_indices( f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" ) - # TODO what if num_clips > (len(decoder) - clip_span)? We can: - # - wrap around - # - as a real wrap around [0, 1, 2, 3, 0, 1] - # - the way linspace natively does it i.g. [0, 0, 1, 2, 2, 3] - # - truncate and return less than num_clips. - # - error - sampling_range_start, sampling_range_end = _validate_sampling_range( sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, @@ -184,10 +177,17 @@ def clips_at_regular_indices( clip_span=clip_span, ) + # Note [num clips larger than sampling range] + # If we ask for more clips than there are frames in the sampling range (or + # in the video), we rely on torch.linspace behavior which will return + # duplicated indices. E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) + # returns 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 clip_start_indices = torch.linspace( sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) + # Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. + # See other TODO over there, and apply similar changes here. clips = [ decoder.get_frames_at( start=clip_start_index, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 097f830aa..6fb3c912a 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -45,6 +45,8 @@ def test_sampler(sampler, num_indices_between_frames): for diff in seconds_between_clip_starts: assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) + assert (diff > 0).all() # Also assert clips are sorted by start time + # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. @@ -163,6 +165,37 @@ def test_random_sampler_randomness(): assert builtin_random_state_start == builtin_random_state_end +@pytest.mark.parametrize( + "num_clips, sampling_range_size", + ( + # Ask for 50 clips while the sampling range is 10 frames wide + # expect 10 clips with 10 unique starting points. + (10, 10), + # Ask for 50 clips while the sampling range is only 10 frames wide + # expect 50 clips with only 10 unique starting points. + (50, 10), + ), +) +def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_size): + # Test for expected behavior described in Note [num clips larger than sampling range] + decoder = VideoDecoder(NASA_VIDEO.path) + clips = clips_at_regular_indices( + decoder, + num_clips=num_clips, + sampling_range_start=0, + sampling_range_end=sampling_range_size, # because sampling_range_start=0 + ) + + assert len(clips) == num_clips + + clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips]) + assert len(torch.unique(clip_starts_seconds)) == sampling_range_size + + # Assert clips starts are ordered, i.e. the start indices don't just "wrap + # around". They're duplicated *and* ordered. + assert (clip_starts_seconds.diff() >= 0).all() + + @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) def test_random_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) From 1873174a84c37993e5e7583a67d1a22621286a98 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 06:40:41 -0700 Subject: [PATCH 5/5] Address comments --- src/torchcodec/samplers/_implem.py | 11 +++++++---- test/samplers/test_samplers.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 75fd6db55..3a8c46291 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -178,10 +178,13 @@ def clips_at_regular_indices( ) # Note [num clips larger than sampling range] - # If we ask for more clips than there are frames in the sampling range (or - # in the video), we rely on torch.linspace behavior which will return - # duplicated indices. E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) - # returns 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # If we ask for more clips than there are frames in the sampling range or + # in the video, we rely on torch.linspace behavior which will return + # duplicated indices. + # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns + # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # Alternatively we could wrap around, but the current behavior is closer to + # the expected "equally spaced indices" sampling. clip_start_indices = torch.linspace( sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 6fb3c912a..2df7bf79f 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -168,7 +168,7 @@ def test_random_sampler_randomness(): @pytest.mark.parametrize( "num_clips, sampling_range_size", ( - # Ask for 50 clips while the sampling range is 10 frames wide + # Ask for 10 clips while the sampling range is 10 frames wide # expect 10 clips with 10 unique starting points. (10, 10), # Ask for 50 clips while the sampling range is only 10 frames wide