Skip to content

Conversation

@NicolasHug
Copy link
Contributor

@NicolasHug NicolasHug commented Sep 24, 2024

This PR adds a random index-based clip sampler.

Leaving the following for other PRs:

  • Edge case policy support, i.e. what to do when there aren't enough frames to sample? Right now, we raise a loud error.
  • Docs

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Meta Open Source bot. label Sep 24, 2024

import torch

from torchcodec.decoders import ( # TODO: move FrameBatch to torchcodec.FrameBatch?
Copy link
Contributor Author

@NicolasHug NicolasHug Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ahmadsharif1 @scotts Thoughts? FrameBatch isn't just used in decoder anymore, so I think we can expose it at the highest level.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll do it. We'll have to move the FrameBatch implementation within torchcodec._something, otherwise this is creating cyclic imports. I'll create a PR to tackle that immediately after this one is merge. I prefer not doing it here to avoid complicating the PR and the reviews.

clip_start_indices = torch.randint(
low=0, high=last_clip_start_index + 1, size=(num_clips,)
)

Copy link
Contributor

@scotts scotts Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(I'm putting the comment here so that we can still read the logic without a giant textbox in the middle of it.)

I'm not sure if we want to compute and then use last_clip_start_index. From a probability perspective, I think we're actually making it much less likely that the frames in the range [last_clip_start_index, len(decoder)) are ever returned to the user. For all other frames, consider that they could be the start of a clip, or included in a clip. For last set of frames, they're only included if we happen to select last_clip_start_index as a clip start.

An alternative approach is that we don't do this, and just set low=0, high=len(decoder). Then we let our clip-too-small policy from above handle the case when the clip start is is in the range [last_clip_start_index, len(decoder)).

Both options result in some quirky behavior. What is currently implemented makes it less likely (I think) that we'll ever return frames in the range [last_clip_start_index, len(decoder)). My suggestion above means that selecting those frames is just as likely as the frames in the rest of the video, but we'll always have some duplicates - and maybe biasing the earlier part of the video is better than having duplicates. I'm not sure, sampling is actually a hard problem. ¯\_(ツ)_/¯

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think uniformly sampling the clip_starts in the range where they are valid matches my own intuition.

Duplicate frames may not be good for training.

If @scotts has a strong opinion, maybe we can expose an option but the current behavior seems good to me as the default.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After chatting offline we concluded the best way to handle this was to let the user choose the sampling range, similarly to the "equaly_spaced_in_time" sampler.


from torchcodec.decoders import FrameBatch, SimpleVideoDecoder


Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below I factored-out some validation logic into functions. This isn't strictly necessary for this PR, but we'll need the exact same logic for the clips_at_regular_indices() sampler. So I prefer extracting the logic now, so that we can minimize the diff changes of the clips_at_regular_indices() PR.

@NicolasHug NicolasHug marked this pull request as ready for review September 30, 2024 15:27
Comment on lines 35 to 36
# TODO: or max(sampling_range_start, num_frames - 1)?
sampling_range_start = sampling_range_start % num_frames
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking for input on this. I'm in a "let's try to avoid mid-train crashes as much as possible" mindset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why aren't we doing the same for sampling_range_end if it's not None?

Maybe you can extract out a function like _coerce_index() or something and apply it for both start and end?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The question as I understand it is: what should the behavior be when the sampling start is greater than the total number of frames? I'm inclined to make this an error for now, for the following reasons:

  1. None of the non-failure options (mod it or max it with number of frames) seem intuitive to me.
  2. We already have a resilient way to do the right thing: don't provide a start and end.
  3. It's easy to relax what was once an error condition, but it's hard to change the behavior of non-error conditions.
  4. This feels tied to the problem of what to do when there are not enough frames to meet the required frames per clip. We know we're going to do something there, but we're currently making it an error. When we have a mechanism to solve that problem, we could use it here, too.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, let's just error for now and stay open for an automated behavior upon user request.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why aren't we doing the same for sampling_range_end if it's not None?

Good catch, thanks. I think we should cap it to num_frames, as this is a fairly common strategy for the upper bound (as in slice indexing).

@@ -0,0 +1,138 @@
import random
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

random nit: why not call it implementation.py? It's not user-facing, so we can have long descriptive names here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I personally don't find implementation substantially more descriptive than implem.
I'm OK to address, but only at the very end before merging, because the renaming would cause GitHub to mark ongoing comments as "outdated".

*, sampling_range_start, sampling_range_end, num_frames, clip_span
):
if sampling_range_start < 0:
sampling_range_start = num_frames + sampling_range_start
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if it's still negative after this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would result in an error because sampling_range_end <= sampling_range_start .

If we don't error, we can either:

  • decide that it's equivalent to 1 (so that it can work when sampling_range_start is 0)
  • keep wrapping-around (modulo).

Neither sounds like a great intuitive option, so I it might be best to error, similar to what we decided in https://github.com/pytorch/torchcodec/pull/221/files#r1781347650

Comment on lines 35 to 36
# TODO: or max(sampling_range_start, num_frames - 1)?
sampling_range_start = sampling_range_start % num_frames
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why aren't we doing the same for sampling_range_end if it's not None?

Maybe you can extract out a function like _coerce_index() or something and apply it for both start and end?

Comment on lines +133 to +136
builtin_random_state = random.getstate()
random.seed(torch.randint(0, 2**32, (1,)).item())
random.shuffle(clips)
random.setstate(builtin_random_state)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if there is a contextmanager for this? I couldn't find it myself but maybe there is a way using the with statement here that auto saves/restores?

Copy link
Contributor Author

@NicolasHug NicolasHug Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's typically the kind of scenario where a CM can be useful indeed. I thought about writing one, but decided against it considering this logic is just 2 lines of code (the CM would be more), and we'll remove it soon when we address the TODO.

# Make sure we didn't alter the builtin Python RNG
builtin_random_state_end = random.getstate()
assert builtin_random_state_start == builtin_random_state_end

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we can fuzz test this to make sure we don't return errors mid-training.

Can pytest help with passing in random values to the parameters (with certain constraints like positive-only values, etc. for things like frames_per_clip, etc.) to make sure we are robust to errors?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is the hypothesis package that I think does what you're suggesting, but I'm not sure how useful it would be in this specific instance? What do you have in mind with

make sure we don't return errors mid-training.

?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was requesting fuzz testing so we can be robust to errors so we don't return errors in the middle of training (which is what your original philosophy was).

image

Now if we are switching to errors mindset, you do seem to have tests for exact errors, so you can resolve this comment.

num_indices_between_frames: int = 1,
sampling_range_start: int = 0,
sampling_range_end: Optional[int] = None, # interval is [start, end).
) -> List[FrameBatch]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did we agree to return a List here? I thought some users had a need to get a single tensor back? Would those users just stack it manually? Note that stacking could be an expensive operation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we discussed this, we decided to "let the implementation decide".
I do agree that stacked tensors make more sense in general, but we also don't want to stack on behalf of the user as it may introduce an unnecessary copy.
The current implementation leads to a list, but it's possible by implementing the "better frame shuffling" strategy (left as a TODO at the bottom of this function), we'll end up with a stack. This is still TBD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Decoding frames to a single tensor doesn't involves a copy -- and that can be done for max speed if we want.

