Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,6 +538,20 @@ VideoDecoder::ContainerMetadata VideoDecoder::getContainerMetadata() const {
return containerMetadata_;
}

torch::Tensor VideoDecoder::getKeyFrameIndices(int streamIndex) {
validateUserProvidedStreamIndex(streamIndex);
validateScannedAllStreams("getKeyFrameIndices");

const std::vector<FrameInfo>& keyFrames = streamInfos_[streamIndex].keyFrames;
torch::Tensor keyFrameIndices =
torch::empty({static_cast<int64_t>(keyFrames.size())}, {torch::kInt64});
for (size_t i = 0; i < keyFrames.size(); ++i) {
keyFrameIndices[i] = keyFrames[i].frameIndex;
}

return keyFrameIndices;
}

int VideoDecoder::getKeyFrameIndexForPtsUsingEncoderIndex(
AVStream* stream,
int64_t pts) const {
Expand Down Expand Up @@ -654,7 +668,21 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() {
return frameInfo1.pts < frameInfo2.pts;
});

size_t keyIndex = 0;
for (size_t i = 0; i < streamInfo.allFrames.size(); ++i) {
streamInfo.allFrames[i].frameIndex = i;

// For correctly encoded files, we shouldn't need to ensure that keyIndex
// is less than the number of key frames. That is, the relationship
// between the frames in allFrames and keyFrames should be such that
// keyIndex is always a valid index into keyFrames. But we're being
// defensive in case we encounter incorrectly encoded files.
if (keyIndex < streamInfo.keyFrames.size() &&
streamInfo.keyFrames[keyIndex].pts == streamInfo.allFrames[i].pts) {
streamInfo.keyFrames[keyIndex].frameIndex = i;
++keyIndex;
}

if (i + 1 < streamInfo.allFrames.size()) {
streamInfo.allFrames[i].nextPts = streamInfo.allFrames[i + 1].pts;
}
Expand Down
21 changes: 16 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ class VideoDecoder {
// Returns the metadata for the container.
ContainerMetadata getContainerMetadata() const;

// Returns the key frame indices as a tensor. The tensor is 1D and contains
// int64 values, where each value is the frame index for a key frame.
torch::Tensor getKeyFrameIndices(int streamIndex);

// --------------------------------------------------------------------------
// ADDING STREAMS API
// --------------------------------------------------------------------------
Expand Down Expand Up @@ -284,12 +288,19 @@ class VideoDecoder {

struct FrameInfo {
int64_t pts = 0;
// The value of this default is important: the last frame's nextPts will be
// INT64_MAX, which ensures that the allFrames vec contains FrameInfo
// structs with *increasing* nextPts values. That's a necessary condition
// for the binary searches on those values to work properly (as typically
// done during pts -> index conversions.)

// The value of the nextPts default is important: the last frame's nextPts
// will be INT64_MAX, which ensures that the allFrames vec contains
// FrameInfo structs with *increasing* nextPts values. That's a necessary
// condition for the binary searches on those values to work properly (as
// typically done during pts -> index conversions).
int64_t nextPts = INT64_MAX;

// Note that frameIndex is ALWAYS the index into all of the frames in that
// stream, even when the FrameInfo is part of the key frame index. Given a
// FrameInfo for a key frame, the frameIndex allows us to know which frame
// that is in the stream.
int64_t frameIndex = 0;
};

struct FilterGraphContext {
Expand Down
10 changes: 10 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ TORCH_LIBRARY(torchcodec_ns, m) {
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def(
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
Expand Down Expand Up @@ -334,6 +336,13 @@ bool _test_frame_pts_equality(
videoDecoder->getPtsSecondsForFrame(stream_index, frame_index);
}

torch::Tensor _get_key_frame_indices(
at::Tensor& decoder,
int64_t stream_index) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return videoDecoder->getKeyFrameIndices(stream_index);
}

std::string get_json_metadata(at::Tensor& decoder) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);

Expand Down Expand Up @@ -526,6 +535,7 @@ TORCH_LIBRARY_IMPL(torchcodec_ns, CPU, m) {
m.impl("add_video_stream", &add_video_stream);
m.impl("_add_video_stream", &_add_video_stream);
m.impl("get_next_frame", &get_next_frame);
m.impl("_get_key_frame_indices", &_get_key_frame_indices);
m.impl("get_json_metadata", &get_json_metadata);
m.impl("get_container_json_metadata", &get_container_json_metadata);
m.impl("get_stream_json_metadata", &get_stream_json_metadata);
Expand Down
2 changes: 2 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ bool _test_frame_pts_equality(
int64_t frame_index,
double pts_seconds_to_test);

torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);

// Get the metadata from the video as a string.
std::string get_json_metadata(at::Tensor& decoder);

Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/_core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from .video_decoder_ops import (
_add_video_stream,
_get_key_frame_indices,
_test_frame_pts_equality,
add_video_stream,
create_from_bytes,
Expand Down
8 changes: 8 additions & 0 deletions src/torchcodec/decoders/_core/video_decoder_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def load_torchcodec_extension():
_get_container_json_metadata = (
torch.ops.torchcodec_ns.get_container_json_metadata.default
)
_get_key_frame_indices = torch.ops.torchcodec_ns._get_key_frame_indices.default
scan_all_streams_to_update_metadata = (
torch.ops.torchcodec_ns.scan_all_streams_to_update_metadata.default
)
Expand Down Expand Up @@ -255,6 +256,13 @@ def get_frames_by_pts_in_range_abstract(
)


@register_fake("torchcodec_ns::_get_key_frame_indices")
def get_key_frame_indices_abstract(
decoder: torch.Tensor, *, stream_index: int
) -> torch.Tensor:
return torch.empty([], dtype=torch.int)


@register_fake("torchcodec_ns::get_json_metadata")
def get_json_metadata_abstract(decoder: torch.Tensor) -> str:
return ""
Expand Down
5 changes: 5 additions & 0 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,11 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
f"Unsupported key type: {type(key)}. Supported types are int and slice."
)

def _get_key_frame_indices(self) -> list[int]:
return core._get_key_frame_indices(
self._decoder, stream_index=self.stream_index
)

def get_frame_at(self, index: int) -> Frame:
"""Return a single frame at the given index.

Expand Down
43 changes: 43 additions & 0 deletions test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,49 @@ def test_get_frames_by_pts_in_range_fails(self, device, seek_mode):
with pytest.raises(ValueError, match="Invalid stop seconds"):
frame = decoder.get_frames_played_in_range(0, 23) # noqa

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_key_frame_indices(self, device):
decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode="exact")
key_frame_indices = decoder._get_key_frame_indices()

# The key frame indices were generated from the following command:
# $ ffprobe -v error -hide_banner -select_streams v:1 -show_frames -of csv test/resources/nasa_13013.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left the line-noise of a shell command on a single line so that it's easier to copy-paste. It does make it harder to read, but making it easy to read means breaking it up over several lines, and then when you go to copy-paste, there's the # comment markers in there. Happy to change it to a better way.

# What it's doing:
# 1. Calling ffprobe on the second video stream, which is absolute stream index 3.
# 2. Showing all frames for that stream.
# 3. Using grep to find the "I" frames, which are the key frames. We also get the line
# number, which is also the count of the rames.
# 4. Using cut to extract just the count for the frame.
# Finally, because the above produces a count, which is index + 1, we subtract
# one from all values manually to arrive at the values below.
Comment on lines +846 to +848
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose you meant the same thing, but my understanding is that -n will output the line number (not the count), and the cut part will select that line number (which is 1-based)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I meant the same thing. I can try to clarify.

# TODO: decide if/how we want to incorporate key frame indices into the utils
# framework.
nasa_reference_key_frame_indices = torch.tensor([0, 240])

torch.testing.assert_close(
key_frame_indices, nasa_reference_key_frame_indices, atol=0, rtol=0
)

decoder = VideoDecoder(AV1_VIDEO.path, device=device, seek_mode="exact")
key_frame_indices = decoder._get_key_frame_indices()

# $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/av1_video.mkv | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
av1_reference_key_frame_indices = torch.tensor([0])

torch.testing.assert_close(
key_frame_indices, av1_reference_key_frame_indices, atol=0, rtol=0
)

decoder = VideoDecoder(H265_VIDEO.path, device=device, seek_mode="exact")
key_frame_indices = decoder._get_key_frame_indices()

# ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of csv test/resources/h265_video.mp4 | grep -n ",I," | cut -d ':' -f 1 > key_frames.txt
h265_reference_key_frame_indices = torch.tensor([0, 2, 4, 6, 8])

torch.testing.assert_close(
key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0
)


if __name__ == "__main__":
pytest.main()
Loading