From c090e44b787fb11c0c91528592b6fe8432b11669 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 03:49:39 -0700 Subject: [PATCH 1/9] Add Random index-based clip sampler --- src/torchcodec/__init__.py | 2 +- src/torchcodec/samplers/__init__.py | 1 + src/torchcodec/samplers/_implem.py | 62 +++++++++++++++++++++++ test/samplers/test_samplers.py | 76 +++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 src/torchcodec/samplers/__init__.py create mode 100644 src/torchcodec/samplers/_implem.py create mode 100644 test/samplers/test_samplers.py diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index a97c32ce0..069f2163a 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import decoders # noqa +from . import decoders, samplers # noqa # noqa __version__ = "0.0.2.dev" diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py new file mode 100644 index 000000000..b5aa1261e --- /dev/null +++ b/src/torchcodec/samplers/__init__.py @@ -0,0 +1 @@ +from ._implem import clips_at_random_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py new file mode 100644 index 000000000..832cd4404 --- /dev/null +++ b/src/torchcodec/samplers/_implem.py @@ -0,0 +1,62 @@ +from typing import List + +import torch + +from torchcodec.decoders import ( # TODO: move FrameBatch to torchcodec.FrameBatch? + FrameBatch, + SimpleVideoDecoder, +) + + +def clips_at_random_indices( + decoder: SimpleVideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, +) -> List[FrameBatch]: + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) + + # Determine the span of a clip, i.e. the number of frames (or indices) + # between the first and last frame in the clip, both included. This isn't + # the same as the number of frames in a clip! + # Example: f means a frame in the clip, x means a frame excluded from the clip + # num_frames_per_clip = 4 + # num_indices_between_frames = 1, clip = ffff , span = 4 + # num_indices_between_frames = 2, clip = fxfxfxf , span = 7 + # num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 + clip_span = num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + last_clip_start_index = len(decoder) - clip_span + clip_start_indices = torch.randint( + low=0, high=last_clip_start_index + 1, size=(num_clips,) + ) + + # TODO: This is inefficient as we are potentially seeking backwards. + # We should sort by clip start before querying, and re-shuffle. + # Note: we may still have to seek backwards if we have overlapping clips. + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py new file mode 100644 index 000000000..2bf056b2b --- /dev/null +++ b/test/samplers/test_samplers.py @@ -0,0 +1,76 @@ +import re + +import pytest +import torch +from torchcodec.decoders import FrameBatch, SimpleVideoDecoder +from torchcodec.samplers import clips_at_random_indices + +from ..utils import NASA_VIDEO + + +@pytest.mark.parametrize("num_indices_between_frames", [1, 5]) +def test_random_sampler(num_indices_between_frames): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 2 + num_frames_per_clip = 3 + + clips = clips_at_random_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + + assert isinstance(clips, list) + assert len(clips) == num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + NASA_VIDEO.height, + NASA_VIDEO.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + # Check the num_indices_between_frames parameter by asserting that the + # "time" difference between frames in a clip is the same as the "index" + # distance. + avg_distance_between_frames_seconds = torch.concat( + [clip.pts_seconds.diff() for clip in clips] + ).mean() + assert avg_distance_between_frames_seconds == pytest.approx( + num_indices_between_frames / decoder.metadata.average_fps + ) + + +def test_random_sampler_errors(): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + with pytest.raises( + ValueError, match=re.escape("num_clips (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_clips=0) + + with pytest.raises( + ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_frames_per_clip=0) + + with pytest.raises( + ValueError, + match=re.escape("num_indices_between_frames (0) must be strictly positive"), + ): + clips_at_random_indices(decoder, num_indices_between_frames=0) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1000) is larger than the number of frames"), + ): + clips_at_random_indices(decoder, num_frames_per_clip=1000) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1001) is larger than the number of frames"), + ): + clips_at_random_indices( + decoder, num_frames_per_clip=2, num_indices_between_frames=1000 + ) From b1073827da37c9a4b5ab2ec77df2231383390738 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 09:58:51 -0700 Subject: [PATCH 2/9] Sort by clip starts before decoding --- src/torchcodec/samplers/_implem.py | 26 ++++++++++++++++++--- test/conftest.py | 22 ++++++++++++++++++ test/samplers/test_samplers.py | 37 +++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 4 deletions(-) create mode 100644 test/conftest.py diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 832cd4404..b2783ec62 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -1,3 +1,4 @@ +import random from typing import List import torch @@ -26,6 +27,11 @@ def clips_at_random_indices( f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" ) + if len(decoder) < 1: + raise ValueError( + f"Decoder must have at least one frame, found {len(decoder)} frames." + ) + # Determine the span of a clip, i.e. the number of frames (or indices) # between the first and last frame in the clip, both included. This isn't # the same as the number of frames in a clip! @@ -47,9 +53,16 @@ def clips_at_random_indices( low=0, high=last_clip_start_index + 1, size=(num_clips,) ) - # TODO: This is inefficient as we are potentially seeking backwards. - # We should sort by clip start before querying, and re-shuffle. - # Note: we may still have to seek backwards if we have overlapping clips. + # We want to avoid seeking backwards, so we sort the clip start indices + # before decoding the frames, and then re-shuffle the clips afterwards. + # Backward seeks may still happen if there are overlapping clips, i.e. if a + # clip ends after the next one starts. + # TODO: We should use a different strategy to avoid backward seeks: + # - flatten all frames indices, irrespective of their clip + # - sort the indices and dedup + # - decode all frames in index order + # - re-arrange the frames back into their original clips + clip_start_indices = torch.sort(clip_start_indices).values clips = [ decoder.get_frames_at( start=clip_start_index, @@ -59,4 +72,11 @@ def clips_at_random_indices( for clip_start_index in clip_start_indices ] + # This an ugly way to shuffle the clips using pytorch RNG *without* + # affecting the python builtin RNG. + builtin_random_state = random.getstate() + random.seed(torch.randint(0, 2**32, (1,)).item()) + random.shuffle(clips) + random.setstate(builtin_random_state) + return clips diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..6ca08807a --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,22 @@ +import random + +import pytest +import torch + + +@pytest.fixture(autouse=True) +def prevent_leaking_rng(): + # Prevent each test from leaking the rng to all other test when they call + # torch.manual_seed() or random.seed(). + + torch_rng_state = torch.get_rng_state() + builtin_rng_state = random.getstate() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + + yield + + torch.set_rng_state(torch_rng_state) + random.setstate(builtin_rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 2bf056b2b..0023e9c4d 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -1,3 +1,4 @@ +import random import re import pytest @@ -5,7 +6,7 @@ from torchcodec.decoders import FrameBatch, SimpleVideoDecoder from torchcodec.samplers import clips_at_random_indices -from ..utils import NASA_VIDEO +from ..utils import assert_tensor_equal, NASA_VIDEO @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) @@ -43,6 +44,40 @@ def test_random_sampler(num_indices_between_frames): ) +def test_random_sampler_randomness(): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 5 + + builtin_random_state_start = random.getstate() + + torch.manual_seed(0) + clips_1 = clips_at_random_indices(decoder, num_clips=num_clips) + + # Assert the clip starts aren't sorted, to make sure we haven't messed up + # the implementation. (This may fail if we're unlucky, but we hard-coded a + # seed, so it will always pass.) + clip_starts = [clip.pts_seconds.item() for clip in clips_1] + assert sorted(clip_starts) != clip_starts + + # Call the same sampler again with the same seed, expect same results + torch.manual_seed(0) + clips_2 = clips_at_random_indices(decoder, num_clips=num_clips) + for clip_1, clip_2 in zip(clips_1, clips_2): + assert_tensor_equal(clip_1.data, clip_2.data) + assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) + assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds) + + # Call with a different seed, expect different results + torch.manual_seed(1) + clips_3 = clips_at_random_indices(decoder, num_clips=num_clips) + with pytest.raises(AssertionError, match="not equal"): + assert_tensor_equal(clips_1[0].data, clips_3[0].data) + + # Make sure we didn't alter the builting Python RNG + builtin_random_state_end = random.getstate() + assert builtin_random_state_start == builtin_random_state_end + + def test_random_sampler_errors(): decoder = SimpleVideoDecoder(NASA_VIDEO.path) with pytest.raises( From 63bbcfaf99c454c3371416b67cb29125c69446ef Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 11:00:00 -0700 Subject: [PATCH 3/9] Add sampling range parameter, needs testing --- src/torchcodec/samplers/_implem.py | 46 ++++++++++++++++++++++++++++-- test/samplers/test_samplers.py | 20 +++++++++++++ 2 files changed, 63 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index b2783ec62..b07be4d77 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -1,5 +1,5 @@ import random -from typing import List +from typing import List, Optional import torch @@ -9,12 +9,46 @@ ) +def _validate_sampling_range( + *, sampling_range_start, sampling_range_end, num_frames, clip_span +): + if sampling_range_start < 0: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be non-negative." + ) + + # TODO: or max(sampling_range_start, num_frames - 1)? + sampling_range_start = sampling_range_start % num_frames + + if sampling_range_end is None: + sampling_range_end = num_frames - clip_span + 1 + if sampling_range_start > sampling_range_end: + raise ValueError( + f"We determined that sampling_range_end should be {sampling_range_end}, " + f"but it is smaller than sampling_range_start ({sampling_range_start})." + ) + else: + if sampling_range_end < 0: + # Support negative values so that -1 means last frame. + # TODO: do we want to wrap around if sampling_range_end < -num_frames ? + sampling_range_end = num_frames + sampling_range_end + 1 + if sampling_range_start > sampling_range_end: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be smaller than " + f"sampling_range_end ({sampling_range_end})." + ) + + return sampling_range_start, sampling_range_end + + def clips_at_random_indices( decoder: SimpleVideoDecoder, *, num_clips: int = 1, num_frames_per_clip: int = 1, num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). ) -> List[FrameBatch]: if num_clips <= 0: raise ValueError(f"num_clips ({num_clips}) must be strictly positive") @@ -48,9 +82,15 @@ def clips_at_random_indices( f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" ) - last_clip_start_index = len(decoder) - clip_span + sampling_range_start, sampling_range_end = _validate_sampling_range( + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames=len(decoder), + clip_span=clip_span, + ) + clip_start_indices = torch.randint( - low=0, high=last_clip_start_index + 1, size=(num_clips,) + low=sampling_range_start, high=sampling_range_end, size=(num_clips,) ) # We want to avoid seeking backwards, so we sort the clip start indices diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 0023e9c4d..02e3a6d9e 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -109,3 +109,23 @@ def test_random_sampler_errors(): clips_at_random_indices( decoder, num_frames_per_clip=2, num_indices_between_frames=1000 ) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (-1) must be non-negative") + ): + clips_at_random_indices(decoder, sampling_range_start=-1) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (4) must be smaller than") + ): + clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=0) + + with pytest.raises( + ValueError, match="We determined that sampling_range_end should" + ): + clips_at_random_indices( + decoder, + num_frames_per_clip=10, + sampling_range_start=len(decoder) - 1, + sampling_range_end=None, + ) From b3bb8958e96d673d980fd6717f16b7e2efadb128 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Sep 2024 07:48:05 -0700 Subject: [PATCH 4/9] Added tests for sampling range parameters --- src/torchcodec/samplers/_implem.py | 14 +++--- test/samplers/test_samplers.py | 81 ++++++++++++++++++++++++++++-- 2 files changed, 82 insertions(+), 13 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index b07be4d77..2beb56e36 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -13,26 +13,24 @@ def _validate_sampling_range( *, sampling_range_start, sampling_range_end, num_frames, clip_span ): if sampling_range_start < 0: - raise ValueError( - f"sampling_range_start ({sampling_range_start}) must be non-negative." - ) + sampling_range_start = num_frames + sampling_range_start # TODO: or max(sampling_range_start, num_frames - 1)? sampling_range_start = sampling_range_start % num_frames if sampling_range_end is None: sampling_range_end = num_frames - clip_span + 1 - if sampling_range_start > sampling_range_end: + if sampling_range_start >= sampling_range_end: raise ValueError( f"We determined that sampling_range_end should be {sampling_range_end}, " - f"but it is smaller than sampling_range_start ({sampling_range_start})." + "but it is smaller than or equal to sampling_range_start " + f"({sampling_range_start})." ) else: if sampling_range_end < 0: # Support negative values so that -1 means last frame. - # TODO: do we want to wrap around if sampling_range_end < -num_frames ? - sampling_range_end = num_frames + sampling_range_end + 1 - if sampling_range_start > sampling_range_end: + sampling_range_end = num_frames + sampling_range_end + if sampling_range_start >= sampling_range_end: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be smaller than " f"sampling_range_end ({sampling_range_end})." diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 02e3a6d9e..0d0c463c1 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -1,3 +1,4 @@ +import contextlib import random import re @@ -44,6 +45,74 @@ def test_random_sampler(num_indices_between_frames): ) +@pytest.mark.parametrize( + "sampling_range_start, sampling_range_end, assert_all_equal", + ( + (10, 11, True), + (10, 12, False), + ), +) +def test_random_sampler_range( + sampling_range_start, sampling_range_end, assert_all_equal +): + # Test the sampling_range_start and sampling_range_end parameters by + # asserting that all clips are equal if the sampling range is of size 1, + # and that they are not all equal if the sampling range is of size 2. + + # Since this has a low but non-zero probability of failing, we hard-code a + # seed that works. + torch.manual_seed(0) + + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + + clips = clips_at_random_indices( + decoder, + num_clips=10, + num_frames_per_clip=2, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + ) + + cm = ( + contextlib.nullcontext() + if assert_all_equal + else pytest.raises(AssertionError, match="Tensor-likes are not equal!") + ) + with cm: + for clip in clips: + assert_tensor_equal(clip.data, clips[0].data) + + +def test_random_sampler_range_negative(): + # Test the passing negative values for sampling_range_start and + # sampling_range_end is the same as passing `len(decoder) - val` + + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + + clips_1 = clips_at_random_indices( + decoder, + num_clips=10, + num_frames_per_clip=2, + sampling_range_start=len(decoder) - 100, + sampling_range_end=len(decoder) - 99, + ) + + clips_2 = clips_at_random_indices( + decoder, + num_clips=10, + num_frames_per_clip=2, + sampling_range_start=-100, + sampling_range_end=-99, + ) + + # There is only one unique clip in clips_1... + for clip in clips_1: + assert_tensor_equal(clip.data, clips_1[0].data) + # ... and it's the same that's in clips_2 + for clip in clips_2: + assert_tensor_equal(clip.data, clips_1[0].data) + + def test_random_sampler_randomness(): decoder = SimpleVideoDecoder(NASA_VIDEO.path) num_clips = 5 @@ -73,7 +142,7 @@ def test_random_sampler_randomness(): with pytest.raises(AssertionError, match="not equal"): assert_tensor_equal(clips_1[0].data, clips_3[0].data) - # Make sure we didn't alter the builting Python RNG + # Make sure we didn't alter the builtin Python RNG builtin_random_state_end = random.getstate() assert builtin_random_state_start == builtin_random_state_end @@ -111,14 +180,16 @@ def test_random_sampler_errors(): ) with pytest.raises( - ValueError, match=re.escape("sampling_range_start (-1) must be non-negative") + ValueError, match=re.escape("sampling_range_start (4) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=-1) + clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) with pytest.raises( - ValueError, match=re.escape("sampling_range_start (4) must be smaller than") + ValueError, match=re.escape("sampling_range_start (290) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=0) + clips_at_random_indices( + decoder, sampling_range_start=-100, sampling_range_end=-100 + ) with pytest.raises( ValueError, match="We determined that sampling_range_end should" From 4d5da2038046a847379d0a3495161d9621784098 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Sep 2024 08:01:56 -0700 Subject: [PATCH 5/9] macos fix --- test/samplers/test_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 0d0c463c1..4a43071fa 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -76,7 +76,7 @@ def test_random_sampler_range( cm = ( contextlib.nullcontext() if assert_all_equal - else pytest.raises(AssertionError, match="Tensor-likes are not equal!") + else pytest.raises(AssertionError, match="Tensor-likes are not") ) with cm: for clip in clips: From 78eebd5ed73ad659f7e293df740872ec1685a77e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Sep 2024 08:08:16 -0700 Subject: [PATCH 6/9] Refactoring --- src/torchcodec/samplers/_implem.py | 72 +++++++++++++++++++----------- 1 file changed, 45 insertions(+), 27 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 2beb56e36..f76554014 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -3,10 +3,27 @@ import torch -from torchcodec.decoders import ( # TODO: move FrameBatch to torchcodec.FrameBatch? - FrameBatch, - SimpleVideoDecoder, -) +from torchcodec.decoders import FrameBatch, SimpleVideoDecoder + + +def _validate_params( + *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames +): + if len(decoder) < 1: + raise ValueError( + f"Decoder must have at least one frame, found {len(decoder)} frames." + ) + + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) def _validate_sampling_range( @@ -39,6 +56,20 @@ def _validate_sampling_range( return sampling_range_start, sampling_range_end +def get_clip_span(*, num_indices_between_frames, num_frames_per_clip): + """Return the span of a clip, i.e. the number of frames (or indices) + between the first and last frame in the clip, both included. + + This isn't the same as the number of frames in a clip! + Example: f means a frame in the clip, x means a frame excluded from the clip + num_frames_per_clip = 4 + num_indices_between_frames = 1, clip = ffff , span = 4 + num_indices_between_frames = 2, clip = fxfxfxf , span = 7 + num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 + """ + return num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + def clips_at_random_indices( decoder: SimpleVideoDecoder, *, @@ -48,31 +79,18 @@ def clips_at_random_indices( sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). ) -> List[FrameBatch]: - if num_clips <= 0: - raise ValueError(f"num_clips ({num_clips}) must be strictly positive") - if num_frames_per_clip <= 0: - raise ValueError( - f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" - ) - if num_indices_between_frames <= 0: - raise ValueError( - f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" - ) - if len(decoder) < 1: - raise ValueError( - f"Decoder must have at least one frame, found {len(decoder)} frames." - ) + _validate_params( + decoder=decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) - # Determine the span of a clip, i.e. the number of frames (or indices) - # between the first and last frame in the clip, both included. This isn't - # the same as the number of frames in a clip! - # Example: f means a frame in the clip, x means a frame excluded from the clip - # num_frames_per_clip = 4 - # num_indices_between_frames = 1, clip = ffff , span = 4 - # num_indices_between_frames = 2, clip = fxfxfxf , span = 7 - # num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 - clip_span = num_indices_between_frames * (num_frames_per_clip - 1) + 1 + clip_span = get_clip_span( + num_indices_between_frames=num_indices_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) # TODO: We should probably not error. if clip_span > len(decoder): From 623f3295b2eb5ebcc22242d9f1d6e93b776b4193 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Sep 2024 08:11:21 -0700 Subject: [PATCH 7/9] _get_clip_span should be private --- src/torchcodec/samplers/_implem.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index f76554014..3883711cd 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -56,7 +56,7 @@ def _validate_sampling_range( return sampling_range_start, sampling_range_end -def get_clip_span(*, num_indices_between_frames, num_frames_per_clip): +def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): """Return the span of a clip, i.e. the number of frames (or indices) between the first and last frame in the clip, both included. @@ -87,7 +87,7 @@ def clips_at_random_indices( num_indices_between_frames=num_indices_between_frames, ) - clip_span = get_clip_span( + clip_span = _get_clip_span( num_indices_between_frames=num_indices_between_frames, num_frames_per_clip=num_frames_per_clip, ) From 179de70ffef6e178d75c17b46e65d4d3bd4f89ec Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Sep 2024 08:22:24 -0700 Subject: [PATCH 8/9] macos again --- test/samplers/test_samplers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 4a43071fa..40e3dc4d2 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -139,7 +139,7 @@ def test_random_sampler_randomness(): # Call with a different seed, expect different results torch.manual_seed(1) clips_3 = clips_at_random_indices(decoder, num_clips=num_clips) - with pytest.raises(AssertionError, match="not equal"): + with pytest.raises(AssertionError, match="Tensor-likes are not"): assert_tensor_equal(clips_1[0].data, clips_3[0].data) # Make sure we didn't alter the builtin Python RNG From 310da2b86cc4de5247a80375eb07d13ef6ec619e Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 1 Oct 2024 02:59:57 -0700 Subject: [PATCH 9/9] Address comments --- src/torchcodec/samplers/_implem.py | 8 ++++++-- test/samplers/test_samplers.py | 12 ++++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 3883711cd..45cc36421 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -32,8 +32,11 @@ def _validate_sampling_range( if sampling_range_start < 0: sampling_range_start = num_frames + sampling_range_start - # TODO: or max(sampling_range_start, num_frames - 1)? - sampling_range_start = sampling_range_start % num_frames + if sampling_range_start >= num_frames: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be smaller than " + f"the number of frames ({num_frames})." + ) if sampling_range_end is None: sampling_range_end = num_frames - clip_span + 1 @@ -47,6 +50,7 @@ def _validate_sampling_range( if sampling_range_end < 0: # Support negative values so that -1 means last frame. sampling_range_end = num_frames + sampling_range_end + sampling_range_end = min(sampling_range_end, num_frames) if sampling_range_start >= sampling_range_end: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be smaller than " diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 40e3dc4d2..57b9da717 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -59,8 +59,8 @@ def test_random_sampler_range( # asserting that all clips are equal if the sampling range is of size 1, # and that they are not all equal if the sampling range is of size 2. - # Since this has a low but non-zero probability of failing, we hard-code a - # seed that works. + # When size=2 there's still a (small) non-zero probability of sampling the + # same indices for clip starts, so we hard-code a seed that works. torch.manual_seed(0) decoder = SimpleVideoDecoder(NASA_VIDEO.path) @@ -73,6 +73,9 @@ def test_random_sampler_range( sampling_range_end=sampling_range_end, ) + # This context manager is used to ensure that the call to + # assert_tensor_equal() below either passes (nullcontext) or fails + # (pytest.raises) cm = ( contextlib.nullcontext() if assert_all_equal @@ -179,6 +182,11 @@ def test_random_sampler_errors(): decoder, num_frames_per_clip=2, num_indices_between_frames=1000 ) + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") + ): + clips_at_random_indices(decoder, sampling_range_start=1000) + with pytest.raises( ValueError, match=re.escape("sampling_range_start (4) must be smaller than") ):