Stacking them after the fact is inefficient.

random.seed(torch.randint(0, 2**32, (1,)).item())
random.shuffle(clips)
random.setstate(builtin_random_state)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is it necessary that we don't affect the random module's state? I assume we're trying not to affect the random values seen by models during actual training, but why do we want to do that?

Copy link
Contributor Author

@NicolasHug NicolasHug Oct 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Users may want to have strict control over the RNG stream of the builtin random module, and we don't want to alter that in any way. Think of a user running different experiments with random.seed(1) and random.seed(2). We'd be altering the RNG streams of these executions to something completely different, rendering their experiment invalid without them knowing.

In general, a library hard-coding a seed for a global RNG stream is a big no no (whether it's the Python RNG, pytorch RNG, numpy, ...). If a library hard-codes a seed, it should be for a local, non-global stream.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, I'm on board with not hardcoding a seed. I guess what I'm confused about is why we can't just call random.shuffle(clips) on its own. Why do we need to set a seed before shuffling? (Without that need, I assume it's okay to consume values in the RNG stream.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're a pytorch library, so in general we want our RNG stream to come from pytorch, not from other RNGs.
In other words we want torch.manual_seed(123) to affect the RNG of the sampler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ohhhh, got it, so we're porting over torch's RNG state to random. Since your last response was the aha moment for me, let's say something to that effect in a comment. :)

@NicolasHug NicolasHug changed the title Add Random index-based clip sampler Add Random index-based clip sampler - part 1 Oct 1, 2024
@NicolasHug NicolasHug merged commit 7bff6a7 into meta-pytorch:main Oct 1, 2024
@NicolasHug NicolasHug deleted the samplerz branch October 1, 2024 14:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants