diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index e9d485aad..8fca584c0 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -71,7 +71,7 @@ def _validate_sampling_range_time_based( if sampling_range_start is None: sampling_range_start = begin_stream_seconds else: - if sampling_range_start <= begin_stream_seconds: + if sampling_range_start < begin_stream_seconds: raise ValueError( f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}" ) diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 4a12d93c4..854efd9ca 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -148,6 +148,51 @@ def test_time_based_sampler(sampler, seconds_between_frames): ) +@pytest.mark.parametrize( + "sampler", + ( + partial( + clips_at_regular_indices, + num_clips=1, + sampling_range_start=0, + sampling_range_end=1, + ), + partial( + clips_at_random_indices, + num_clips=1, + sampling_range_start=0, + sampling_range_end=1, + ), + partial( + clips_at_random_timestamps, + num_clips=1, + sampling_range_start=0, + sampling_range_end=0.01, + ), + partial( + clips_at_regular_timestamps, + seconds_between_clip_starts=1, + seconds_between_frames=0.0335, # forces consecutive frames + sampling_range_start=0, + sampling_range_end=0.01, + ), + ), +) +def test_against_ref(sampler): + # Force the sampler to sample a clip containing the first 5 frames of the + # video. We can then assert the exact frame values against our existing test + # resource reference. + decoder = VideoDecoder(NASA_VIDEO.path) + + num_frames_per_clip = 5 + expected_clip_data = NASA_VIDEO.get_frame_data_by_range( + start=0, stop=num_frames_per_clip + ) + + clip = sampler(decoder, num_frames_per_clip=num_frames_per_clip)[0] + assert_tensor_equal(clip.data, expected_clip_data) + + @pytest.mark.parametrize( "sampler, sampling_range_start, sampling_range_end, assert_all_equal", (