From 89b2b1dcc35b1c53cc090cbb2a1936b619befb06 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 25 Oct 2024 13:45:31 +0100 Subject: [PATCH 1/2] Test samplers against reference --- src/torchcodec/samplers/_time_based.py | 2 +- test/samplers/test_samplers.py | 45 ++++++++++++++++++++++++++ 2 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index e9d485aad..8fca584c0 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -71,7 +71,7 @@ def _validate_sampling_range_time_based( if sampling_range_start is None: sampling_range_start = begin_stream_seconds else: - if sampling_range_start <= begin_stream_seconds: + if sampling_range_start < begin_stream_seconds: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}" ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 4a12d93c4..51af6d875 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -148,6 +148,51 @@ def test_time_based_sampler(sampler, seconds_between_frames): ) +@pytest.mark.parametrize( + "sampler", + ( + partial( + clips_at_regular_indices, + num_clips=1, + num_frames_per_clip=5, + sampling_range_start=0, + sampling_range_end=1, + ), + partial( + clips_at_random_indices, + num_clips=1, + num_frames_per_clip=5, + sampling_range_start=0, + sampling_range_end=1, + ), + partial( + clips_at_random_timestamps, + num_clips=1, + num_frames_per_clip=5, + sampling_range_start=0, + sampling_range_end=0.01, + ), + partial( + clips_at_regular_timestamps, + seconds_between_clip_starts=1, + seconds_between_frames=0.0335, + num_frames_per_clip=5, + sampling_range_start=0, + sampling_range_end=0.01, + ), + ), +) +def test_against_ref(sampler): + # Force the sampler to sample a clip containing the first 5 frames of the + # video. We can then assert the exact frame values against our existing test + # resource reference. + decoder = VideoDecoder(NASA_VIDEO.path) + expected_clip_data = NASA_VIDEO.get_frame_data_by_range(start=0, stop=5) + + clip = sampler(decoder)[0] + assert_tensor_equal(clip.data, expected_clip_data) + + @pytest.mark.parametrize( "sampler, sampling_range_start, sampling_range_end, assert_all_equal", ( From 868922c8568e3d8c455601181002677dc1e6fdb5 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Fri, 25 Oct 2024 13:52:09 +0100 Subject: [PATCH 2/2] Nit --- test/samplers/test_samplers.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 51af6d875..854efd9ca 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -154,29 +154,25 @@ def test_time_based_sampler(sampler, seconds_between_frames): partial( clips_at_regular_indices, num_clips=1, - num_frames_per_clip=5, sampling_range_start=0, sampling_range_end=1, ), partial( clips_at_random_indices, num_clips=1, - num_frames_per_clip=5, sampling_range_start=0, sampling_range_end=1, ), partial( clips_at_random_timestamps, num_clips=1, - num_frames_per_clip=5, sampling_range_start=0, sampling_range_end=0.01, ), partial( clips_at_regular_timestamps, seconds_between_clip_starts=1, - seconds_between_frames=0.0335, - num_frames_per_clip=5, + seconds_between_frames=0.0335, # forces consecutive frames sampling_range_start=0, sampling_range_end=0.01, ), @@ -187,9 +183,13 @@ def test_against_ref(sampler): # video. We can then assert the exact frame values against our existing test # resource reference. decoder = VideoDecoder(NASA_VIDEO.path) - expected_clip_data = NASA_VIDEO.get_frame_data_by_range(start=0, stop=5) - clip = sampler(decoder)[0] + num_frames_per_clip = 5 + expected_clip_data = NASA_VIDEO.get_frame_data_by_range( + start=0, stop=num_frames_per_clip + ) + + clip = sampler(decoder, num_frames_per_clip=num_frames_per_clip)[0] assert_tensor_equal(clip.data, expected_clip_data)