diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index a97c32ce0..069f2163a 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import decoders # noqa +from . import decoders, samplers # noqa # noqa __version__ = "0.0.2.dev" diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py new file mode 100644 index 000000000..b5aa1261e --- /dev/null +++ b/src/torchcodec/samplers/__init__.py @@ -0,0 +1 @@ +from ._implem import clips_at_random_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py new file mode 100644 index 000000000..45cc36421 --- /dev/null +++ b/src/torchcodec/samplers/_implem.py @@ -0,0 +1,142 @@ +import random +from typing import List, Optional + +import torch + +from torchcodec.decoders import FrameBatch, SimpleVideoDecoder + + +def _validate_params( + *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames +): + if len(decoder) < 1: + raise ValueError( + f"Decoder must have at least one frame, found {len(decoder)} frames." + ) + + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) + + +def _validate_sampling_range( + *, sampling_range_start, sampling_range_end, num_frames, clip_span +): + if sampling_range_start < 0: + sampling_range_start = num_frames + sampling_range_start + + if sampling_range_start >= num_frames: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be smaller than " + f"the number of frames ({num_frames})." + ) + + if sampling_range_end is None: + sampling_range_end = num_frames - clip_span + 1 + if sampling_range_start >= sampling_range_end: + raise ValueError( + f"We determined that sampling_range_end should be {sampling_range_end}, " + "but it is smaller than or equal to sampling_range_start " + f"({sampling_range_start})." + ) + else: + if sampling_range_end < 0: + # Support negative values so that -1 means last frame. + sampling_range_end = num_frames + sampling_range_end + sampling_range_end = min(sampling_range_end, num_frames) + if sampling_range_start >= sampling_range_end: + raise ValueError( + f"sampling_range_start ({sampling_range_start}) must be smaller than " + f"sampling_range_end ({sampling_range_end})." + ) + + return sampling_range_start, sampling_range_end + + +def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): + """Return the span of a clip, i.e. the number of frames (or indices) + between the first and last frame in the clip, both included. + + This isn't the same as the number of frames in a clip! + Example: f means a frame in the clip, x means a frame excluded from the clip + num_frames_per_clip = 4 + num_indices_between_frames = 1, clip = ffff , span = 4 + num_indices_between_frames = 2, clip = fxfxfxf , span = 7 + num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 + """ + return num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + +def clips_at_random_indices( + decoder: SimpleVideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). +) -> List[FrameBatch]: + + _validate_params( + decoder=decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + + clip_span = _get_clip_span( + num_indices_between_frames=num_indices_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + sampling_range_start, sampling_range_end = _validate_sampling_range( + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames=len(decoder), + clip_span=clip_span, + ) + + clip_start_indices = torch.randint( + low=sampling_range_start, high=sampling_range_end, size=(num_clips,) + ) + + # We want to avoid seeking backwards, so we sort the clip start indices + # before decoding the frames, and then re-shuffle the clips afterwards. + # Backward seeks may still happen if there are overlapping clips, i.e. if a + # clip ends after the next one starts. + # TODO: We should use a different strategy to avoid backward seeks: + # - flatten all frames indices, irrespective of their clip + # - sort the indices and dedup + # - decode all frames in index order + # - re-arrange the frames back into their original clips + clip_start_indices = torch.sort(clip_start_indices).values + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + # This an ugly way to shuffle the clips using pytorch RNG *without* + # affecting the python builtin RNG. + builtin_random_state = random.getstate() + random.seed(torch.randint(0, 2**32, (1,)).item()) + random.shuffle(clips) + random.setstate(builtin_random_state) + + return clips diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 000000000..6ca08807a --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,22 @@ +import random + +import pytest +import torch + + +@pytest.fixture(autouse=True) +def prevent_leaking_rng(): + # Prevent each test from leaking the rng to all other test when they call + # torch.manual_seed() or random.seed(). + + torch_rng_state = torch.get_rng_state() + builtin_rng_state = random.getstate() + if torch.cuda.is_available(): + cuda_rng_state = torch.cuda.get_rng_state() + + yield + + torch.set_rng_state(torch_rng_state) + random.setstate(builtin_rng_state) + if torch.cuda.is_available(): + torch.cuda.set_rng_state(cuda_rng_state) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py new file mode 100644 index 000000000..57b9da717 --- /dev/null +++ b/test/samplers/test_samplers.py @@ -0,0 +1,210 @@ +import contextlib +import random +import re + +import pytest +import torch +from torchcodec.decoders import FrameBatch, SimpleVideoDecoder +from torchcodec.samplers import clips_at_random_indices + +from ..utils import assert_tensor_equal, NASA_VIDEO + + +@pytest.mark.parametrize("num_indices_between_frames", [1, 5]) +def test_random_sampler(num_indices_between_frames): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 2 + num_frames_per_clip = 3 + + clips = clips_at_random_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + + assert isinstance(clips, list) + assert len(clips) == num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + NASA_VIDEO.height, + NASA_VIDEO.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + # Check the num_indices_between_frames parameter by asserting that the + # "time" difference between frames in a clip is the same as the "index" + # distance. + avg_distance_between_frames_seconds = torch.concat( + [clip.pts_seconds.diff() for clip in clips] + ).mean() + assert avg_distance_between_frames_seconds == pytest.approx( + num_indices_between_frames / decoder.metadata.average_fps + ) + + +@pytest.mark.parametrize( + "sampling_range_start, sampling_range_end, assert_all_equal", + ( + (10, 11, True), + (10, 12, False), + ), +) +def test_random_sampler_range( + sampling_range_start, sampling_range_end, assert_all_equal +): + # Test the sampling_range_start and sampling_range_end parameters by + # asserting that all clips are equal if the sampling range is of size 1, + # and that they are not all equal if the sampling range is of size 2. + + # When size=2 there's still a (small) non-zero probability of sampling the + # same indices for clip starts, so we hard-code a seed that works. + torch.manual_seed(0) + + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + + clips = clips_at_random_indices( + decoder, + num_clips=10, + num_frames_per_clip=2, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + ) + + # This context manager is used to ensure that the call to + # assert_tensor_equal() below either passes (nullcontext) or fails + # (pytest.raises) + cm = ( + contextlib.nullcontext() + if assert_all_equal + else pytest.raises(AssertionError, match="Tensor-likes are not") + ) + with cm: + for clip in clips: + assert_tensor_equal(clip.data, clips[0].data) + + +def test_random_sampler_range_negative(): + # Test the passing negative values for sampling_range_start and + # sampling_range_end is the same as passing `len(decoder) - val` + + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + + clips_1 = clips_at_random_indices( + decoder, + num_clips=10, + num_frames_per_clip=2, + sampling_range_start=len(decoder) - 100, + sampling_range_end=len(decoder) - 99, + ) + + clips_2 = clips_at_random_indices( + decoder, + num_clips=10, + num_frames_per_clip=2, + sampling_range_start=-100, + sampling_range_end=-99, + ) + + # There is only one unique clip in clips_1... + for clip in clips_1: + assert_tensor_equal(clip.data, clips_1[0].data) + # ... and it's the same that's in clips_2 + for clip in clips_2: + assert_tensor_equal(clip.data, clips_1[0].data) + + +def test_random_sampler_randomness(): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 5 + + builtin_random_state_start = random.getstate() + + torch.manual_seed(0) + clips_1 = clips_at_random_indices(decoder, num_clips=num_clips) + + # Assert the clip starts aren't sorted, to make sure we haven't messed up + # the implementation. (This may fail if we're unlucky, but we hard-coded a + # seed, so it will always pass.) + clip_starts = [clip.pts_seconds.item() for clip in clips_1] + assert sorted(clip_starts) != clip_starts + + # Call the same sampler again with the same seed, expect same results + torch.manual_seed(0) + clips_2 = clips_at_random_indices(decoder, num_clips=num_clips) + for clip_1, clip_2 in zip(clips_1, clips_2): + assert_tensor_equal(clip_1.data, clip_2.data) + assert_tensor_equal(clip_1.pts_seconds, clip_2.pts_seconds) + assert_tensor_equal(clip_1.duration_seconds, clip_2.duration_seconds) + + # Call with a different seed, expect different results + torch.manual_seed(1) + clips_3 = clips_at_random_indices(decoder, num_clips=num_clips) + with pytest.raises(AssertionError, match="Tensor-likes are not"): + assert_tensor_equal(clips_1[0].data, clips_3[0].data) + + # Make sure we didn't alter the builtin Python RNG + builtin_random_state_end = random.getstate() + assert builtin_random_state_start == builtin_random_state_end + + +def test_random_sampler_errors(): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + with pytest.raises( + ValueError, match=re.escape("num_clips (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_clips=0) + + with pytest.raises( + ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_frames_per_clip=0) + + with pytest.raises( + ValueError, + match=re.escape("num_indices_between_frames (0) must be strictly positive"), + ): + clips_at_random_indices(decoder, num_indices_between_frames=0) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1000) is larger than the number of frames"), + ): + clips_at_random_indices(decoder, num_frames_per_clip=1000) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1001) is larger than the number of frames"), + ): + clips_at_random_indices( + decoder, num_frames_per_clip=2, num_indices_between_frames=1000 + ) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") + ): + clips_at_random_indices(decoder, sampling_range_start=1000) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (4) must be smaller than") + ): + clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) + + with pytest.raises( + ValueError, match=re.escape("sampling_range_start (290) must be smaller than") + ): + clips_at_random_indices( + decoder, sampling_range_start=-100, sampling_range_end=-100 + ) + + with pytest.raises( + ValueError, match="We determined that sampling_range_end should" + ): + clips_at_random_indices( + decoder, + num_frames_per_clip=10, + sampling_range_start=len(decoder) - 1, + sampling_range_end=None, + )