From c090e44b787fb11c0c91528592b6fe8432b11669 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 03:49:39 -0700 Subject: [PATCH 01/22] Add Random index-based clip sampler --- src/torchcodec/__init__.py | 2 +- src/torchcodec/samplers/__init__.py | 1 + src/torchcodec/samplers/_implem.py | 62 +++++++++++++++++++++++ test/samplers/test_samplers.py | 76 +++++++++++++++++++++++++++++ 4 files changed, 140 insertions(+), 1 deletion(-) create mode 100644 src/torchcodec/samplers/__init__.py create mode 100644 src/torchcodec/samplers/_implem.py create mode 100644 test/samplers/test_samplers.py diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index a97c32ce0..069f2163a 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,6 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import decoders # noqa +from . import decoders, samplers # noqa # noqa __version__ = "0.0.2.dev" diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py new file mode 100644 index 000000000..b5aa1261e --- /dev/null +++ b/src/torchcodec/samplers/__init__.py @@ -0,0 +1 @@ +from ._implem import clips_at_random_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py new file mode 100644 index 000000000..832cd4404 --- /dev/null +++ b/src/torchcodec/samplers/_implem.py @@ -0,0 +1,62 @@ +from typing import List + +import torch + +from torchcodec.decoders import ( # TODO: move FrameBatch to torchcodec.FrameBatch? + FrameBatch, + SimpleVideoDecoder, +) + + +def clips_at_random_indices( + decoder: SimpleVideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, +) -> List[FrameBatch]: + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) + + # Determine the span of a clip, i.e. the number of frames (or indices) + # between the first and last frame in the clip, both included. This isn't + # the same as the number of frames in a clip! + # Example: f means a frame in the clip, x means a frame excluded from the clip + # num_frames_per_clip = 4 + # num_indices_between_frames = 1, clip = ffff , span = 4 + # num_indices_between_frames = 2, clip = fxfxfxf , span = 7 + # num_indices_between_frames = 3, clip = fxxfxxfxxf, span = 10 + clip_span = num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + last_clip_start_index = len(decoder) - clip_span + clip_start_indices = torch.randint( + low=0, high=last_clip_start_index + 1, size=(num_clips,) + ) + + # TODO: This is inefficient as we are potentially seeking backwards. + # We should sort by clip start before querying, and re-shuffle. + # Note: we may still have to seek backwards if we have overlapping clips. + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py new file mode 100644 index 000000000..2bf056b2b --- /dev/null +++ b/test/samplers/test_samplers.py @@ -0,0 +1,76 @@ +import re + +import pytest +import torch +from torchcodec.decoders import FrameBatch, SimpleVideoDecoder +from torchcodec.samplers import clips_at_random_indices + +from ..utils import NASA_VIDEO + + +@pytest.mark.parametrize("num_indices_between_frames", [1, 5]) +def test_random_sampler(num_indices_between_frames): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 2 + num_frames_per_clip = 3 + + clips = clips_at_random_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + + assert isinstance(clips, list) + assert len(clips) == num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + NASA_VIDEO.height, + NASA_VIDEO.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + # Check the num_indices_between_frames parameter by asserting that the + # "time" difference between frames in a clip is the same as the "index" + # distance. + avg_distance_between_frames_seconds = torch.concat( + [clip.pts_seconds.diff() for clip in clips] + ).mean() + assert avg_distance_between_frames_seconds == pytest.approx( + num_indices_between_frames / decoder.metadata.average_fps + ) + + +def test_random_sampler_errors(): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + with pytest.raises( + ValueError, match=re.escape("num_clips (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_clips=0) + + with pytest.raises( + ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") + ): + clips_at_random_indices(decoder, num_frames_per_clip=0) + + with pytest.raises( + ValueError, + match=re.escape("num_indices_between_frames (0) must be strictly positive"), + ): + clips_at_random_indices(decoder, num_indices_between_frames=0) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1000) is larger than the number of frames"), + ): + clips_at_random_indices(decoder, num_frames_per_clip=1000) + + with pytest.raises( + ValueError, + match=re.escape("Clip span (1001) is larger than the number of frames"), + ): + clips_at_random_indices( + decoder, num_frames_per_clip=2, num_indices_between_frames=1000 + ) From c61816006b8026e0f1c4844e5d47ed8b801de5e4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 24 Sep 2024 09:01:14 -0700 Subject: [PATCH 02/22] Basic linear sampler --- src/torchcodec/samplers/__init__.py | 2 +- src/torchcodec/samplers/_implem.py | 47 +++++++++++++++++++++++++++++ test/samplers/test_samplers.py | 38 ++++++++++++++++++++++- 3 files changed, 85 insertions(+), 2 deletions(-) diff --git a/src/torchcodec/samplers/__init__.py b/src/torchcodec/samplers/__init__.py index b5aa1261e..5a173c218 100644 --- a/src/torchcodec/samplers/__init__.py +++ b/src/torchcodec/samplers/__init__.py @@ -1 +1 @@ -from ._implem import clips_at_random_indices +from ._implem import clips_at_random_indices, clips_at_regular_indices diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 832cd4404..1a66de0e5 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -60,3 +60,50 @@ def clips_at_random_indices( ] return clips + + +def clips_at_regular_indices( + decoder: SimpleVideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, +) -> List[FrameBatch]: + if num_clips <= 0: + raise ValueError(f"num_clips ({num_clips}) must be strictly positive") + if num_frames_per_clip <= 0: + raise ValueError( + f"num_frames_per_clip ({num_frames_per_clip}) must be strictly positive" + ) + if num_indices_between_frames <= 0: + raise ValueError( + f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" + ) + + clip_span = num_indices_between_frames * (num_frames_per_clip - 1) + 1 + + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + + # TODO what if num_clips > (len(decoder) - clip_span)? We can: + # - wrap around + # - as a real wrap around [0, 1, 2, 3, 0, 1] + # - the way linspace natively does it i.g. [0, 0, 1, 2, 2, 3] + # - truncate and return less than num_clips. + # - error + + clip_start_indices = torch.linspace(0, len(decoder) - clip_span, steps=num_clips, dtype=torch.int) + + clips = [ + decoder.get_frames_at( + start=clip_start_index, + stop=clip_start_index + clip_span, + step=num_indices_between_frames, + ) + for clip_start_index in clip_start_indices + ] + + return clips diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 2bf056b2b..b84aa3cf3 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -3,7 +3,7 @@ import pytest import torch from torchcodec.decoders import FrameBatch, SimpleVideoDecoder -from torchcodec.samplers import clips_at_random_indices +from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices from ..utils import NASA_VIDEO @@ -74,3 +74,39 @@ def test_random_sampler_errors(): clips_at_random_indices( decoder, num_frames_per_clip=2, num_indices_between_frames=1000 ) + + +@pytest.mark.parametrize("num_indices_between_frames", [1])#, 5]) +def test_clips_at_regular_indices(num_indices_between_frames): + decoder = SimpleVideoDecoder(NASA_VIDEO.path) + num_clips = 4 + num_frames_per_clip = 3 + + clips = clips_at_regular_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + ) + print(clips) + + assert isinstance(clips, list) + assert len(clips) == num_clips + assert all(isinstance(clip, FrameBatch) for clip in clips) + expected_clip_data_shape = ( + num_frames_per_clip, + 3, + NASA_VIDEO.height, + NASA_VIDEO.width, + ) + assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + + # # Check the num_indices_between_frames parameter by asserting that the + # # "time" difference between frames in a clip is the same as the "index" + # # distance. + # avg_distance_between_frames_seconds = torch.concat( + # [clip.pts_seconds.diff() for clip in clips] + # ).mean() + # assert avg_distance_between_frames_seconds == pytest.approx( + # num_indices_between_frames / decoder.metadata.average_fps + # ) From 2c5a559fc73c6bf54765096df674fdb5389242b4 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 03:59:06 -0700 Subject: [PATCH 03/22] Add tests --- src/torchcodec/samplers/_implem.py | 17 +++++- test/samplers/test_samplers.py | 91 +++++++++++------------------- 2 files changed, 48 insertions(+), 60 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index b6c20df20..9e390b123 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -148,6 +148,8 @@ def clips_at_regular_indices( num_clips: int = 1, num_frames_per_clip: int = 1, num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). ) -> List[FrameBatch]: _validate_params( @@ -162,6 +164,12 @@ def clips_at_regular_indices( num_frames_per_clip=num_frames_per_clip, ) + # TODO: We should probably not error. + if clip_span > len(decoder): + raise ValueError( + f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + ) + # TODO what if num_clips > (len(decoder) - clip_span)? We can: # - wrap around # - as a real wrap around [0, 1, 2, 3, 0, 1] @@ -169,8 +177,15 @@ def clips_at_regular_indices( # - truncate and return less than num_clips. # - error + sampling_range_start, sampling_range_end = _validate_sampling_range( + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + num_frames=len(decoder), + clip_span=clip_span, + ) + clip_start_indices = torch.linspace( - 0, len(decoder) - clip_span, steps=num_clips, dtype=torch.int + sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) clips = [ diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 194c4b941..097f830aa 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -11,13 +11,14 @@ from ..utils import assert_tensor_equal, NASA_VIDEO +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize("num_indices_between_frames", [1, 5]) -def test_random_sampler(num_indices_between_frames): +def test_sampler(sampler, num_indices_between_frames): decoder = VideoDecoder(NASA_VIDEO.path) - num_clips = 2 + num_clips = 5 num_frames_per_clip = 3 - clips = clips_at_random_indices( + clips = sampler( decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, @@ -35,6 +36,15 @@ def test_random_sampler(num_indices_between_frames): ) assert all(clip.data.shape == expected_clip_data_shape for clip in clips) + if sampler is clips_at_regular_indices: + # assert regular spacing between sampled clips + # Note: need approximate check as actual values typically look like [3.2032, 3.2366, 3.2366, 3.2366] + seconds_between_clip_starts = torch.tensor( + [clip.pts_seconds[0] for clip in clips] + ).diff() + for diff in seconds_between_clip_starts: + assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) + # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. @@ -46,6 +56,7 @@ def test_random_sampler(num_indices_between_frames): ) +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) @pytest.mark.parametrize( "sampling_range_start, sampling_range_end, assert_all_equal", ( @@ -53,8 +64,8 @@ def test_random_sampler(num_indices_between_frames): (10, 12, False), ), ) -def test_random_sampler_range( - sampling_range_start, sampling_range_end, assert_all_equal +def test_sampling_range( + sampler, sampling_range_start, sampling_range_end, assert_all_equal ): # Test the sampling_range_start and sampling_range_end parameters by # asserting that all clips are equal if the sampling range is of size 1, @@ -66,7 +77,7 @@ def test_random_sampler_range( decoder = VideoDecoder(NASA_VIDEO.path) - clips = clips_at_random_indices( + clips = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -87,13 +98,14 @@ def test_random_sampler_range( assert_tensor_equal(clip.data, clips[0].data) -def test_random_sampler_range_negative(): +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_negative(sampler): # Test the passing negative values for sampling_range_start and # sampling_range_end is the same as passing `len(decoder) - val` decoder = VideoDecoder(NASA_VIDEO.path) - clips_1 = clips_at_random_indices( + clips_1 = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -101,7 +113,7 @@ def test_random_sampler_range_negative(): sampling_range_end=len(decoder) - 99, ) - clips_2 = clips_at_random_indices( + clips_2 = sampler( decoder, num_clips=10, num_frames_per_clip=2, @@ -151,97 +163,58 @@ def test_random_sampler_randomness(): assert builtin_random_state_start == builtin_random_state_end -def test_random_sampler_errors(): +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_random_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( ValueError, match=re.escape("num_clips (0) must be strictly positive") ): - clips_at_random_indices(decoder, num_clips=0) + sampler(decoder, num_clips=0) with pytest.raises( ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive") ): - clips_at_random_indices(decoder, num_frames_per_clip=0) + sampler(decoder, num_frames_per_clip=0) with pytest.raises( ValueError, match=re.escape("num_indices_between_frames (0) must be strictly positive"), ): - clips_at_random_indices(decoder, num_indices_between_frames=0) + sampler(decoder, num_indices_between_frames=0) with pytest.raises( ValueError, match=re.escape("Clip span (1000) is larger than the number of frames"), ): - clips_at_random_indices(decoder, num_frames_per_clip=1000) + sampler(decoder, num_frames_per_clip=1000) with pytest.raises( ValueError, match=re.escape("Clip span (1001) is larger than the number of frames"), ): - clips_at_random_indices( - decoder, num_frames_per_clip=2, num_indices_between_frames=1000 - ) + sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000) with pytest.raises( ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=1000) + sampler(decoder, sampling_range_start=1000) with pytest.raises( ValueError, match=re.escape("sampling_range_start (4) must be smaller than") ): - clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4) + sampler(decoder, sampling_range_start=4, sampling_range_end=4) with pytest.raises( ValueError, match=re.escape("sampling_range_start (290) must be smaller than") ): - clips_at_random_indices( - decoder, sampling_range_start=-100, sampling_range_end=-100 - ) + sampler(decoder, sampling_range_start=-100, sampling_range_end=-100) with pytest.raises( ValueError, match="We determined that sampling_range_end should" ): - clips_at_random_indices( + sampler( decoder, num_frames_per_clip=10, sampling_range_start=len(decoder) - 1, sampling_range_end=None, ) - - -@pytest.mark.parametrize("num_indices_between_frames", [1]) # , 5]) -def test_clips_at_regular_indices(num_indices_between_frames): - decoder = VideoDecoder(NASA_VIDEO.path) - num_clips = 4 - num_frames_per_clip = 3 - - clips = clips_at_regular_indices( - decoder, - num_clips=num_clips, - num_frames_per_clip=num_frames_per_clip, - num_indices_between_frames=num_indices_between_frames, - ) - print(clips) - - assert isinstance(clips, list) - assert len(clips) == num_clips - assert all(isinstance(clip, FrameBatch) for clip in clips) - expected_clip_data_shape = ( - num_frames_per_clip, - 3, - NASA_VIDEO.height, - NASA_VIDEO.width, - ) - assert all(clip.data.shape == expected_clip_data_shape for clip in clips) - - # # Check the num_indices_between_frames parameter by asserting that the - # # "time" difference between frames in a clip is the same as the "index" - # # distance. - # avg_distance_between_frames_seconds = torch.concat( - # [clip.pts_seconds.diff() for clip in clips] - # ).mean() - # assert avg_distance_between_frames_seconds == pytest.approx( - # num_indices_between_frames / decoder.metadata.average_fps - # ) From ce4619640679d92862a763f039f8fba815b8b054 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 05:12:45 -0700 Subject: [PATCH 04/22] Handle edge case when num_clips is larger than available sampling range --- src/torchcodec/samplers/_implem.py | 14 ++++++------- test/samplers/test_samplers.py | 33 ++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 9e390b123..75fd6db55 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -170,13 +170,6 @@ def clips_at_regular_indices( f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" ) - # TODO what if num_clips > (len(decoder) - clip_span)? We can: - # - wrap around - # - as a real wrap around [0, 1, 2, 3, 0, 1] - # - the way linspace natively does it i.g. [0, 0, 1, 2, 2, 3] - # - truncate and return less than num_clips. - # - error - sampling_range_start, sampling_range_end = _validate_sampling_range( sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, @@ -184,10 +177,17 @@ def clips_at_regular_indices( clip_span=clip_span, ) + # Note [num clips larger than sampling range] + # If we ask for more clips than there are frames in the sampling range (or + # in the video), we rely on torch.linspace behavior which will return + # duplicated indices. E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) + # returns 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 clip_start_indices = torch.linspace( sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) + # Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. + # See other TODO over there, and apply similar changes here. clips = [ decoder.get_frames_at( start=clip_start_index, diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 097f830aa..6fb3c912a 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -45,6 +45,8 @@ def test_sampler(sampler, num_indices_between_frames): for diff in seconds_between_clip_starts: assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05) + assert (diff > 0).all() # Also assert clips are sorted by start time + # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. @@ -163,6 +165,37 @@ def test_random_sampler_randomness(): assert builtin_random_state_start == builtin_random_state_end +@pytest.mark.parametrize( + "num_clips, sampling_range_size", + ( + # Ask for 50 clips while the sampling range is 10 frames wide + # expect 10 clips with 10 unique starting points. + (10, 10), + # Ask for 50 clips while the sampling range is only 10 frames wide + # expect 50 clips with only 10 unique starting points. + (50, 10), + ), +) +def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_size): + # Test for expected behavior described in Note [num clips larger than sampling range] + decoder = VideoDecoder(NASA_VIDEO.path) + clips = clips_at_regular_indices( + decoder, + num_clips=num_clips, + sampling_range_start=0, + sampling_range_end=sampling_range_size, # because sampling_range_start=0 + ) + + assert len(clips) == num_clips + + clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips]) + assert len(torch.unique(clip_starts_seconds)) == sampling_range_size + + # Assert clips starts are ordered, i.e. the start indices don't just "wrap + # around". They're duplicated *and* ordered. + assert (clip_starts_seconds.diff() >= 0).all() + + @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) def test_random_sampler_errors(sampler): decoder = VideoDecoder(NASA_VIDEO.path) From 5fed662ec0f7f775cb0355ecee13559af6057501 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 06:34:19 -0700 Subject: [PATCH 05/22] WIP --- src/torchcodec/samplers/_implem.py | 48 ++++++++++++++++++++++++++++++ 1 file changed, 48 insertions(+) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 75fd6db55..352364558 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -74,6 +74,11 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): return num_indices_between_frames * (num_frames_per_clip - 1) + 1 +# TODO: What is sampling_range_end? +# - The upper bound of where a clip can *start* +# - The upper bound of where a clip can *end* +# ? +# Assuming this is where it can start... for now. def clips_at_random_indices( decoder: VideoDecoder, *, @@ -82,6 +87,7 @@ def clips_at_random_indices( num_indices_between_frames: int = 1, sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). + policy: str = "repeat_last" # can also be: "wrap", "error" TODO: use Literal ) -> List[FrameBatch]: _validate_params( @@ -113,6 +119,48 @@ def clips_at_random_indices( low=sampling_range_start, high=sampling_range_end, size=(num_clips,) ) + all_clip_indices : list[int] = [] + + def repeat_last_policy(clip_indices): + clip_indices += [clip_indices[-1]] * (num_frames_per_clip - len(clip_indices)) + return clip_indices + + def wrap_policy(clip_indices): + clip_indices += list( + range( + clip_indices[0], + clip_indices[0] + clip_span, + num_indices_between_frames, + ) + ) + return clip_indices + + def error_policy(clip_indices): + raise ValueError(f"TODO") + + print(f"{len(decoder) = }") + print(f"{clip_span = }") + for start_index in clip_start_indices.tolist(): + print(f"{start_index = }") + upper_bound = min(start_index + clip_span, sampling_range_end) + clip_indices = list(range(start_index, upper_bound, num_indices_between_frames)) + print(f"{clip_indices = }") + if len(clip_indices) < num_frames_per_clip: + # TODO clean up this mess + policy_fun = { + "repeat_last": repeat_last_policy, + "wrap": wrap_policy, + "error": error_policy, + }[policy] + clip_indices = policy_fun(clip_indices) + print(f"{clip_indices = }") + all_clip_indices += clip_indices + # if start_index + clip_span < sampling_range_end: + # clip_indices += list(range(start_index + 1, start_index + clip_span, num_indices_between_frames)) + # clip_indices.append() + # print(f"{index = }") + # print(f"{index + clip_span = }") + # We want to avoid seeking backwards, so we sort the clip start indices # before decoding the frames, and then re-shuffle the clips afterwards. # Backward seeks may still happen if there are overlapping clips, i.e. if a From 1873174a84c37993e5e7583a67d1a22621286a98 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 06:40:41 -0700 Subject: [PATCH 06/22] Address comments --- src/torchcodec/samplers/_implem.py | 11 +++++++---- test/samplers/test_samplers.py | 2 +- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 75fd6db55..3a8c46291 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -178,10 +178,13 @@ def clips_at_regular_indices( ) # Note [num clips larger than sampling range] - # If we ask for more clips than there are frames in the sampling range (or - # in the video), we rely on torch.linspace behavior which will return - # duplicated indices. E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) - # returns 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # If we ask for more clips than there are frames in the sampling range or + # in the video, we rely on torch.linspace behavior which will return + # duplicated indices. + # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns + # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # Alternatively we could wrap around, but the current behavior is closer to + # the expected "equally spaced indices" sampling. clip_start_indices = torch.linspace( sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 6fb3c912a..2df7bf79f 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -168,7 +168,7 @@ def test_random_sampler_randomness(): @pytest.mark.parametrize( "num_clips, sampling_range_size", ( - # Ask for 50 clips while the sampling range is 10 frames wide + # Ask for 10 clips while the sampling range is 10 frames wide # expect 10 clips with 10 unique starting points. (10, 10), # Ask for 50 clips while the sampling range is only 10 frames wide From a97afe7a76fac144e4f6962a7ab16007b5c48df2 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 08:36:22 -0700 Subject: [PATCH 07/22] Samplers: add support for edge-case policies --- src/torchcodec/samplers/_implem.py | 267 +++++++++++++---------------- test/samplers/test_samplers.py | 15 +- 2 files changed, 119 insertions(+), 163 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index b3dea8030..7790682e9 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -1,13 +1,12 @@ -import random -from typing import Any, Callable, List, Optional +from typing import Callable, List, Literal, Optional import torch -from torchcodec.decoders import FrameBatch, VideoDecoder +from torchcodec.decoders import Frame, FrameBatch, VideoDecoder def _validate_params( - *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames + *, decoder, num_clips, num_frames_per_clip, num_indices_between_frames, policy ): if len(decoder) < 1: raise ValueError( @@ -25,21 +24,26 @@ def _validate_params( f"num_indices_between_frames ({num_indices_between_frames}) must be strictly positive" ) + if policy not in _POLICY_FUNCTIONS.keys(): + raise ValueError( + f"Invalid policy ({policy}). Supported values are {_POLICY_FUN.keys()}." + ) + def _validate_sampling_range( - *, sampling_range_start, sampling_range_end, num_frames, clip_span + *, sampling_range_start, sampling_range_end, num_frames_in_video, clip_span ): if sampling_range_start < 0: - sampling_range_start = num_frames + sampling_range_start + sampling_range_start = num_frames_in_video + sampling_range_start - if sampling_range_start >= num_frames: + if sampling_range_start >= num_frames_in_video: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be smaller than " - f"the number of frames ({num_frames})." + f"the number of frames ({num_frames_in_video})." ) if sampling_range_end is None: - sampling_range_end = num_frames - clip_span + 1 + sampling_range_end = max(num_frames_in_video - clip_span + 1, 1) if sampling_range_start >= sampling_range_end: raise ValueError( f"We determined that sampling_range_end should be {sampling_range_end}, " @@ -49,8 +53,8 @@ def _validate_sampling_range( else: if sampling_range_end < 0: # Support negative values so that -1 means last frame. - sampling_range_end = num_frames + sampling_range_end - sampling_range_end = min(sampling_range_end, num_frames) + sampling_range_end = num_frames_in_video + sampling_range_end + sampling_range_end = min(sampling_range_end, num_frames_in_video) if sampling_range_start >= sampling_range_end: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be smaller than " @@ -74,126 +78,105 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): return num_indices_between_frames * (num_frames_per_clip - 1) + 1 -def _repeat_last_policy(clip_indices, *, num_frames_per_clip): - clip_indices += [clip_indices[-1]] * (num_frames_per_clip - len(clip_indices)) - return clip_indices +def _repeat_last_policy( + frame_indices: list[int], *, num_frames_per_clip: int +) -> list[int]: + # frame_indices = [1, 2, 3], num_frames_per_clip = 5 + # output = [1, 2, 3, 3, 3] + frame_indices += [frame_indices[-1]] * (num_frames_per_clip - len(frame_indices)) + return frame_indices -def _wrap_policy(clip_indices, *, num_frames_per_clip): - return (clip_indices * (num_frames_per_clip // len(clip_indices) + 1))[:num_frames_per_clip] +def _wrap_policy(frame_indices: list[int], *, num_frames_per_clip: int) -> list[int]: + # frame_indices = [1, 2, 3], num_frames_per_clip = 5 + # output = [1, 2, 3, 1, 2] + return (frame_indices * (num_frames_per_clip // len(frame_indices) + 1))[ + :num_frames_per_clip + ] -def _error_policy(clip_indices, **kwargs): - raise ValueError("TODO") +def _error_policy(**kwargs): + raise ValueError("TODO nice error message here") -_POLICY_FUNCTIONS: dict[str, Callable[Any, list[int]]] = { +_POLICY_FUNCTION_TYPE = Callable[[list[int], int], list[int]] +_POLICY_FUNCTIONS: dict[str, _POLICY_FUNCTION_TYPE] = { "repeat_last": _repeat_last_policy, "wrap": _wrap_policy, "error": _error_policy, } -# TODO: What is sampling_range_end? -# - The upper bound of where a clip can *start* -# - The upper bound of where a clip can *end* -# ? -# Assuming this is where it can start... for now. -def clips_at_random_indices( - decoder: VideoDecoder, +def _build_all_clips_indices( *, - num_clips: int = 1, - num_frames_per_clip: int = 1, - num_indices_between_frames: int = 1, - sampling_range_start: int = 0, - sampling_range_end: Optional[int] = None, # interval is [start, end). - policy: str = "repeat_last", # can also be: "wrap", "error" TODO: use Literal -) -> List[FrameBatch]: - - _validate_params( - decoder=decoder, - num_clips=num_clips, - num_frames_per_clip=num_frames_per_clip, - num_indices_between_frames=num_indices_between_frames, - ) + clip_start_indices: list[int], + num_frames_per_clip: int, + num_indices_between_frames: int, + num_frames_in_video: int, + policy_fun: _POLICY_FUNCTION_TYPE, +) -> list[int]: + # From the clip start indices (f_00, f10, f20, ...) + # and from the rest of the parameters, return the list of all the frame + # indices, within all clips. + # I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...] + # where f_01 is the index of frame 1 in clip 0. + # + # All clips in the output are of length num_frames_per_clip (=4 in example + # above). When the frame indices go beyond the video, we force the frame + # indices back to valid values by applying the user's policy (wrap, repeat, + # etc.). + all_clips_indices: list[int] = [] clip_span = _get_clip_span( num_indices_between_frames=num_indices_between_frames, num_frames_per_clip=num_frames_per_clip, ) - # TODO: We should probably not error. - if clip_span > len(decoder): - raise ValueError( - f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" + for start_index in clip_start_indices: + frame_index_upper_bound = min(start_index + clip_span, num_frames_in_video) + frame_indices = list( + range(start_index, frame_index_upper_bound, num_indices_between_frames) ) + if len(frame_indices) < num_frames_per_clip: + frame_indices = policy_fun( + frame_indices, num_frames_per_clip=num_frames_per_clip + ) + all_clips_indices += frame_indices + return all_clips_indices - sampling_range_start, sampling_range_end = _validate_sampling_range( - sampling_range_start=sampling_range_start, - sampling_range_end=sampling_range_end, - num_frames=len(decoder), - clip_span=clip_span, - ) - clip_start_indices = torch.randint( - low=sampling_range_start, high=sampling_range_end, size=(num_clips,) - ) +def _decode_all_clips_indices( + decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int +) -> list[FrameBatch]: - if policy not in _POLICY_FUNCTIONS.keys(): - raise ValueError( - f"Invalid policy ({policy}). Supported values are {_POLICY_FUN.keys()}." - ) + def chunk_list(lst, chunk_size): + # return list of sublists of length chunk_size + return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] - all_clip_indices: list[int] = [] - - print(f"{len(decoder) = }") - print(f"{clip_span = }") - for start_index in clip_start_indices.tolist(): - print(f"{start_index = }") - upper_bound = min(start_index + clip_span, sampling_range_end) - clip_indices = list(range(start_index, upper_bound, num_indices_between_frames)) - print(f"{clip_indices = }") - if len(clip_indices) < num_frames_per_clip: - policy_fun = _POLICY_FUNCTIONS[policy] - clip_indices = policy_fun(clip_indices, num_frames_per_clip=num_frames_per_clip) - print(f"{clip_indices = }") - all_clip_indices += clip_indices - # if start_index + clip_span < sampling_range_end: - # clip_indices += list(range(start_index + 1, start_index + clip_span, num_indices_between_frames)) - # clip_indices.append() - # print(f"{index = }") - # print(f"{index + clip_span = }") - - # We want to avoid seeking backwards, so we sort the clip start indices - # before decoding the frames, and then re-shuffle the clips afterwards. - # Backward seeks may still happen if there are overlapping clips, i.e. if a - # clip ends after the next one starts. - # TODO: We should use a different strategy to avoid backward seeks: - # - flatten all frames indices, irrespective of their clip - # - sort the indices and dedup - # - decode all frames in index order - # - re-arrange the frames back into their original clips - clip_start_indices = torch.sort(clip_start_indices).values - clips = [ - decoder.get_frames_at( - start=clip_start_index, - stop=clip_start_index + clip_span, - step=num_indices_between_frames, + def to_framebatch(frames: list[Frame]) -> FrameBatch: + data = torch.stack([frame.data for frame in frames]) + pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) + duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) + return FrameBatch( + data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds ) - for clip_start_index in clip_start_indices - ] - # This an ugly way to shuffle the clips using pytorch RNG *without* - # affecting the python builtin RNG. - builtin_random_state = random.getstate() - random.seed(torch.randint(0, 2**32, (1,)).item()) - random.shuffle(clips) - random.setstate(builtin_random_state) + all_frames: list[Frame] = [ + decoder.get_frame_at(index) for index in all_clips_indices + ] + all_clips: list[list[Frame]] = chunk_list( + all_frames, chunk_size=num_frames_per_clip + ) - return clips + return [to_framebatch(clip) for clip in all_clips] -def clips_at_regular_indices( +# TODO: What is sampling_range_end? +# - The upper bound of where a clip can *start* +# - The upper bound of where a clip can *end* +# ? +# This has to be the upper bound of where a clip can start... right? +def clips_at_random_indices( decoder: VideoDecoder, *, num_clips: int = 1, @@ -201,6 +184,7 @@ def clips_at_regular_indices( num_indices_between_frames: int = 1, sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: _validate_params( @@ -208,6 +192,7 @@ def clips_at_regular_indices( num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, + policy=policy, ) clip_span = _get_clip_span( @@ -215,43 +200,29 @@ def clips_at_regular_indices( num_frames_per_clip=num_frames_per_clip, ) - # TODO: We should probably not error. - if clip_span > len(decoder): - raise ValueError( - f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" - ) - sampling_range_start, sampling_range_end = _validate_sampling_range( sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, - num_frames=len(decoder), + num_frames_in_video=len(decoder), clip_span=clip_span, ) - # Note [num clips larger than sampling range] - # If we ask for more clips than there are frames in the sampling range or - # in the video, we rely on torch.linspace behavior which will return - # duplicated indices. - # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns - # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 - # Alternatively we could wrap around, but the current behavior is closer to - # the expected "equally spaced indices" sampling. - clip_start_indices = torch.linspace( - sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int + clip_start_indices = torch.randint( + low=sampling_range_start, high=sampling_range_end, size=(num_clips,) ) - # Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. - # See other TODO over there, and apply similar changes here. - clips = [ - decoder.get_frames_at( - start=clip_start_index, - stop=clip_start_index + clip_span, - step=num_indices_between_frames, - ) - for clip_start_index in clip_start_indices - ] - - return clips + all_clips_indices = _build_all_clips_indices( + clip_start_indices=clip_start_indices, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + num_frames_in_video=len(decoder), + policy_fun=_POLICY_FUNCTIONS[policy], + ) + return _decode_all_clips_indices( + decoder, + all_clips_indices=all_clips_indices, + num_frames_per_clip=num_frames_per_clip, + ) def clips_at_regular_indices( @@ -262,6 +233,7 @@ def clips_at_regular_indices( num_indices_between_frames: int = 1, sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: _validate_params( @@ -269,6 +241,7 @@ def clips_at_regular_indices( num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, + policy=policy, ) clip_span = _get_clip_span( @@ -276,16 +249,10 @@ def clips_at_regular_indices( num_frames_per_clip=num_frames_per_clip, ) - # TODO: We should probably not error. - if clip_span > len(decoder): - raise ValueError( - f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})" - ) - sampling_range_start, sampling_range_end = _validate_sampling_range( sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, - num_frames=len(decoder), + num_frames_in_video=len(decoder), clip_span=clip_span, ) @@ -301,15 +268,15 @@ def clips_at_regular_indices( sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int ) - # Similarly to clip_at_random_indices, there may be backward seeks if clips overlap. - # See other TODO over there, and apply similar changes here. - clips = [ - decoder.get_frames_at( - start=clip_start_index, - stop=clip_start_index + clip_span, - step=num_indices_between_frames, - ) - for clip_start_index in clip_start_indices - ] - - return clips + all_clips_indices = _build_all_clips_indices( + clip_start_indices=clip_start_indices, + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, + num_frames_in_video=len(decoder), + policy_fun=_POLICY_FUNCTIONS[policy], + ) + return _decode_all_clips_indices( + decoder, + all_clips_indices=all_clips_indices, + num_frames_per_clip=num_frames_per_clip, + ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 2df7bf79f..dae8bfa56 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -50,11 +50,12 @@ def test_sampler(sampler, num_indices_between_frames): # Check the num_indices_between_frames parameter by asserting that the # "time" difference between frames in a clip is the same as the "index" # distance. + avg_distance_between_frames_seconds = torch.concat( [clip.pts_seconds.diff() for clip in clips] ).mean() assert avg_distance_between_frames_seconds == pytest.approx( - num_indices_between_frames / decoder.metadata.average_fps + num_indices_between_frames / decoder.metadata.average_fps, abs=1e-5 ) @@ -215,18 +216,6 @@ def test_random_sampler_errors(sampler): ): sampler(decoder, num_indices_between_frames=0) - with pytest.raises( - ValueError, - match=re.escape("Clip span (1000) is larger than the number of frames"), - ): - sampler(decoder, num_frames_per_clip=1000) - - with pytest.raises( - ValueError, - match=re.escape("Clip span (1001) is larger than the number of frames"), - ): - sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000) - with pytest.raises( ValueError, match=re.escape("sampling_range_start (1000) must be smaller than") ): From 83c6763e3477f1a9372d8213d10dafd64f49e7f7 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 4 Oct 2024 08:53:15 -0700 Subject: [PATCH 08/22] Refactor + comments --- src/torchcodec/samplers/_implem.py | 108 ++++++++++++++++------------- 1 file changed, 58 insertions(+), 50 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 7790682e9..83c05de37 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -115,16 +115,16 @@ def _build_all_clips_indices( num_frames_in_video: int, policy_fun: _POLICY_FUNCTION_TYPE, ) -> list[int]: - # From the clip start indices (f_00, f10, f20, ...) + # From the clip_start_indices [f_00, f10, f20, ...] # and from the rest of the parameters, return the list of all the frame - # indices, within all clips. + # indices that make up all the clips. # I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...] # where f_01 is the index of frame 1 in clip 0. # # All clips in the output are of length num_frames_per_clip (=4 in example - # above). When the frame indices go beyond the video, we force the frame - # indices back to valid values by applying the user's policy (wrap, repeat, - # etc.). + # above). When the frame indices go beyond num_frames_in_video, we force the + # frame indices back to valid values by applying the user's policy (wrap, + # repeat, etc.). all_clips_indices: list[int] = [] clip_span = _get_clip_span( @@ -148,6 +148,12 @@ def _build_all_clips_indices( def _decode_all_clips_indices( decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int ) -> list[FrameBatch]: + # This takes the list of all the frames to decode, decode all the frames, + # and then packs them into clips of length num_frames_per_clip. + # This is slow, unoptimized, and u.g.l.y. It is not meant to stay. + # TODO: + # - sort the frames to avoid backward seeks, dedup, decode, and re-organize frames. + # - write most of this in C++ def chunk_list(lst, chunk_size): # return list of sublists of length chunk_size @@ -161,22 +167,18 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds ) - all_frames: list[Frame] = [ + all_decoded_frames: list[Frame] = [ decoder.get_frame_at(index) for index in all_clips_indices ] all_clips: list[list[Frame]] = chunk_list( - all_frames, chunk_size=num_frames_per_clip + all_decoded_frames, chunk_size=num_frames_per_clip ) return [to_framebatch(clip) for clip in all_clips] -# TODO: What is sampling_range_end? -# - The upper bound of where a clip can *start* -# - The upper bound of where a clip can *end* -# ? -# This has to be the upper bound of where a clip can start... right? -def clips_at_random_indices( +def _abstract_sampler( + kind: Literal["random", "regular"], decoder: VideoDecoder, *, num_clips: int = 1, @@ -184,6 +186,8 @@ def clips_at_random_indices( num_indices_between_frames: int = 1, sampling_range_start: int = 0, sampling_range_end: Optional[int] = None, # interval is [start, end). + # Important note: sampling_range_end defines the upper bound of where a clip + # can *start*, not where a clip can end. policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: @@ -207,9 +211,25 @@ def clips_at_random_indices( clip_span=clip_span, ) - clip_start_indices = torch.randint( - low=sampling_range_start, high=sampling_range_end, size=(num_clips,) - ) + if kind == "random": + clip_start_indices = torch.randint( + low=sampling_range_start, high=sampling_range_end, size=(num_clips,) + ) + else: + # Note [num clips larger than sampling range] + # If we ask for more clips than there are frames in the sampling range or + # in the video, we rely on torch.linspace behavior which will return + # duplicated indices. + # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns + # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 + # Alternatively we could wrap around, but the current behavior is closer to + # the expected "equally spaced indices" sampling. + clip_start_indices = torch.linspace( + sampling_range_start, + sampling_range_end - 1, + steps=num_clips, + dtype=torch.int, + ) all_clips_indices = _build_all_clips_indices( clip_start_indices=clip_start_indices, @@ -225,7 +245,7 @@ def clips_at_random_indices( ) -def clips_at_regular_indices( +def clips_at_random_indices( decoder: VideoDecoder, *, num_clips: int = 1, @@ -235,48 +255,36 @@ def clips_at_regular_indices( sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: - - _validate_params( + return _abstract_sampler( + kind="random", decoder=decoder, num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, - policy=policy, - ) - - clip_span = _get_clip_span( - num_indices_between_frames=num_indices_between_frames, - num_frames_per_clip=num_frames_per_clip, - ) - - sampling_range_start, sampling_range_end = _validate_sampling_range( sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, - num_frames_in_video=len(decoder), - clip_span=clip_span, + policy=policy, ) - # Note [num clips larger than sampling range] - # If we ask for more clips than there are frames in the sampling range or - # in the video, we rely on torch.linspace behavior which will return - # duplicated indices. - # E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns - # 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10 - # Alternatively we could wrap around, but the current behavior is closer to - # the expected "equally spaced indices" sampling. - clip_start_indices = torch.linspace( - sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int - ) - all_clips_indices = _build_all_clips_indices( - clip_start_indices=clip_start_indices, +def clips_at_regular_indices( + decoder: VideoDecoder, + *, + num_clips: int = 1, + num_frames_per_clip: int = 1, + num_indices_between_frames: int = 1, + sampling_range_start: int = 0, + sampling_range_end: Optional[int] = None, # interval is [start, end). + policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", +) -> List[FrameBatch]: + + return _abstract_sampler( + kind="regular", + decoder=decoder, + num_clips=num_clips, num_frames_per_clip=num_frames_per_clip, num_indices_between_frames=num_indices_between_frames, - num_frames_in_video=len(decoder), - policy_fun=_POLICY_FUNCTIONS[policy], - ) - return _decode_all_clips_indices( - decoder, - all_clips_indices=all_clips_indices, - num_frames_per_clip=num_frames_per_clip, + sampling_range_start=sampling_range_start, + sampling_range_end=sampling_range_end, + policy=policy, ) From 71a839a40bb0e39ce64b272ea565aea04ed34c33 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 03:21:21 -0700 Subject: [PATCH 09/22] Add tests --- src/torchcodec/samplers/_implem.py | 20 ++-- test/samplers/test_samplers.py | 155 +++++++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 83c05de37..99034b277 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -26,7 +26,7 @@ def _validate_params( if policy not in _POLICY_FUNCTIONS.keys(): raise ValueError( - f"Invalid policy ({policy}). Supported values are {_POLICY_FUN.keys()}." + f"Invalid policy ({policy}). Supported values are {_POLICY_FUNCTIONS.keys()}." ) @@ -79,7 +79,7 @@ def _get_clip_span(*, num_indices_between_frames, num_frames_per_clip): def _repeat_last_policy( - frame_indices: list[int], *, num_frames_per_clip: int + frame_indices: list[int], num_frames_per_clip: int ) -> list[int]: # frame_indices = [1, 2, 3], num_frames_per_clip = 5 # output = [1, 2, 3, 3, 3] @@ -87,7 +87,7 @@ def _repeat_last_policy( return frame_indices -def _wrap_policy(frame_indices: list[int], *, num_frames_per_clip: int) -> list[int]: +def _wrap_policy(frame_indices: list[int], num_frames_per_clip: int) -> list[int]: # frame_indices = [1, 2, 3], num_frames_per_clip = 5 # output = [1, 2, 3, 1, 2] return (frame_indices * (num_frames_per_clip // len(frame_indices) + 1))[ @@ -95,8 +95,12 @@ def _wrap_policy(frame_indices: list[int], *, num_frames_per_clip: int) -> list[ ] -def _error_policy(**kwargs): - raise ValueError("TODO nice error message here") +def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[int]: + raise ValueError( + "You set the 'error' policy, and the sampler tried the decode a frame " + "that is beyond the number of frames in the video. " + "Try to leave sampling_range_end to its default value?" + ) _POLICY_FUNCTION_TYPE = Callable[[list[int], int], list[int]] @@ -109,7 +113,7 @@ def _error_policy(**kwargs): def _build_all_clips_indices( *, - clip_start_indices: list[int], + clip_start_indices: torch.Tensor, # 1D int tensor num_frames_per_clip: int, num_indices_between_frames: int, num_frames_in_video: int, @@ -138,9 +142,7 @@ def _build_all_clips_indices( range(start_index, frame_index_upper_bound, num_indices_between_frames) ) if len(frame_indices) < num_frames_per_clip: - frame_indices = policy_fun( - frame_indices, num_frames_per_clip=num_frames_per_clip - ) + frame_indices = policy_fun(frame_indices, num_frames_per_clip) all_clips_indices += frame_indices return all_clips_indices diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index dae8bfa56..fdf446096 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -1,12 +1,14 @@ import contextlib import random import re +from collections import Counter import pytest import torch from torchcodec.decoders import FrameBatch, VideoDecoder from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices +from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS from ..utils import assert_tensor_equal, NASA_VIDEO @@ -132,6 +134,80 @@ def test_sampling_range_negative(sampler): assert_tensor_equal(clip.data, clips_1[0].data) +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_default_behavior(sampler): + # This is a functional test for the default behavior of the + # sampling_range_end parameter. By default it's None, which means the + # sampler automatically sets its value such that we never sample "beyond" + # the number of frames in the video. That means that the last few frames of + # the video are less likely to be part of a clip. + # When sampling_range_end is set manually to e.g. len(decoder), the last + # frames are way more likely to be part of a clip, since there is no + # restriction on the sampling range (and the user-defined policy comes into + # action, potentially repeating that last frame). + # + # In this test we assert that the last sampled frame occurs significantly + # more often when sampling_range_end=len(decoder) than when it's None. + # This is only a proxy, for lack of better testing oppportunities. + + torch.manual_seed(0) + + decoder = VideoDecoder(NASA_VIDEO.path) + + num_clips = 10 + num_frames_per_clip = 5 + sampling_range_start = -10 + + # with default sampling_range_end value + clips_default = sampler( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + sampling_range_start=sampling_range_start, + sampling_range_end=None, + ) + + all_pts_default = torch.concat( + [clip.pts_seconds for clip in clips_default] + ).tolist() + largest_pts_default = max(all_pts_default) + largest_pts_counts_default = Counter(all_pts_default)[largest_pts_default] + + # with manual sampling_range_end value set to last frame + clips_manual = sampler( + decoder, + num_clips=num_clips, + num_frames_per_clip=num_frames_per_clip, + sampling_range_start=sampling_range_start, + sampling_range_end=len(decoder), + ) + all_pts_manual = torch.concat([clip.pts_seconds for clip in clips_manual]).tolist() + largest_pts_manual = max(all_pts_manual) + largest_pts_counts_manual = Counter(all_pts_manual)[largest_pts_manual] + + # Assert that the probability of occurence of the "last" sampled frame is + # way higher when setting sampling_range_end=len(decoder) than when relying + # on the default behavior. + # Note: ideally we would directly assert the number of occurrences of the + # last frame based on its index, but our APIs don't return indices, only pts + # (yet?) + assert largest_pts_counts_default < 2 + assert largest_pts_counts_manual > 10 + + +@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) +def test_sampling_range_error_policy(sampler): + decoder = VideoDecoder(NASA_VIDEO.path) + with pytest.raises(ValueError, match="beyond the number of frames"): + sampler( + decoder, + num_frames_per_clip=10, + sampling_range_start=-1, + sampling_range_end=len(decoder), + policy="error", + ) + + def test_random_sampler_randomness(): decoder = VideoDecoder(NASA_VIDEO.path) num_clips = 5 @@ -199,6 +275,8 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) def test_random_sampler_errors(sampler): + torch.manual_seed(0) + decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( ValueError, match=re.escape("num_clips (0) must be strictly positive") @@ -240,3 +318,80 @@ def test_random_sampler_errors(sampler): sampling_range_start=len(decoder) - 1, sampling_range_end=None, ) + + with pytest.raises(ValueError, match="Invalid policy"): + sampler(decoder, policy="BAD") + + +class TestPolicy: + @pytest.mark.parametrize( + "policy, frame_indices, expected_frame_indices", + ( + ("repeat_last", [1, 2, 3], [1, 2, 3, 3, 3]), + ("repeat_last", [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]), + ("wrap", [1, 2, 3], [1, 2, 3, 1, 2]), + ("wrap", [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]), + ), + ) + def test_policy(self, policy, frame_indices, expected_frame_indices): + policy_fun = _POLICY_FUNCTIONS[policy] + assert ( + policy_fun(frame_indices, num_frames_per_clip=5) == expected_frame_indices + ) + + def test_error_policy(self): + with pytest.raises(ValueError, match="beyond the number of frames"): + _POLICY_FUNCTIONS["error"]([1, 2, 3], num_frames_per_clip=5) + + +@pytest.mark.parametrize( + "clip_start_indices, num_indices_between_frames, policy, expected_all_clips_indices", + ( + ( + [0, 1, 2], # clip_start_indices + 1, # num_indices_between_frames + "repeat_last", # policy + # expected_all_clips_indices = + [0, 1, 2, 3, 4] + [1, 2, 3, 4, 4] + [2, 3, 4, 4, 4], + ), + # Same as above but with num_indices_between_frames=2 + ( + [0, 1, 2], # clip_start_indices + 2, # num_indices_between_frames + "repeat_last", # policy + # expected_all_clips_indices = + [0, 2, 4, 4, 4] + [1, 3, 3, 3, 3] + [2, 4, 4, 4, 4], + ), + # Same tests as above, for wrap policy + ( + [0, 1, 2], # clip_start_indices + 1, # num_indices_between_frames + "wrap", # policy + # expected_all_clips_indices = + [0, 1, 2, 3, 4] + [1, 2, 3, 4, 1] + [2, 3, 4, 2, 3], + ), + ( + [0, 1, 2], # clip_start_indices + 2, # num_indices_between_frames + "wrap", # policy + # expected_all_clips_indices = + [0, 2, 4, 0, 2] + [1, 3, 1, 3, 1] + [2, 4, 2, 4, 2], + ), + ), +) +def test_build_all_clips_indices( + clip_start_indices, num_indices_between_frames, policy, expected_all_clips_indices +): + NUM_FRAMES_PER_CLIP = 5 + all_clips_indices = _build_all_clips_indices( + clip_start_indices=clip_start_indices, + num_frames_per_clip=5, + num_indices_between_frames=num_indices_between_frames, + num_frames_in_video=5, + policy_fun=_POLICY_FUNCTIONS[policy], + ) + + assert isinstance(all_clips_indices, list) + assert all(isinstance(index, int) for index in all_clips_indices) + assert len(all_clips_indices) == len(clip_start_indices) * NUM_FRAMES_PER_CLIP + assert all_clips_indices == expected_all_clips_indices From a333e9b76376cbd9de997dc97033382f1c2bbd71 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 05:07:20 -0700 Subject: [PATCH 10/22] Minor clip_span refactoring --- src/torchcodec/samplers/_implem.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 99034b277..9b2ebbd9b 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -31,7 +31,12 @@ def _validate_params( def _validate_sampling_range( - *, sampling_range_start, sampling_range_end, num_frames_in_video, clip_span + *, + num_indices_between_frames, + num_frames_per_clip, + sampling_range_start, + sampling_range_end, + num_frames_in_video, ): if sampling_range_start < 0: sampling_range_start = num_frames_in_video + sampling_range_start @@ -42,6 +47,11 @@ def _validate_sampling_range( f"the number of frames ({num_frames_in_video})." ) + clip_span = _get_clip_span( + num_indices_between_frames=num_indices_between_frames, + num_frames_per_clip=num_frames_per_clip, + ) + if sampling_range_end is None: sampling_range_end = max(num_frames_in_video - clip_span + 1, 1) if sampling_range_start >= sampling_range_end: @@ -201,16 +211,12 @@ def _abstract_sampler( policy=policy, ) - clip_span = _get_clip_span( - num_indices_between_frames=num_indices_between_frames, - num_frames_per_clip=num_frames_per_clip, - ) - sampling_range_start, sampling_range_end = _validate_sampling_range( + num_frames_per_clip=num_frames_per_clip, + num_indices_between_frames=num_indices_between_frames, sampling_range_start=sampling_range_start, sampling_range_end=sampling_range_end, num_frames_in_video=len(decoder), - clip_span=clip_span, ) if kind == "random": From 08833b04c7939362bbca821dd8a11b588c260059 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 05:10:20 -0700 Subject: [PATCH 11/22] Don't add defaults to private abstract sampler --- src/torchcodec/samplers/_implem.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 9b2ebbd9b..4fb86201b 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -193,14 +193,14 @@ def _abstract_sampler( kind: Literal["random", "regular"], decoder: VideoDecoder, *, - num_clips: int = 1, - num_frames_per_clip: int = 1, - num_indices_between_frames: int = 1, - sampling_range_start: int = 0, - sampling_range_end: Optional[int] = None, # interval is [start, end). + num_clips: int, + num_frames_per_clip: int, + num_indices_between_frames: int, + sampling_range_start: int, + sampling_range_end: Optional[int], # interval is [start, end). # Important note: sampling_range_end defines the upper bound of where a clip # can *start*, not where a clip can end. - policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", + policy: Literal["repeat_last", "wrap", "error"], ) -> List[FrameBatch]: _validate_params( From 3dcbe1e4b52980fadbd8fa4a2922fecd8304c9c9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 07:17:47 -0700 Subject: [PATCH 12/22] Speed-up samplers by avoiding backwards seeks --- src/torchcodec/samplers/_implem.py | 36 ++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 4fb86201b..ebb909669 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -160,12 +160,15 @@ def _build_all_clips_indices( def _decode_all_clips_indices( decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int ) -> list[FrameBatch]: - # This takes the list of all the frames to decode, decode all the frames, - # and then packs them into clips of length num_frames_per_clip. - # This is slow, unoptimized, and u.g.l.y. It is not meant to stay. - # TODO: - # - sort the frames to avoid backward seeks, dedup, decode, and re-organize frames. - # - write most of this in C++ + # This takes the list of all the frames to decode (in arbitrary order), + # decode all the frames, and then packs them into clips of length + # num_frames_per_clip. + # + # To avoid backwards seeks (which are slow), we: + # - sort all the frame indices to be decoded + # - dedup them + # - decode all unique frames in sorted order + # - re-assemble the decoded frames back to their original order def chunk_list(lst, chunk_size): # return list of sublists of length chunk_size @@ -179,9 +182,24 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds ) - all_decoded_frames: list[Frame] = [ - decoder.get_frame_at(index) for index in all_clips_indices - ] + all_clips_indices_sorted, argsort = zip( + *sorted((j, i) for (i, j) in enumerate(all_clips_indices)) + ) + previous_decoded_frame = None + all_decoded_frames = [None] * len(all_clips_indices) + for i, j in enumerate(argsort): + frame_index = all_clips_indices_sorted[i] + if ( + previous_decoded_frame is not None + and frame_index == all_clips_indices_sorted[i - 1] + ): + # Avoid decoding the same frame twice. + decoded_frame = previous_decoded_frame + else: + decoded_frame = decoder.get_frame_at(index=frame_index) + previous_decoded_frame = decoded_frame + all_decoded_frames[j] = decoded_frame + all_clips: list[list[Frame]] = chunk_list( all_decoded_frames, chunk_size=num_frames_per_clip ) From 054b72efdba2a5ed5ece447760bcda431acf6d13 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 07:50:54 -0700 Subject: [PATCH 13/22] abstract -> generic --- src/torchcodec/samplers/_implem.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 4fb86201b..3ad07a71d 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -189,7 +189,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: return [to_framebatch(clip) for clip in all_clips] -def _abstract_sampler( +def _generic_sampler( kind: Literal["random", "regular"], decoder: VideoDecoder, *, @@ -263,7 +263,7 @@ def clips_at_random_indices( sampling_range_end: Optional[int] = None, # interval is [start, end). policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: - return _abstract_sampler( + return _generic_sampler( kind="random", decoder=decoder, num_clips=num_clips, @@ -286,7 +286,7 @@ def clips_at_regular_indices( policy: Literal["repeat_last", "wrap", "error"] = "repeat_last", ) -> List[FrameBatch]: - return _abstract_sampler( + return _generic_sampler( kind="regular", decoder=decoder, num_clips=num_clips, From 2bb1d58c2a20a7255d760843d575b2416e139f9f Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 07:51:10 -0700 Subject: [PATCH 14/22] typo fix --- src/torchcodec/samplers/_implem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 3ad07a71d..333bc7f80 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -107,7 +107,7 @@ def _wrap_policy(frame_indices: list[int], num_frames_per_clip: int) -> list[int def _error_policy(frames_indices: list[int], num_frames_per_clip: int) -> list[int]: raise ValueError( - "You set the 'error' policy, and the sampler tried the decode a frame " + "You set the 'error' policy, and the sampler tried to decode a frame " "that is beyond the number of frames in the video. " "Try to leave sampling_range_end to its default value?" ) From 9b932149a78b456bff00c77e3af31bb6825217a0 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 07:55:20 -0700 Subject: [PATCH 15/22] Typo --- src/torchcodec/samplers/_implem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 333bc7f80..0844c5fd2 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -129,7 +129,7 @@ def _build_all_clips_indices( num_frames_in_video: int, policy_fun: _POLICY_FUNCTION_TYPE, ) -> list[int]: - # From the clip_start_indices [f_00, f10, f20, ...] + # From the clip_start_indices [f_00, f_10, f_20, ...] # and from the rest of the parameters, return the list of all the frame # indices that make up all the clips. # I.e. the output is [f_00, f_01, f_02, f_03, f_10, f_11, f_12, f_13, ...] From 75945a86a5009b58da3709cbf08d04e51de60599 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 08:05:11 -0700 Subject: [PATCH 16/22] Comment --- src/torchcodec/samplers/_implem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 222d6a665..b1758a665 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -190,7 +190,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: for i, j in enumerate(argsort): frame_index = all_clips_indices_sorted[i] if ( - previous_decoded_frame is not None + previous_decoded_frame is not None # then we know i > 0 and frame_index == all_clips_indices_sorted[i - 1] ): # Avoid decoding the same frame twice. From 0c5c53735fb233c4da209794cf1dd548ec3bf804 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 08:27:16 -0700 Subject: [PATCH 17/22] Comment --- src/torchcodec/samplers/_implem.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index b1758a665..179251453 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -169,6 +169,8 @@ def _decode_all_clips_indices( # - dedup them # - decode all unique frames in sorted order # - re-assemble the decoded frames back to their original order + # + # TODO: Write this in C++ so we can avoid the copies that happen in `to_framebatch` def chunk_list(lst, chunk_size): # return list of sublists of length chunk_size From 5814439b8d66d7eb8d405114ce9276eeaf13ab47 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 09:16:39 -0700 Subject: [PATCH 18/22] Fix merge --- test/samplers/test_samplers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index db592e10e..3496a002d 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -1,7 +1,6 @@ import contextlib import random import re -from collections import Counter import pytest import torch @@ -262,8 +261,6 @@ def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_siz @pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices)) def test_random_sampler_errors(sampler): - torch.manual_seed(0) - decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises( ValueError, match=re.escape("num_clips (0) must be strictly positive") From 6e18e0790c33892a6804e6dd46ca95489aac402a Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 09:39:28 -0700 Subject: [PATCH 19/22] slightly better index name --- src/torchcodec/samplers/_implem.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 179251453..3ae9eecf9 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -185,7 +185,7 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: ) all_clips_indices_sorted, argsort = zip( - *sorted((j, i) for (i, j) in enumerate(all_clips_indices)) + *sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices)) ) previous_decoded_frame = None all_decoded_frames = [None] * len(all_clips_indices) From 6378a3408cd3ba92ce52e33eedaf37d53412bc94 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 09:43:09 -0700 Subject: [PATCH 20/22] Add note --- src/torchcodec/samplers/_implem.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 3ae9eecf9..2c5cccde2 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -177,6 +177,7 @@ def chunk_list(lst, chunk_size): return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)] def to_framebatch(frames: list[Frame]) -> FrameBatch: + # IMPORTANT: see other IMPORTANT note below data = torch.stack([frame.data for frame in frames]) pts_seconds = torch.tensor([frame.pts_seconds for frame in frames]) duration_seconds = torch.tensor([frame.duration_seconds for frame in frames]) @@ -196,6 +197,11 @@ def to_framebatch(frames: list[Frame]) -> FrameBatch: and frame_index == all_clips_indices_sorted[i - 1] ): # Avoid decoding the same frame twice. + # IMPORTANT: this is only correct because a copy of the frame will + # happen within `to_framebatch` when we call torch.stack. + # If a copy isn't made, the same underlying memory will be used for + # the 2 consecutive frames. When we re-write this, we should make + # sure to explicitly copy the data. decoded_frame = previous_decoded_frame else: decoded_frame = decoder.get_frame_at(index=frame_index) From 9ae87ccd6010dbf4d346cd6991e86022e64f4e36 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 09:48:51 -0700 Subject: [PATCH 21/22] Add sampler benchmarking code --- benchmarks/samplers/_benchmark_samplers.py | 230 ++++++++++++++++++ benchmarks/samplers/benchmark_samplers.py | 266 ++++----------------- 2 files changed, 274 insertions(+), 222 deletions(-) create mode 100644 benchmarks/samplers/_benchmark_samplers.py diff --git a/benchmarks/samplers/_benchmark_samplers.py b/benchmarks/samplers/_benchmark_samplers.py new file mode 100644 index 000000000..a64339ecc --- /dev/null +++ b/benchmarks/samplers/_benchmark_samplers.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +# This is an old benchmark using the old, deprecated and private sampler API. + +import abc +import argparse +import importlib +import os + +import decord +import numpy as np +import torch + +import torch.utils.benchmark as benchmark +from torchcodec.samplers import ( + IndexBasedSamplerArgs, + TimeBasedSamplerArgs, + VideoArgs, + VideoClipSampler, +) +from torchmultimodal.fb.utils.video_utils import ( + ClipSamplerType, + VideoClipSampler as tmm_vcs, +) +from torchvision.datasets.video_clip_sampler import ( # @manual=//pytorch/vision:internal_datasets + TVVideoClipDecoder, + UniformClipSamplingStrategy, + VideoClipSampler as ta_vcs, +) + + +class AbstractSampler: + def __init__(self): + pass + + @abc.abstractmethod + def sample_frames_uniformly(self, video_file, clips_per_video): + pass + + +class TorchCodecTimeBasedSampler(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + video_input = VideoArgs() + sampler_input = TimeBasedSamplerArgs( + sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 + ) + sampler = VideoClipSampler(video_input, sampler_input) + return sampler(video_tensor) + + +class TorchCodecIndexBasedSampler(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + video_input = VideoArgs() + sampler_input = IndexBasedSamplerArgs( + sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 + ) + sampler = VideoClipSampler(video_input, sampler_input) + return sampler(video_tensor) + + +class TorchCodecIndexBasedSamplerWithStackedOutput(AbstractSampler): + """ + On large batch, torch stack has impact on performance, but it's not obvious locally. + """ + + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + video_input = VideoArgs() + sampler_input = IndexBasedSamplerArgs( + sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 + ) + sampler = VideoClipSampler(video_input, sampler_input) + clips = sampler(video_tensor) + return torch.stack([clip[0] for clip in clips]) + + +class DecordSampler(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + decord.bridge.set_bridge("torch") + av_reader = decord.VideoReader(video_file) + num_frames = len(av_reader) + frame_indices = np.linspace(0, num_frames - 1, clips_per_video, dtype=int) + frames = av_reader.get_batch(frame_indices) + return frames + + +class TorchMMSamplerWithTorchVisionBackend(AbstractSampler): + """ + Here we use TorchMultimodal sampler as it's updated version on top of torchvision decoder. + """ + + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + sampler = tmm_vcs( + clip_sampler_type=ClipSamplerType("UNIFORM"), + clips_per_video=clips_per_video, + frames_per_clip=1, + frame_dilation=1, + ) + return sampler(video_tensor) + + +class TorchVisionNewSamplerWithTorchVisionBackend(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + clip_sampling_strategy = UniformClipSamplingStrategy( + clips_per_video=clips_per_video + ) + decoder = TVVideoClipDecoder(clip_length_in_frames=1, read_audio_stream=False) + sampler = ta_vcs(clip_sampling_strategy, decoder) + return sampler(str(video_file)) + + +def main(): + """Benchmarks the performance of different samplers""" + + parser = argparse.ArgumentParser() + parser.add_argument( + "--bm_small_video_speed", + help="Benchmark small video decoding speed", + default=True, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--bm_large_video_speed", + help="Benchmark large video decoding speed", + default=True, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--bm_video_speed_min_run_seconds", + help="Benchmark minimum run time, in seconds, to wait per datapoint", + type=float, + default=5.0, + ) + args = parser.parse_args() + + small_video_path = importlib.resources.path(__package__, "nasa_13013.mp4") + small_video_path = os.fspath(str(small_video_path)) + + large_video_path = importlib.resources.path(__package__, "853.mp4") + large_video_path = os.fspath(str(large_video_path)) + + clips_per_video = 8 + + sampler_dict = {} + sampler_dict["TorchCodecTimeBasedSampler"] = TorchCodecTimeBasedSampler() + sampler_dict["TorchCodecIndexBasedSampler"] = TorchCodecIndexBasedSampler() + sampler_dict["TorchCodecIndexBasedSamplerWithStackedOutput"] = ( + TorchCodecIndexBasedSamplerWithStackedOutput() + ) + sampler_dict["DecordSampler"] = DecordSampler() + sampler_dict["TorchMMSamplerWithTorchVisionBackend"] = ( + TorchMMSamplerWithTorchVisionBackend() + ) + sampler_dict["TorchVisionNewSamplerWithTorchVisionBackend"] = ( + TorchVisionNewSamplerWithTorchVisionBackend() + ) + + results = [] + + for sampler_name, sampler in sampler_dict.items(): + if args.bm_small_video_speed: + sampler_result = benchmark.Timer( + stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", + globals={ + "video_file": small_video_path, + "clips_per_video": clips_per_video, + "sampler": sampler, + }, + label="uniform sampling latency for 700KB video", + sub_label=sampler_name, + description=f"uniform sampling {clips_per_video} frames", + ) + results.append( + sampler_result.blocked_autorange( + min_run_time=args.bm_video_speed_min_run_seconds + ) + ) + + if args.bm_large_video_speed: + if sampler_name == "TorchMMSamplerWithTorchVisionBackend": + continue + sampler_result = benchmark.Timer( + stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", + globals={ + "video_file": large_video_path, + "clips_per_video": clips_per_video, + "sampler": sampler, + }, + label="uniform sampling latency for 50MB video", + sub_label=sampler_name, + description=f"uniform sampling {clips_per_video} frames", + ) + results.append( + sampler_result.blocked_autorange( + min_run_time=args.bm_video_speed_min_run_seconds + ) + ) + + compare = benchmark.Compare(results) + compare.print() diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index ed31a79a9..f34acea66 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -1,227 +1,49 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. +from time import perf_counter_ns -import abc -import argparse -import importlib -import os - -import decord -import numpy as np import torch - -import torch.utils.benchmark as benchmark -from torchcodec.samplers import ( - IndexBasedSamplerArgs, - TimeBasedSamplerArgs, - VideoArgs, - VideoClipSampler, -) -from torchmultimodal.fb.utils.video_utils import ( - ClipSamplerType, - VideoClipSampler as tmm_vcs, -) -from torchvision.datasets.video_clip_sampler import ( # @manual=//pytorch/vision:internal_datasets - TVVideoClipDecoder, - UniformClipSamplingStrategy, - VideoClipSampler as ta_vcs, -) - - -class AbstractSampler: - def __init__(self): - pass - - @abc.abstractmethod - def sample_frames_uniformly(self, video_file, clips_per_video): - pass - - -class TorchCodecTimeBasedSampler(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - video_input = VideoArgs() - sampler_input = TimeBasedSamplerArgs( - sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 - ) - sampler = VideoClipSampler(video_input, sampler_input) - return sampler(video_tensor) - - -class TorchCodecIndexBasedSampler(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - video_input = VideoArgs() - sampler_input = IndexBasedSamplerArgs( - sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 - ) - sampler = VideoClipSampler(video_input, sampler_input) - return sampler(video_tensor) - - -class TorchCodecIndexBasedSamplerWithStackedOutput(AbstractSampler): - """ - On large batch, torch stack has impact on performance, but it's not obvious locally. - """ - - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - video_input = VideoArgs() - sampler_input = IndexBasedSamplerArgs( - sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 - ) - sampler = VideoClipSampler(video_input, sampler_input) - clips = sampler(video_tensor) - return torch.stack([clip[0] for clip in clips]) - - -class DecordSampler(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - decord.bridge.set_bridge("torch") - av_reader = decord.VideoReader(video_file) - num_frames = len(av_reader) - frame_indices = np.linspace(0, num_frames - 1, clips_per_video, dtype=int) - frames = av_reader.get_batch(frame_indices) - return frames - - -class TorchMMSamplerWithTorchVisionBackend(AbstractSampler): - """ - Here we use TorchMultimodal sampler as it's updated version on top of torchvision decoder. - """ - - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - sampler = tmm_vcs( - clip_sampler_type=ClipSamplerType("UNIFORM"), - clips_per_video=clips_per_video, - frames_per_clip=1, - frame_dilation=1, - ) - return sampler(video_tensor) - - -class TorchVisionNewSamplerWithTorchVisionBackend(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - clip_sampling_strategy = UniformClipSamplingStrategy( - clips_per_video=clips_per_video - ) - decoder = TVVideoClipDecoder(clip_length_in_frames=1, read_audio_stream=False) - sampler = ta_vcs(clip_sampling_strategy, decoder) - return sampler(str(video_file)) - - -def main(): - """Benchmarks the performance of different samplers""" - - parser = argparse.ArgumentParser() - parser.add_argument( - "--bm_small_video_speed", - help="Benchmark small video decoding speed", - default=True, - action=argparse.BooleanOptionalAction, - ) - parser.add_argument( - "--bm_large_video_speed", - help="Benchmark large video decoding speed", - default=True, - action=argparse.BooleanOptionalAction, - ) - parser.add_argument( - "--bm_video_speed_min_run_seconds", - help="Benchmark minimum run time, in seconds, to wait per datapoint", - type=float, - default=5.0, +from torchcodec.decoders import VideoDecoder +from torchcodec.samplers import clips_at_random_indices + + +def bench(f, *args, num_exp=100, warmup=0, **kwargs): + + for _ in range(warmup): + f(*args, **kwargs) + + times = [] + for _ in range(num_exp): + start = perf_counter_ns() + f(*args, **kwargs) + end = perf_counter_ns() + times.append(end - start) + return torch.tensor(times).float() + + +def report_stats(times, unit="ms"): + mul = { + "ns": 1, + "µs": 1e-3, + "ms": 1e-6, + "s": 1e-9, + }[unit] + times = times * mul + std = times.std().item() + med = times.median().item() + print(f"{med = :.2f}{unit} +- {std:.2f}") + return med + + +def sample(num_clips): + decoder = VideoDecoder("test/resources/nasa_13013.mp4") + clips_at_random_indices( + decoder, + num_clips=num_clips, + num_frames_per_clip=10, + num_indices_between_frames=2, ) - args = parser.parse_args() - - small_video_path = importlib.resources.path(__package__, "nasa_13013.mp4") - small_video_path = os.fspath(str(small_video_path)) - - large_video_path = importlib.resources.path(__package__, "853.mp4") - large_video_path = os.fspath(str(large_video_path)) - - clips_per_video = 8 - - sampler_dict = {} - sampler_dict["TorchCodecTimeBasedSampler"] = TorchCodecTimeBasedSampler() - sampler_dict["TorchCodecIndexBasedSampler"] = TorchCodecIndexBasedSampler() - sampler_dict["TorchCodecIndexBasedSamplerWithStackedOutput"] = ( - TorchCodecIndexBasedSamplerWithStackedOutput() - ) - sampler_dict["DecordSampler"] = DecordSampler() - sampler_dict["TorchMMSamplerWithTorchVisionBackend"] = ( - TorchMMSamplerWithTorchVisionBackend() - ) - sampler_dict["TorchVisionNewSamplerWithTorchVisionBackend"] = ( - TorchVisionNewSamplerWithTorchVisionBackend() - ) - - results = [] - - for sampler_name, sampler in sampler_dict.items(): - if args.bm_small_video_speed: - sampler_result = benchmark.Timer( - stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", - globals={ - "video_file": small_video_path, - "clips_per_video": clips_per_video, - "sampler": sampler, - }, - label="uniform sampling latency for 700KB video", - sub_label=sampler_name, - description=f"uniform sampling {clips_per_video} frames", - ) - results.append( - sampler_result.blocked_autorange( - min_run_time=args.bm_video_speed_min_run_seconds - ) - ) - if args.bm_large_video_speed: - if sampler_name == "TorchMMSamplerWithTorchVisionBackend": - continue - sampler_result = benchmark.Timer( - stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", - globals={ - "video_file": large_video_path, - "clips_per_video": clips_per_video, - "sampler": sampler, - }, - label="uniform sampling latency for 50MB video", - sub_label=sampler_name, - description=f"uniform sampling {clips_per_video} frames", - ) - results.append( - sampler_result.blocked_autorange( - min_run_time=args.bm_video_speed_min_run_seconds - ) - ) - compare = benchmark.Compare(results) - compare.print() +times = bench(sample, num_clips=1, num_exp=30, warmup=2) +report_stats(times, unit="ms") +times = bench(sample, num_clips=50, num_exp=30, warmup=2) +report_stats(times, unit="ms") From 5eba3e4fe5e91f9d9aa20dccfe1c65bf6bbe7f95 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 7 Oct 2024 09:50:39 -0700 Subject: [PATCH 22/22] Revert "Add sampler benchmarking code" This reverts commit 9ae87ccd6010dbf4d346cd6991e86022e64f4e36. --- benchmarks/samplers/_benchmark_samplers.py | 230 ------------------ benchmarks/samplers/benchmark_samplers.py | 266 +++++++++++++++++---- 2 files changed, 222 insertions(+), 274 deletions(-) delete mode 100644 benchmarks/samplers/_benchmark_samplers.py diff --git a/benchmarks/samplers/_benchmark_samplers.py b/benchmarks/samplers/_benchmark_samplers.py deleted file mode 100644 index a64339ecc..000000000 --- a/benchmarks/samplers/_benchmark_samplers.py +++ /dev/null @@ -1,230 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - - -# This is an old benchmark using the old, deprecated and private sampler API. - -import abc -import argparse -import importlib -import os - -import decord -import numpy as np -import torch - -import torch.utils.benchmark as benchmark -from torchcodec.samplers import ( - IndexBasedSamplerArgs, - TimeBasedSamplerArgs, - VideoArgs, - VideoClipSampler, -) -from torchmultimodal.fb.utils.video_utils import ( - ClipSamplerType, - VideoClipSampler as tmm_vcs, -) -from torchvision.datasets.video_clip_sampler import ( # @manual=//pytorch/vision:internal_datasets - TVVideoClipDecoder, - UniformClipSamplingStrategy, - VideoClipSampler as ta_vcs, -) - - -class AbstractSampler: - def __init__(self): - pass - - @abc.abstractmethod - def sample_frames_uniformly(self, video_file, clips_per_video): - pass - - -class TorchCodecTimeBasedSampler(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - video_input = VideoArgs() - sampler_input = TimeBasedSamplerArgs( - sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 - ) - sampler = VideoClipSampler(video_input, sampler_input) - return sampler(video_tensor) - - -class TorchCodecIndexBasedSampler(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - video_input = VideoArgs() - sampler_input = IndexBasedSamplerArgs( - sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 - ) - sampler = VideoClipSampler(video_input, sampler_input) - return sampler(video_tensor) - - -class TorchCodecIndexBasedSamplerWithStackedOutput(AbstractSampler): - """ - On large batch, torch stack has impact on performance, but it's not obvious locally. - """ - - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - video_input = VideoArgs() - sampler_input = IndexBasedSamplerArgs( - sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 - ) - sampler = VideoClipSampler(video_input, sampler_input) - clips = sampler(video_tensor) - return torch.stack([clip[0] for clip in clips]) - - -class DecordSampler(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - decord.bridge.set_bridge("torch") - av_reader = decord.VideoReader(video_file) - num_frames = len(av_reader) - frame_indices = np.linspace(0, num_frames - 1, clips_per_video, dtype=int) - frames = av_reader.get_batch(frame_indices) - return frames - - -class TorchMMSamplerWithTorchVisionBackend(AbstractSampler): - """ - Here we use TorchMultimodal sampler as it's updated version on top of torchvision decoder. - """ - - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - arr = np.fromfile(video_file, dtype=np.uint8) - video_tensor = torch.from_numpy(arr) - sampler = tmm_vcs( - clip_sampler_type=ClipSamplerType("UNIFORM"), - clips_per_video=clips_per_video, - frames_per_clip=1, - frame_dilation=1, - ) - return sampler(video_tensor) - - -class TorchVisionNewSamplerWithTorchVisionBackend(AbstractSampler): - def __init__(self): - pass - - def sample_frames_uniformly(self, video_file, clips_per_video): - clip_sampling_strategy = UniformClipSamplingStrategy( - clips_per_video=clips_per_video - ) - decoder = TVVideoClipDecoder(clip_length_in_frames=1, read_audio_stream=False) - sampler = ta_vcs(clip_sampling_strategy, decoder) - return sampler(str(video_file)) - - -def main(): - """Benchmarks the performance of different samplers""" - - parser = argparse.ArgumentParser() - parser.add_argument( - "--bm_small_video_speed", - help="Benchmark small video decoding speed", - default=True, - action=argparse.BooleanOptionalAction, - ) - parser.add_argument( - "--bm_large_video_speed", - help="Benchmark large video decoding speed", - default=True, - action=argparse.BooleanOptionalAction, - ) - parser.add_argument( - "--bm_video_speed_min_run_seconds", - help="Benchmark minimum run time, in seconds, to wait per datapoint", - type=float, - default=5.0, - ) - args = parser.parse_args() - - small_video_path = importlib.resources.path(__package__, "nasa_13013.mp4") - small_video_path = os.fspath(str(small_video_path)) - - large_video_path = importlib.resources.path(__package__, "853.mp4") - large_video_path = os.fspath(str(large_video_path)) - - clips_per_video = 8 - - sampler_dict = {} - sampler_dict["TorchCodecTimeBasedSampler"] = TorchCodecTimeBasedSampler() - sampler_dict["TorchCodecIndexBasedSampler"] = TorchCodecIndexBasedSampler() - sampler_dict["TorchCodecIndexBasedSamplerWithStackedOutput"] = ( - TorchCodecIndexBasedSamplerWithStackedOutput() - ) - sampler_dict["DecordSampler"] = DecordSampler() - sampler_dict["TorchMMSamplerWithTorchVisionBackend"] = ( - TorchMMSamplerWithTorchVisionBackend() - ) - sampler_dict["TorchVisionNewSamplerWithTorchVisionBackend"] = ( - TorchVisionNewSamplerWithTorchVisionBackend() - ) - - results = [] - - for sampler_name, sampler in sampler_dict.items(): - if args.bm_small_video_speed: - sampler_result = benchmark.Timer( - stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", - globals={ - "video_file": small_video_path, - "clips_per_video": clips_per_video, - "sampler": sampler, - }, - label="uniform sampling latency for 700KB video", - sub_label=sampler_name, - description=f"uniform sampling {clips_per_video} frames", - ) - results.append( - sampler_result.blocked_autorange( - min_run_time=args.bm_video_speed_min_run_seconds - ) - ) - - if args.bm_large_video_speed: - if sampler_name == "TorchMMSamplerWithTorchVisionBackend": - continue - sampler_result = benchmark.Timer( - stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", - globals={ - "video_file": large_video_path, - "clips_per_video": clips_per_video, - "sampler": sampler, - }, - label="uniform sampling latency for 50MB video", - sub_label=sampler_name, - description=f"uniform sampling {clips_per_video} frames", - ) - results.append( - sampler_result.blocked_autorange( - min_run_time=args.bm_video_speed_min_run_seconds - ) - ) - - compare = benchmark.Compare(results) - compare.print() diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index f34acea66..ed31a79a9 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -1,49 +1,227 @@ -from time import perf_counter_ns +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import abc +import argparse +import importlib +import os + +import decord +import numpy as np import torch -from torchcodec.decoders import VideoDecoder -from torchcodec.samplers import clips_at_random_indices - - -def bench(f, *args, num_exp=100, warmup=0, **kwargs): - - for _ in range(warmup): - f(*args, **kwargs) - - times = [] - for _ in range(num_exp): - start = perf_counter_ns() - f(*args, **kwargs) - end = perf_counter_ns() - times.append(end - start) - return torch.tensor(times).float() - - -def report_stats(times, unit="ms"): - mul = { - "ns": 1, - "µs": 1e-3, - "ms": 1e-6, - "s": 1e-9, - }[unit] - times = times * mul - std = times.std().item() - med = times.median().item() - print(f"{med = :.2f}{unit} +- {std:.2f}") - return med - - -def sample(num_clips): - decoder = VideoDecoder("test/resources/nasa_13013.mp4") - clips_at_random_indices( - decoder, - num_clips=num_clips, - num_frames_per_clip=10, - num_indices_between_frames=2, + +import torch.utils.benchmark as benchmark +from torchcodec.samplers import ( + IndexBasedSamplerArgs, + TimeBasedSamplerArgs, + VideoArgs, + VideoClipSampler, +) +from torchmultimodal.fb.utils.video_utils import ( + ClipSamplerType, + VideoClipSampler as tmm_vcs, +) +from torchvision.datasets.video_clip_sampler import ( # @manual=//pytorch/vision:internal_datasets + TVVideoClipDecoder, + UniformClipSamplingStrategy, + VideoClipSampler as ta_vcs, +) + + +class AbstractSampler: + def __init__(self): + pass + + @abc.abstractmethod + def sample_frames_uniformly(self, video_file, clips_per_video): + pass + + +class TorchCodecTimeBasedSampler(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + video_input = VideoArgs() + sampler_input = TimeBasedSamplerArgs( + sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 + ) + sampler = VideoClipSampler(video_input, sampler_input) + return sampler(video_tensor) + + +class TorchCodecIndexBasedSampler(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + video_input = VideoArgs() + sampler_input = IndexBasedSamplerArgs( + sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 + ) + sampler = VideoClipSampler(video_input, sampler_input) + return sampler(video_tensor) + + +class TorchCodecIndexBasedSamplerWithStackedOutput(AbstractSampler): + """ + On large batch, torch stack has impact on performance, but it's not obvious locally. + """ + + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + video_input = VideoArgs() + sampler_input = IndexBasedSamplerArgs( + sampler_type="uniform", clips_per_video=clips_per_video, frames_per_clip=1 + ) + sampler = VideoClipSampler(video_input, sampler_input) + clips = sampler(video_tensor) + return torch.stack([clip[0] for clip in clips]) + + +class DecordSampler(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + decord.bridge.set_bridge("torch") + av_reader = decord.VideoReader(video_file) + num_frames = len(av_reader) + frame_indices = np.linspace(0, num_frames - 1, clips_per_video, dtype=int) + frames = av_reader.get_batch(frame_indices) + return frames + + +class TorchMMSamplerWithTorchVisionBackend(AbstractSampler): + """ + Here we use TorchMultimodal sampler as it's updated version on top of torchvision decoder. + """ + + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + arr = np.fromfile(video_file, dtype=np.uint8) + video_tensor = torch.from_numpy(arr) + sampler = tmm_vcs( + clip_sampler_type=ClipSamplerType("UNIFORM"), + clips_per_video=clips_per_video, + frames_per_clip=1, + frame_dilation=1, + ) + return sampler(video_tensor) + + +class TorchVisionNewSamplerWithTorchVisionBackend(AbstractSampler): + def __init__(self): + pass + + def sample_frames_uniformly(self, video_file, clips_per_video): + clip_sampling_strategy = UniformClipSamplingStrategy( + clips_per_video=clips_per_video + ) + decoder = TVVideoClipDecoder(clip_length_in_frames=1, read_audio_stream=False) + sampler = ta_vcs(clip_sampling_strategy, decoder) + return sampler(str(video_file)) + + +def main(): + """Benchmarks the performance of different samplers""" + + parser = argparse.ArgumentParser() + parser.add_argument( + "--bm_small_video_speed", + help="Benchmark small video decoding speed", + default=True, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--bm_large_video_speed", + help="Benchmark large video decoding speed", + default=True, + action=argparse.BooleanOptionalAction, + ) + parser.add_argument( + "--bm_video_speed_min_run_seconds", + help="Benchmark minimum run time, in seconds, to wait per datapoint", + type=float, + default=5.0, ) + args = parser.parse_args() + + small_video_path = importlib.resources.path(__package__, "nasa_13013.mp4") + small_video_path = os.fspath(str(small_video_path)) + + large_video_path = importlib.resources.path(__package__, "853.mp4") + large_video_path = os.fspath(str(large_video_path)) + + clips_per_video = 8 + + sampler_dict = {} + sampler_dict["TorchCodecTimeBasedSampler"] = TorchCodecTimeBasedSampler() + sampler_dict["TorchCodecIndexBasedSampler"] = TorchCodecIndexBasedSampler() + sampler_dict["TorchCodecIndexBasedSamplerWithStackedOutput"] = ( + TorchCodecIndexBasedSamplerWithStackedOutput() + ) + sampler_dict["DecordSampler"] = DecordSampler() + sampler_dict["TorchMMSamplerWithTorchVisionBackend"] = ( + TorchMMSamplerWithTorchVisionBackend() + ) + sampler_dict["TorchVisionNewSamplerWithTorchVisionBackend"] = ( + TorchVisionNewSamplerWithTorchVisionBackend() + ) + + results = [] + + for sampler_name, sampler in sampler_dict.items(): + if args.bm_small_video_speed: + sampler_result = benchmark.Timer( + stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", + globals={ + "video_file": small_video_path, + "clips_per_video": clips_per_video, + "sampler": sampler, + }, + label="uniform sampling latency for 700KB video", + sub_label=sampler_name, + description=f"uniform sampling {clips_per_video} frames", + ) + results.append( + sampler_result.blocked_autorange( + min_run_time=args.bm_video_speed_min_run_seconds + ) + ) + if args.bm_large_video_speed: + if sampler_name == "TorchMMSamplerWithTorchVisionBackend": + continue + sampler_result = benchmark.Timer( + stmt="sampler.sample_frames_uniformly(video_file, clips_per_video)", + globals={ + "video_file": large_video_path, + "clips_per_video": clips_per_video, + "sampler": sampler, + }, + label="uniform sampling latency for 50MB video", + sub_label=sampler_name, + description=f"uniform sampling {clips_per_video} frames", + ) + results.append( + sampler_result.blocked_autorange( + min_run_time=args.bm_video_speed_min_run_seconds + ) + ) -times = bench(sample, num_clips=1, num_exp=30, warmup=2) -report_stats(times, unit="ms") -times = bench(sample, num_clips=50, num_exp=30, warmup=2) -report_stats(times, unit="ms") + compare = benchmark.Compare(results) + compare.print()