Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/torchcodec/_samplers/video_clip_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,6 @@ def _get_clips_for_index_based_sampling(
]
frames, *_ = get_frames_at_indices(
video_decoder,
stream_index=metadata_json["bestVideoStreamIndex"],
frame_indices=batch_indexes,
)
clips.append(frames)
Expand Down
29 changes: 9 additions & 20 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,24 +39,23 @@ TORCH_LIBRARY(torchcodec_ns, m) {
m.def(
"get_frame_at_pts(Tensor(a!) decoder, float seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frame_at_index(Tensor(a!) decoder, *, int stream_index, int frame_index) -> (Tensor, Tensor, Tensor)");
"get_frame_at_index(Tensor(a!) decoder, *, int frame_index) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_at_indices(Tensor(a!) decoder, *, int stream_index, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
"get_frames_at_indices(Tensor(a!) decoder, *, int[] frame_indices) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_in_range(Tensor(a!) decoder, *, int stream_index, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
"get_frames_in_range(Tensor(a!) decoder, *, int start, int stop, int? step=None) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, int stream_index, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
"get_frames_by_pts_in_range(Tensor(a!) decoder, *, float start_seconds, float stop_seconds) -> (Tensor, Tensor, Tensor)");
m.def(
"get_frames_by_pts(Tensor(a!) decoder, *, int stream_index, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def(
"_get_key_frame_indices(Tensor(a!) decoder, int stream_index) -> Tensor");
"get_frames_by_pts(Tensor(a!) decoder, *, float[] timestamps) -> (Tensor, Tensor, Tensor)");
m.def("_get_key_frame_indices(Tensor(a!) decoder) -> Tensor");
m.def("get_json_metadata(Tensor(a!) decoder) -> str");
m.def("get_container_json_metadata(Tensor(a!) decoder) -> str");
m.def(
"get_stream_json_metadata(Tensor(a!) decoder, int stream_index) -> str");
m.def("_get_json_ffmpeg_library_versions() -> str");
m.def(
"_test_frame_pts_equality(Tensor(a!) decoder, *, int stream_index, int frame_index, float pts_seconds_to_test) -> bool");
"_test_frame_pts_equality(Tensor(a!) decoder, *, int frame_index, float pts_seconds_to_test) -> bool");
m.def("scan_all_streams_to_update_metadata(Tensor(a!) decoder) -> ()");
}

Expand Down Expand Up @@ -245,18 +244,14 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds) {
return makeOpsFrameOutput(result);
}

OpsFrameOutput get_frame_at_index(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
int64_t frame_index) {
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
auto result = videoDecoder->getFrameAtIndex(frame_index);
return makeOpsFrameOutput(result);
}

OpsFrameBatchOutput get_frames_at_indices(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
at::IntArrayRef frame_indices) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<int64_t> frameIndicesVec(
Expand All @@ -267,7 +262,6 @@ OpsFrameBatchOutput get_frames_at_indices(

OpsFrameBatchOutput get_frames_in_range(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step) {
Expand All @@ -278,7 +272,6 @@ OpsFrameBatchOutput get_frames_in_range(

OpsFrameBatchOutput get_frames_by_pts(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
at::ArrayRef<double> timestamps) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
std::vector<double> timestampsVec(timestamps.begin(), timestamps.end());
Expand All @@ -288,7 +281,6 @@ OpsFrameBatchOutput get_frames_by_pts(

OpsFrameBatchOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
double start_seconds,
double stop_seconds) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
Expand Down Expand Up @@ -321,17 +313,14 @@ std::string mapToJson(const std::map<std::string, std::string>& metadataMap) {

bool _test_frame_pts_equality(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index,
int64_t frame_index,
double pts_seconds_to_test) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return pts_seconds_to_test ==
videoDecoder->getPtsSecondsForFrame(frame_index);
}

torch::Tensor _get_key_frame_indices(
at::Tensor& decoder,
[[maybe_unused]] int64_t stream_index) {
torch::Tensor _get_key_frame_indices(at::Tensor& decoder) {
auto videoDecoder = unwrapTensorToGetDecoder(decoder);
return videoDecoder->getKeyFrameIndices();
}
Expand Down
14 changes: 3 additions & 11 deletions src/torchcodec/decoders/_core/VideoDecoderOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ OpsFrameOutput get_frame_at_pts(at::Tensor& decoder, double seconds);
// Return the frames at given ptss for a given stream
OpsFrameBatchOutput get_frames_by_pts(
at::Tensor& decoder,
int64_t stream_index,
at::ArrayRef<double> timestamps);

// Return the frame that is visible at a given index in the video.
OpsFrameOutput get_frame_at_index(
at::Tensor& decoder,
int64_t stream_index,
int64_t frame_index);
OpsFrameOutput get_frame_at_index(at::Tensor& decoder, int64_t frame_index);

// Get the next frame from the video as a tuple that has the frame data, pts and
// duration as tensors.
Expand All @@ -101,14 +97,12 @@ OpsFrameOutput get_next_frame(at::Tensor& decoder);
// Return the frames at given indices for a given stream
OpsFrameBatchOutput get_frames_at_indices(
at::Tensor& decoder,
int64_t stream_index,
at::IntArrayRef frame_indices);

// Return the frames inside a range as a single stacked Tensor. The range is
// defined as [start, stop).
OpsFrameBatchOutput get_frames_in_range(
at::Tensor& decoder,
int64_t stream_index,
int64_t start,
int64_t stop,
std::optional<int64_t> step = std::nullopt);
Expand All @@ -118,7 +112,6 @@ OpsFrameBatchOutput get_frames_in_range(
// order.
OpsFrameBatchOutput get_frames_by_pts_in_range(
at::Tensor& decoder,
int64_t stream_index,
double start_seconds,
double stop_seconds);

Expand All @@ -128,16 +121,15 @@ OpsFrameBatchOutput get_frames_by_pts_in_range(
// We want to make sure that the value is preserved exactly, bit-for-bit, during
// this process.
//
// Returns true if for the given decoder, in the stream stream_index, the pts
// Returns true if for the given decoder, the pts
// value when converted to seconds as a double is exactly pts_seconds_to_test.
// Returns false otherwise.
bool _test_frame_pts_equality(
at::Tensor& decoder,
int64_t stream_index,
int64_t frame_index,
double pts_seconds_to_test);

torch::Tensor _get_key_frame_indices(at::Tensor& decoder, int64_t stream_index);
torch::Tensor _get_key_frame_indices(at::Tensor& decoder);

// Get the metadata from the video as a string.
std::string get_json_metadata(at::Tensor& decoder);
Expand Down
17 changes: 5 additions & 12 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,9 +152,7 @@ def _getitem_int(self, key: int) -> Tensor:
f"Index {key} is out of bounds; length is {self._num_frames}"
)

frame_data, *_ = core.get_frame_at_index(
self._decoder, frame_index=key, stream_index=self.stream_index
)
frame_data, *_ = core.get_frame_at_index(self._decoder, frame_index=key)
return frame_data

def _getitem_slice(self, key: slice) -> Tensor:
Expand All @@ -163,7 +161,6 @@ def _getitem_slice(self, key: slice) -> Tensor:
start, stop, step = key.indices(len(self))
frame_data, *_ = core.get_frames_in_range(
self._decoder,
stream_index=self.stream_index,
start=start,
stop=stop,
step=step,
Expand All @@ -189,9 +186,7 @@ def __getitem__(self, key: Union[numbers.Integral, slice]) -> Tensor:
)

def _get_key_frame_indices(self) -> list[int]:
return core._get_key_frame_indices(
self._decoder, stream_index=self.stream_index
)
return core._get_key_frame_indices(self._decoder)

def get_frame_at(self, index: int) -> Frame:
"""Return a single frame at the given index.
Expand All @@ -208,7 +203,7 @@ def get_frame_at(self, index: int) -> Frame:
f"Index {index} is out of bounds; must be in the range [0, {self._num_frames})."
)
data, pts_seconds, duration_seconds = core.get_frame_at_index(
self._decoder, frame_index=index, stream_index=self.stream_index
self._decoder, frame_index=index
)
return Frame(
data=data,
Expand All @@ -234,7 +229,7 @@ def get_frames_at(self, indices: list[int]) -> FrameBatch:
"""

data, pts_seconds, duration_seconds = core.get_frames_at_indices(
self._decoder, stream_index=self.stream_index, frame_indices=indices
self._decoder, frame_indices=indices
)
return FrameBatch(
data=data,
Expand Down Expand Up @@ -268,7 +263,6 @@ def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatc
raise IndexError(f"Step ({step}) must be greater than 0.")
frames = core.get_frames_in_range(
self._decoder,
stream_index=self.stream_index,
start=start,
stop=stop,
step=step,
Expand Down Expand Up @@ -316,7 +310,7 @@ def get_frames_played_at(self, seconds: list[float]) -> FrameBatch:
FrameBatch: The frames that are played at ``seconds``.
"""
data, pts_seconds, duration_seconds = core.get_frames_by_pts(
self._decoder, timestamps=seconds, stream_index=self.stream_index
self._decoder, timestamps=seconds
)
return FrameBatch(
data=data,
Expand Down Expand Up @@ -359,7 +353,6 @@ def get_frames_played_in_range(
)
frames = core.get_frames_by_pts_in_range(
self._decoder,
stream_index=self.stream_index,
start_seconds=start_seconds,
stop_seconds=stop_seconds,
)
Expand Down
4 changes: 1 addition & 3 deletions test/decoders/manual_smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,5 @@
)
torchcodec.decoders._core.scan_all_streams_to_update_metadata(decoder)
torchcodec.decoders._core.add_video_stream(decoder, stream_index=3)
frame, _, _ = torchcodec.decoders._core.get_frame_at_index(
decoder, stream_index=3, frame_index=180
)
frame, _, _ = torchcodec.decoders._core.get_frame_at_index(decoder, frame_index=180)
write_png(frame, "frame180.png")
Loading
Loading