Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/samplers/_time_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def _validate_sampling_range_time_based(
if sampling_range_start is None:
sampling_range_start = begin_stream_seconds
else:
if sampling_range_start <= begin_stream_seconds:
if sampling_range_start < begin_stream_seconds:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Caught this minor bug in the process!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am glad this caught a bug and the test is present to prevent regressions.

Thanks @NicolasHug

raise ValueError(
f"sampling_range_start ({sampling_range_start}) must be at least {begin_stream_seconds}"
)
Expand Down
45 changes: 45 additions & 0 deletions test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,51 @@ def test_time_based_sampler(sampler, seconds_between_frames):
)


@pytest.mark.parametrize(
"sampler",
(
partial(
clips_at_regular_indices,
num_clips=1,
sampling_range_start=0,
sampling_range_end=1,
),
partial(
clips_at_random_indices,
num_clips=1,
sampling_range_start=0,
sampling_range_end=1,
),
partial(
clips_at_random_timestamps,
num_clips=1,
sampling_range_start=0,
sampling_range_end=0.01,
),
partial(
clips_at_regular_timestamps,
seconds_between_clip_starts=1,
seconds_between_frames=0.0335, # forces consecutive frames
sampling_range_start=0,
sampling_range_end=0.01,
),
),
)
def test_against_ref(sampler):
# Force the sampler to sample a clip containing the first 5 frames of the
# video. We can then assert the exact frame values against our existing test
# resource reference.
decoder = VideoDecoder(NASA_VIDEO.path)

num_frames_per_clip = 5
expected_clip_data = NASA_VIDEO.get_frame_data_by_range(
start=0, stop=num_frames_per_clip
)

clip = sampler(decoder, num_frames_per_clip=num_frames_per_clip)[0]
assert_tensor_equal(clip.data, expected_clip_data)


@pytest.mark.parametrize(
"sampler, sampling_range_start, sampling_range_end, assert_all_equal",
(
Expand Down
Loading