Skip to content
Merged
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
c090e44
Add Random index-based clip sampler
NicolasHug Sep 24, 2024
6b52cfd
Merge branch 'main' of github.com:pytorch/torchcodec into samplerz
NicolasHug Sep 24, 2024
c618160
Basic linear sampler
NicolasHug Sep 24, 2024
b5545a9
Merge branch 'main' of github.com:pytorch/torchcodec into samplerzz
NicolasHug Oct 4, 2024
2c5a559
Add tests
NicolasHug Oct 4, 2024
ce46196
Handle edge case when num_clips is larger than available sampling range
NicolasHug Oct 4, 2024
5fed662
WIP
NicolasHug Oct 4, 2024
1873174
Address comments
NicolasHug Oct 4, 2024
e86e017
Merge branch 'samplerzz' into samplers_edge_case
NicolasHug Oct 4, 2024
63462e9
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_ed…
NicolasHug Oct 4, 2024
a97afe7
Samplers: add support for edge-case policies
NicolasHug Oct 4, 2024
83c6763
Refactor + comments
NicolasHug Oct 4, 2024
71a839a
Add tests
NicolasHug Oct 7, 2024
a333e9b
Minor clip_span refactoring
NicolasHug Oct 7, 2024
08833b0
Don't add defaults to private abstract sampler
NicolasHug Oct 7, 2024
3dcbe1e
Speed-up samplers by avoiding backwards seeks
NicolasHug Oct 7, 2024
054b72e
abstract -> generic
NicolasHug Oct 7, 2024
2bb1d58
typo fix
NicolasHug Oct 7, 2024
9b93214
Typo
NicolasHug Oct 7, 2024
7a500de
Merge branch 'samplers_edge_case' into samplers_fast
NicolasHug Oct 7, 2024
75945a8
Comment
NicolasHug Oct 7, 2024
0c5c537
Comment
NicolasHug Oct 7, 2024
5585b46
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_fast
NicolasHug Oct 7, 2024
523dff9
Merge branch 'main' of github.com:pytorch/torchcodec into samplers_fast
NicolasHug Oct 7, 2024
5814439
Fix merge
NicolasHug Oct 7, 2024
6e18e07
slightly better index name
NicolasHug Oct 7, 2024
6378a34
Add note
NicolasHug Oct 7, 2024
9ae87cc
Add sampler benchmarking code
NicolasHug Oct 7, 2024
5eba3e4
Revert "Add sampler benchmarking code"
NicolasHug Oct 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 35 additions & 9 deletions src/torchcodec/samplers/_implem.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,28 +160,54 @@ def _build_all_clips_indices(
def _decode_all_clips_indices(
decoder: VideoDecoder, all_clips_indices: list[int], num_frames_per_clip: int
) -> list[FrameBatch]:
# This takes the list of all the frames to decode, decode all the frames,
# and then packs them into clips of length num_frames_per_clip.
# This is slow, unoptimized, and u.g.l.y. It is not meant to stay.
# TODO:
# - sort the frames to avoid backward seeks, dedup, decode, and re-organize frames.
# - write most of this in C++
# This takes the list of all the frames to decode (in arbitrary order),
# decode all the frames, and then packs them into clips of length
# num_frames_per_clip.
#
# To avoid backwards seeks (which are slow), we:
# - sort all the frame indices to be decoded
# - dedup them
# - decode all unique frames in sorted order
# - re-assemble the decoded frames back to their original order
#
# TODO: Write this in C++ so we can avoid the copies that happen in `to_framebatch`

def chunk_list(lst, chunk_size):
# return list of sublists of length chunk_size
return [lst[i : i + chunk_size] for i in range(0, len(lst), chunk_size)]

def to_framebatch(frames: list[Frame]) -> FrameBatch:
# IMPORTANT: see other IMPORTANT note below
data = torch.stack([frame.data for frame in frames])
pts_seconds = torch.tensor([frame.pts_seconds for frame in frames])
duration_seconds = torch.tensor([frame.duration_seconds for frame in frames])
return FrameBatch(
data=data, pts_seconds=pts_seconds, duration_seconds=duration_seconds
)

all_decoded_frames: list[Frame] = [
decoder.get_frame_at(index) for index in all_clips_indices
]
all_clips_indices_sorted, argsort = zip(
*sorted((frame_index, i) for (i, frame_index) in enumerate(all_clips_indices))
)
previous_decoded_frame = None
all_decoded_frames = [None] * len(all_clips_indices)
for i, j in enumerate(argsort):
frame_index = all_clips_indices_sorted[i]
if (
previous_decoded_frame is not None # then we know i > 0
and frame_index == all_clips_indices_sorted[i - 1]
):
# Avoid decoding the same frame twice.
# IMPORTANT: this is only correct because a copy of the frame will
# happen within `to_framebatch` when we call torch.stack.
# If a copy isn't made, the same underlying memory will be used for
# the 2 consecutive frames. When we re-write this, we should make
# sure to explicitly copy the data.
decoded_frame = previous_decoded_frame
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is setting it to the same python object, right?

Will there be any issues with that? Example, if the user modifies that tensor or something else in FrameBatch -- they will modify both entries in the list, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be slightly safer w.r.t. future changes this should be

decoded_frame = copy(previous_decoded_frame)

but we don't implement __copy__ on Frame.

Note that a copy still happens within to_framebatch, so this is currently safe, but admittedly subject to an implementation detail that will change.

We can either:

  • be OK with this since we'll re-implement it in C++ anyway
  • implement __copy__.

LMK.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am OK with this as-is.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I'll add a comment in to_framebatch() so we don't accidentally mess it up.

else:
decoded_frame = decoder.get_frame_at(index=frame_index)
previous_decoded_frame = decoded_frame
all_decoded_frames[j] = decoded_frame

all_clips: list[list[Frame]] = chunk_list(
all_decoded_frames, chunk_size=num_frames_per_clip
)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note that we don't have to chunk the clips. The implementation already allows us to return a single 5D FrameBatch instead of a list[4D FrameBatch]. I'll just leave this for another PR so we can discuss.

Expand Down
Loading