Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/torchcodec/samplers/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from ._implem import clips_at_random_indices
from ._implem import clips_at_random_indices, clips_at_regular_indices
61 changes: 61 additions & 0 deletions src/torchcodec/samplers/_implem.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,64 @@ def clips_at_random_indices(
random.setstate(builtin_random_state)

return clips


def clips_at_regular_indices(
decoder: VideoDecoder,
*,
num_clips: int = 1,
Copy link
Contributor Author

@NicolasHug NicolasHug Oct 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like clips_at_regular_indices(num_clips=1) is an anti pattern. Same with 2. Maybe we should make 3 the default, and enforce num_clips>=3? Or not enforce, but still set the default to 3?

I don't feel too strongly about it either way - just being pedantic.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is what other libraries do we can do the same here.

I was thinking of sampling the middle of the video if num_clips=1 but that would make it slower than existing libraries

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. Other libs allow num_clips=1. Let's revisit if needed

num_frames_per_clip: int = 1,
num_indices_between_frames: int = 1,
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
) -> List[FrameBatch]:

_validate_params(
decoder=decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
num_indices_between_frames=num_indices_between_frames,
)

clip_span = _get_clip_span(
num_indices_between_frames=num_indices_between_frames,
num_frames_per_clip=num_frames_per_clip,
)

# TODO: We should probably not error.
if clip_span > len(decoder):
raise ValueError(
f"Clip span ({clip_span}) is larger than the number of frames ({len(decoder)})"
)

sampling_range_start, sampling_range_end = _validate_sampling_range(
sampling_range_start=sampling_range_start,
sampling_range_end=sampling_range_end,
num_frames=len(decoder),
clip_span=clip_span,
)

# Note [num clips larger than sampling range]
# If we ask for more clips than there are frames in the sampling range or
# in the video, we rely on torch.linspace behavior which will return
# duplicated indices.
# E.g. torch.linspace(0, 10, steps=20, dtype=torch.int) returns
# 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 10
# Alternatively we could wrap around, but the current behavior is closer to
# the expected "equally spaced indices" sampling.
clip_start_indices = torch.linspace(
sampling_range_start, sampling_range_end - 1, steps=num_clips, dtype=torch.int
)

# Similarly to clip_at_random_indices, there may be backward seeks if clips overlap.
# See other TODO over there, and apply similar changes here.
clips = [
decoder.get_frames_at(
start=clip_start_index,
stop=clip_start_index + clip_span,
step=num_indices_between_frames,
)
for clip_start_index in clip_start_indices
]

return clips
91 changes: 67 additions & 24 deletions test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,21 @@

import pytest
import torch

from torchcodec.decoders import FrameBatch, VideoDecoder
from torchcodec.samplers import clips_at_random_indices
from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices

from ..utils import assert_tensor_equal, NASA_VIDEO


@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
@pytest.mark.parametrize("num_indices_between_frames", [1, 5])
def test_random_sampler(num_indices_between_frames):
def test_sampler(sampler, num_indices_between_frames):
decoder = VideoDecoder(NASA_VIDEO.path)
num_clips = 2
num_clips = 5
num_frames_per_clip = 3

clips = clips_at_random_indices(
clips = sampler(
decoder,
num_clips=num_clips,
num_frames_per_clip=num_frames_per_clip,
Expand All @@ -34,6 +36,17 @@ def test_random_sampler(num_indices_between_frames):
)
assert all(clip.data.shape == expected_clip_data_shape for clip in clips)

if sampler is clips_at_regular_indices:
# assert regular spacing between sampled clips
# Note: need approximate check as actual values typically look like [3.2032, 3.2366, 3.2366, 3.2366]
seconds_between_clip_starts = torch.tensor(
[clip.pts_seconds[0] for clip in clips]
).diff()
for diff in seconds_between_clip_starts:
assert diff == pytest.approx(seconds_between_clip_starts[0], abs=0.05)

assert (diff > 0).all() # Also assert clips are sorted by start time

# Check the num_indices_between_frames parameter by asserting that the
# "time" difference between frames in a clip is the same as the "index"
# distance.
Expand All @@ -45,15 +58,16 @@ def test_random_sampler(num_indices_between_frames):
)


@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
@pytest.mark.parametrize(
"sampling_range_start, sampling_range_end, assert_all_equal",
(
(10, 11, True),
(10, 12, False),
),
)
def test_random_sampler_range(
sampling_range_start, sampling_range_end, assert_all_equal
def test_sampling_range(
sampler, sampling_range_start, sampling_range_end, assert_all_equal
):
# Test the sampling_range_start and sampling_range_end parameters by
# asserting that all clips are equal if the sampling range is of size 1,
Expand All @@ -65,7 +79,7 @@ def test_random_sampler_range(

decoder = VideoDecoder(NASA_VIDEO.path)

clips = clips_at_random_indices(
clips = sampler(
decoder,
num_clips=10,
num_frames_per_clip=2,
Expand All @@ -86,21 +100,22 @@ def test_random_sampler_range(
assert_tensor_equal(clip.data, clips[0].data)


def test_random_sampler_range_negative():
@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
def test_sampling_range_negative(sampler):
# Test the passing negative values for sampling_range_start and
# sampling_range_end is the same as passing `len(decoder) - val`

decoder = VideoDecoder(NASA_VIDEO.path)

clips_1 = clips_at_random_indices(
clips_1 = sampler(
decoder,
num_clips=10,
num_frames_per_clip=2,
sampling_range_start=len(decoder) - 100,
sampling_range_end=len(decoder) - 99,
)

clips_2 = clips_at_random_indices(
clips_2 = sampler(
decoder,
num_clips=10,
num_frames_per_clip=2,
Expand Down Expand Up @@ -150,59 +165,87 @@ def test_random_sampler_randomness():
assert builtin_random_state_start == builtin_random_state_end


def test_random_sampler_errors():
@pytest.mark.parametrize(
"num_clips, sampling_range_size",
(
# Ask for 10 clips while the sampling range is 10 frames wide
# expect 10 clips with 10 unique starting points.
(10, 10),
# Ask for 50 clips while the sampling range is only 10 frames wide
# expect 50 clips with only 10 unique starting points.
(50, 10),
),
)
def test_sample_at_regular_indices_num_clips_large(num_clips, sampling_range_size):
# Test for expected behavior described in Note [num clips larger than sampling range]
decoder = VideoDecoder(NASA_VIDEO.path)
clips = clips_at_regular_indices(
decoder,
num_clips=num_clips,
sampling_range_start=0,
sampling_range_end=sampling_range_size, # because sampling_range_start=0
)

assert len(clips) == num_clips

clip_starts_seconds = torch.tensor([clip.pts_seconds[0] for clip in clips])
assert len(torch.unique(clip_starts_seconds)) == sampling_range_size

# Assert clips starts are ordered, i.e. the start indices don't just "wrap
# around". They're duplicated *and* ordered.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mention reason for not wrapping around?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no particular reason honestly, it's just an assertion of the behavior of torch.linspace. I'll add a comment near the Note to clarify that this is somewhat arbitrary.

assert (clip_starts_seconds.diff() >= 0).all()


@pytest.mark.parametrize("sampler", (clips_at_random_indices, clips_at_regular_indices))
def test_random_sampler_errors(sampler):
decoder = VideoDecoder(NASA_VIDEO.path)
with pytest.raises(
ValueError, match=re.escape("num_clips (0) must be strictly positive")
):
clips_at_random_indices(decoder, num_clips=0)
sampler(decoder, num_clips=0)

with pytest.raises(
ValueError, match=re.escape("num_frames_per_clip (0) must be strictly positive")
):
clips_at_random_indices(decoder, num_frames_per_clip=0)
sampler(decoder, num_frames_per_clip=0)

with pytest.raises(
ValueError,
match=re.escape("num_indices_between_frames (0) must be strictly positive"),
):
clips_at_random_indices(decoder, num_indices_between_frames=0)
sampler(decoder, num_indices_between_frames=0)

with pytest.raises(
ValueError,
match=re.escape("Clip span (1000) is larger than the number of frames"),
):
clips_at_random_indices(decoder, num_frames_per_clip=1000)
sampler(decoder, num_frames_per_clip=1000)

with pytest.raises(
ValueError,
match=re.escape("Clip span (1001) is larger than the number of frames"),
):
clips_at_random_indices(
decoder, num_frames_per_clip=2, num_indices_between_frames=1000
)
sampler(decoder, num_frames_per_clip=2, num_indices_between_frames=1000)

with pytest.raises(
ValueError, match=re.escape("sampling_range_start (1000) must be smaller than")
):
clips_at_random_indices(decoder, sampling_range_start=1000)
sampler(decoder, sampling_range_start=1000)

with pytest.raises(
ValueError, match=re.escape("sampling_range_start (4) must be smaller than")
):
clips_at_random_indices(decoder, sampling_range_start=4, sampling_range_end=4)
sampler(decoder, sampling_range_start=4, sampling_range_end=4)

with pytest.raises(
ValueError, match=re.escape("sampling_range_start (290) must be smaller than")
):
clips_at_random_indices(
decoder, sampling_range_start=-100, sampling_range_end=-100
)
sampler(decoder, sampling_range_start=-100, sampling_range_end=-100)

with pytest.raises(
ValueError, match="We determined that sampling_range_end should"
):
clips_at_random_indices(
sampler(
decoder,
num_frames_per_clip=10,
sampling_range_start=len(decoder) - 1,
Expand Down