Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -602,25 +602,34 @@ FrameOutput SingleStreamDecoder::getFrameAtIndexInternal(
}

FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
const std::vector<int64_t>& frameIndices) {
const torch::Tensor& frameIndices) {
validateActiveStream(AVMEDIA_TYPE_VIDEO);

auto indicesAreSorted =
std::is_sorted(frameIndices.begin(), frameIndices.end());
auto frameIndicesAccessor = frameIndices.accessor<int64_t, 1>();

bool indicesAreSorted = true;
for (int64_t i = 1; i < frameIndices.numel(); ++i) {
if (frameIndicesAccessor[i] < frameIndicesAccessor[i - 1]) {
indicesAreSorted = false;
break;
}
}

std::vector<size_t> argsort;
if (!indicesAreSorted) {
// if frameIndices is [13, 10, 12, 11]
// when sorted, it's [10, 11, 12, 13] <-- this is the sorted order we want
// to use to decode the frames
// and argsort is [ 1, 3, 2, 0]
argsort.resize(frameIndices.size());
argsort.resize(frameIndices.numel());
for (size_t i = 0; i < argsort.size(); ++i) {
argsort[i] = i;
}
std::sort(
argsort.begin(), argsort.end(), [&frameIndices](size_t a, size_t b) {
return frameIndices[a] < frameIndices[b];
argsort.begin(),
argsort.end(),
[&frameIndicesAccessor](size_t a, size_t b) {
return frameIndicesAccessor[a] < frameIndicesAccessor[b];
});
}

Expand All @@ -629,12 +638,12 @@ FrameBatchOutput SingleStreamDecoder::getFramesAtIndices(
const auto& streamInfo = streamInfos_[activeStreamIndex_];
const auto& videoStreamOptions = streamInfo.videoStreamOptions;
FrameBatchOutput frameBatchOutput(
frameIndices.size(), videoStreamOptions, streamMetadata);
frameIndices.numel(), videoStreamOptions, streamMetadata);

auto previousIndexInVideo = -1;
for (size_t f = 0; f < frameIndices.size(); ++f) {
for (int64_t f = 0; f < frameIndices.numel(); ++f) {
auto indexInOutput = indicesAreSorted ? f : argsort[f];
auto indexInVideo = frameIndices[indexInOutput];
auto indexInVideo = frameIndicesAccessor[indexInOutput];

if ((f > 0) && (indexInVideo == previousIndexInVideo)) {
// Avoid decoding the same frame twice
Expand Down Expand Up @@ -776,7 +785,8 @@ FrameBatchOutput SingleStreamDecoder::getFramesPlayedAt(
frameIndices[i] = secondsToIndexLowerBound(frameSeconds);
}

return getFramesAtIndices(frameIndices);
// TODO: Support tensors natively instead of a vector to avoid a copy.
return getFramesAtIndices(torch::tensor(frameIndices));
}

FrameBatchOutput SingleStreamDecoder::getFramesPlayedInRange(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/_core/SingleStreamDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class SingleStreamDecoder {

// Returns frames at the given indices for a given stream as a single stacked
// Tensor.
FrameBatchOutput getFramesAtIndices(const std::vector<int64_t>& frameIndices);
FrameBatchOutput getFramesAtIndices(const torch::Tensor& frameIndices);

// Returns frames within a given range. The range is defined by [start, stop).
// The values retrieved from the range are: [start, start+step,
Expand Down
8 changes: 3 additions & 5 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
"get_frames_at_indices(Tensor(a!) decoder, *, Tensor frame_indices) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
Expand Down Expand Up @@ -378,11 +378,9 @@ OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
// Return the frames at given indices for a given stream
OpsFrameBatchOutput get_frames_at_indices(
at::Tensor& decoder,
at::IntArrayRef frame_indices) {
const at::Tensor& frame_indices) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<int64_t> frameIndicesVec(
frame_indices.begin(), frame_indices.end());
auto result = videoDecoder->getFramesAtIndices(frameIndicesVec);
auto result = videoDecoder->getFramesAtIndices(frame_indices);
return makeOpsFrameBatchOutput(result);
}

Expand Down
20 changes: 16 additions & 4 deletions src/torchcodec/_core/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def load_torchcodec_shared_libraries():
get_next_frame = torch.ops.torchcodec_ns.get_next_frame.default
get_frame_at_pts = torch.ops.torchcodec_ns.get_frame_at_pts.default
get_frame_at_index = torch.ops.torchcodec_ns.get_frame_at_index.default
get_frames_at_indices = torch.ops.torchcodec_ns.get_frames_at_indices.default
_get_frames_at_indices_tensor_input = (
torch.ops.torchcodec_ns.get_frames_at_indices.default
)
get_frames_by_pts = torch.ops.torchcodec_ns.get_frames_by_pts.default
get_frames_in_range = torch.ops.torchcodec_ns.get_frames_in_range.default
get_frames_by_pts_in_range = torch.ops.torchcodec_ns.get_frames_by_pts_in_range.default
Expand Down Expand Up @@ -198,6 +200,18 @@ def encode_audio_to_file_like(
)


def get_frames_at_indices(
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, list[int]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if isinstance(frame_indices, torch.Tensor):
# Ensure indices is the correct dtype (int64)
frame_indices = frame_indices.to(torch.int64)
else:
# Convert list to tensor for dispatch
frame_indices = torch.tensor(frame_indices)
return _get_frames_at_indices_tensor_input(decoder, frame_indices=frame_indices)


# ==============================
# Abstract impl for the operators. Needed by torch.compile.
# ==============================
Expand Down Expand Up @@ -371,9 +385,7 @@ def get_frame_at_index_abstract(

@register_fake("torchcodec_ns::get_frames_at_indices")
def get_frames_at_indices_abstract(
decoder: torch.Tensor,
*,
frame_indices: List[int],
decoder: torch.Tensor, *, frame_indices: Union[torch.Tensor, List[int]]
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
image_size = [get_ctx().new_dynamic_size() for _ in range(4)]
return (
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/_samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ def _get_clips_for_index_based_sampling(
clip_start_idx + i * index_based_sampler_args.video_frame_dilation
for i in range(index_based_sampler_args.frames_per_clip)
]
# Need torch.stack to convert List[Tensor[int]] into 1D Tensor[int]
batch_indexes = torch.stack(batch_indexes)
frames, *_ = get_frames_at_indices(
video_decoder,
frame_indices=batch_indexes,
Expand Down
10 changes: 3 additions & 7 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,24 +240,20 @@ def get_frame_at(self, index: int) -> Frame:
duration_seconds=duration_seconds.item(),
)

def get_frames_at(self, indices: list[int]) -> FrameBatch:
def get_frames_at(self, indices: Union[torch.Tensor, list[int]]) -> FrameBatch:
"""Return frames at the given indices.

Args:
indices (list of int): The indices of the frames to retrieve.
indices (torch.Tensor or list of int): The indices of the frames to retrieve.

Returns:
FrameBatch: The frames at the given indices.
"""
if isinstance(indices, torch.Tensor):
# TODO we should avoid converting tensors to lists and just let the
# core ops and C++ code natively accept tensors. See
# https://github.com/pytorch/torchcodec/issues/879
indices = indices.to(torch.int).tolist()

data, pts_seconds, duration_seconds = core.get_frames_at_indices(
self._decoder, frame_indices=indices
)

return FrameBatch(
data=data,
pts_seconds=pts_seconds,
Expand Down
6 changes: 4 additions & 2 deletions test/VideoDecoderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNCHW) {
*ourDecoder->getContainerMetadata().bestVideoStreamIndex;
ourDecoder->addVideoStream(bestVideoStreamIndex);
// Frame with index 180 corresponds to timestamp 6.006.
auto output = ourDecoder->getFramesAtIndices({0, 180});
auto frameIndices = torch::tensor({0, 180});
auto output = ourDecoder->getFramesAtIndices(frameIndices);
auto tensor = output.data;
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 3, 270, 480}));

Expand All @@ -246,7 +247,8 @@ TEST_P(SingleStreamDecoderTest, DecodesFramesInABatchInNHWC) {
videoStreamOptions.dimensionOrder = "NHWC";
ourDecoder->addVideoStream(bestVideoStreamIndex, videoStreamOptions);
// Frame with index 180 corresponds to timestamp 6.006.
auto output = ourDecoder->getFramesAtIndices({0, 180});
auto frameIndices = torch::tensor({0, 180});
auto output = ourDecoder->getFramesAtIndices(frameIndices);
auto tensor = output.data;
EXPECT_EQ(tensor.sizes(), std::vector<long>({2, 270, 480, 3}));

Expand Down
4 changes: 3 additions & 1 deletion test/test_decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,7 +569,9 @@ def test_get_frames_at_fails(self, device, seek_mode):
with pytest.raises(IndexError, match="Invalid frame index=390"):
decoder.get_frames_at([390])

with pytest.raises(RuntimeError, match="Expected a value of type"):
with pytest.raises(
RuntimeError, match="expected scalar type Long but found Float"
):
decoder.get_frames_at([0.3])

@pytest.mark.parametrize("device", all_supported_devices())
Expand Down
2 changes: 1 addition & 1 deletion test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,7 +1209,7 @@ def seek(self, offset: int, whence: int) -> int:
torch.manual_seed(0)
indices = torch.randint(
0, len(NASA_VIDEO.frames[NASA_VIDEO.default_stream_index]), size=(50,)
).tolist()
)

frames_file_like, *_ = get_frames_at_indices(
decoder_file_like, frame_indices=indices
Expand Down
Loading