Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 13 additions & 14 deletions src/torchcodec/decoders/_audio_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torchcodec.decoders import _core as core
from torchcodec.decoders._decoder_utils import (
create_decoder,
get_and_validate_stream_metadata,
ERROR_REPORTING_INSTRUCTIONS,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changes below: we previously had a common util for audio and video that extracted and validated the metadata: get_and_validate_stream_metadata(). Since we moved a bunch of fields as video-only, this util wasn't generic enough anymore to justify its existence, hence a few edits here and in the video_decoder.py file.

)


Expand Down Expand Up @@ -57,15 +57,20 @@ def __init__(
self._decoder, stream_index=stream_index, sample_rate=sample_rate
)

(
self.metadata,
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
) = get_and_validate_stream_metadata(
decoder=self._decoder, stream_index=stream_index, media_type="audio"
container_metadata = core.get_container_metadata(self._decoder)
self.stream_index = (
container_metadata.best_audio_stream_index
if stream_index is None
else stream_index
)
if self.stream_index is None:
raise ValueError(
"The best audio stream is unknown and there is no specified stream. "
+ ERROR_REPORTING_INSTRUCTIONS
)
self.metadata = container_metadata.streams[self.stream_index]
assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy

self._desired_sample_rate = (
sample_rate if sample_rate is not None else self.metadata.sample_rate
)
Expand All @@ -90,12 +95,6 @@ def get_samples_played_in_range(
raise ValueError(
f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})."
)
if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds:
raise ValueError(
f"Invalid start seconds: {start_seconds}. "
f"It must be greater than or equal to {self._begin_stream_seconds} "
f"and less than or equal to {self._end_stream_seconds}."
)
frames, first_pts = core.get_frames_by_pts_in_range_audio(
self._decoder,
start_seconds=start_seconds,
Expand Down
9 changes: 7 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ void VideoDecoder::initializeDecoder() {
streamMetadata.durationSeconds =
av_q2d(avStream->time_base) * avStream->duration;
}
if (avStream->start_time != AV_NOPTS_VALUE) {
streamMetadata.beginStreamFromHeader =
av_q2d(avStream->time_base) * avStream->start_time;
}

if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) {
double fps = av_q2d(avStream->r_frame_rate);
Expand Down Expand Up @@ -944,8 +948,9 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio(
TORCH_CHECK(
frames.size() > 0 && firstFramePtsSeconds.has_value(),
"No audio frames were decoded. ",
"This should probably not happen. ",
"Please report an issue on the TorchCodec repo.");
"This is probably because start_seconds is too high? ",
"Current value is ",
startSeconds);

return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds};
}
Expand Down
1 change: 1 addition & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class VideoDecoder {
std::optional<AVCodecID> codecId;
std::optional<std::string> codecName;
std::optional<double> durationSeconds;
std::optional<double> beginStreamFromHeader;
std::optional<int64_t> numFrames;
std::optional<int64_t> numKeyFrames;
std::optional<double> averageFps;
Expand Down
4 changes: 4 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoderOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ std::string get_stream_json_metadata(
if (streamMetadata.numFrames.has_value()) {
map["numFrames"] = std::to_string(*streamMetadata.numFrames);
}
if (streamMetadata.beginStreamFromHeader.has_value()) {
map["beginStreamFromHeader"] =
std::to_string(*streamMetadata.beginStreamFromHeader);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This reminded me to create Issue #593.

if (streamMetadata.minPtsSecondsFromScan.has_value()) {
map["minPtsSecondsFromScan"] =
std::to_string(*streamMetadata.minPtsSecondsFromScan);
Expand Down
100 changes: 55 additions & 45 deletions src/torchcodec/decoders/_core/_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,27 +32,59 @@ class StreamMetadata:
duration_seconds_from_header: Optional[float]
"""Duration of the stream, in seconds, obtained from the header (float or
None). This could be inaccurate."""
begin_stream_seconds_from_header: Optional[float]
"""Beginning of the stream, in seconds, obtained from the header (float or
None). Usually, this is equal to 0."""
bit_rate: Optional[float]
"""Bit rate of the stream, in seconds (float or None)."""
codec: Optional[str]
"""Codec (str or None)."""
stream_index: int
"""Index of the stream within the video (int)."""

def __repr__(self):
s = self.__class__.__name__ + ":\n"
for field in dataclasses.fields(self):
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
return s


@dataclass
class VideoStreamMetadata(StreamMetadata):
"""Metadata of a single video stream."""

begin_stream_seconds_from_content: Optional[float]
"""Beginning of the stream, in seconds (float or None).
Conceptually, this corresponds to the first frame's :term:`pts`. It is
computed as min(frame.pts) across all frames in the stream. Usually, this is
equal to 0."""
Conceptually, this corresponds to the first frame's :term:`pts`. It is only
computed when a :term:`scan` is done as min(frame.pts) across all frames in
the stream. Usually, this is equal to 0."""
end_stream_seconds_from_content: Optional[float]
"""End of the stream, in seconds (float or None).
Conceptually, this corresponds to last_frame.pts + last_frame.duration. It
is computed as max(frame.pts + frame.duration) across all frames in the
stream. Note that no frame is played at this time value, so calling
:meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with
this value would result in an error. Retrieving the last frame is best done
by simply indexing the :class:`~torchcodec.decoders.VideoDecoder`
object with ``[-1]``.
is only computed when a :term:`scan` is done as max(frame.pts +
frame.duration) across all frames in the stream. Note that no frame is
played at this time value, so calling
:meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with this
value would result in an error. Retrieving the last frame is best done by
simply indexing the :class:`~torchcodec.decoders.VideoDecoder` object with
``[-1]``.
"""
codec: Optional[str]
"""Codec (str or None)."""
stream_index: int
"""Index of the stream within the video (int)."""
width: Optional[int]
"""Width of the frames (int or None)."""
height: Optional[int]
"""Height of the frames (int or None)."""
num_frames_from_header: Optional[int]
"""Number of frames, from the stream's metadata. This is potentially
inaccurate. We recommend using the ``num_frames`` attribute instead.
(int or None)."""
num_frames_from_content: Optional[int]
"""Number of frames computed by TorchCodec by scanning the stream's
content (the scan doesn't involve decoding). This is more accurate
than ``num_frames_from_header``. We recommend using the
``num_frames`` attribute instead. (int or None)."""
average_fps_from_header: Optional[float]
"""Averate fps of the stream, obtained from the header (float or None).
We recommend using the ``average_fps`` attribute instead."""

@property
def duration_seconds(self) -> Optional[float]:
Expand Down Expand Up @@ -94,36 +126,6 @@ def end_stream_seconds(self) -> Optional[float]:
else:
return self.end_stream_seconds_from_content

def __repr__(self):
# Overridden because properites are not printed by default.
s = self.__class__.__name__ + ":\n"
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
for field in dataclasses.fields(self):
s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n"
return s


@dataclass
class VideoStreamMetadata(StreamMetadata):
"""Metadata of a single video stream."""

width: Optional[int]
"""Width of the frames (int or None)."""
height: Optional[int]
"""Height of the frames (int or None)."""
num_frames_from_header: Optional[int]
"""Number of frames, from the stream's metadata. This is potentially
inaccurate. We recommend using the ``num_frames`` attribute instead.
(int or None)."""
num_frames_from_content: Optional[int]
"""Number of frames computed by TorchCodec by scanning the stream's
content (the scan doesn't involve decoding). This is more accurate
than ``num_frames_from_header``. We recommend using the
``num_frames`` attribute instead. (int or None)."""
average_fps_from_header: Optional[float]
"""Averate fps of the stream, obtained from the header (float or None).
We recommend using the ``average_fps`` attribute instead."""

@property
def num_frames(self) -> Optional[int]:
"""Number of frames in the stream. This corresponds to
Expand Down Expand Up @@ -154,6 +156,9 @@ def average_fps(self) -> Optional[float]:

def __repr__(self):
s = super().__repr__()
s += f"{SPACES}duration_seconds: {self.duration_seconds}\n"
s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n"
s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n"
s += f"{SPACES}num_frames: {self.num_frames}\n"
s += f"{SPACES}average_fps: {self.average_fps}\n"
return s
Expand Down Expand Up @@ -224,14 +229,19 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata:
common_meta = dict(
duration_seconds_from_header=stream_dict.get("durationSeconds"),
bit_rate=stream_dict.get("bitRate"),
begin_stream_seconds_from_content=stream_dict.get("minPtsSecondsFromScan"),
end_stream_seconds_from_content=stream_dict.get("maxPtsSecondsFromScan"),
begin_stream_seconds_from_header=stream_dict.get("beginStreamFromHeader"),
codec=stream_dict.get("codec"),
stream_index=stream_index,
)
if stream_dict["mediaType"] == "video":
streams_metadata.append(
VideoStreamMetadata(
begin_stream_seconds_from_content=stream_dict.get(
"minPtsSecondsFromScan"
),
end_stream_seconds_from_content=stream_dict.get(
"maxPtsSecondsFromScan"
),
width=stream_dict.get("width"),
height=stream_dict.get("height"),
num_frames_from_header=stream_dict.get("numFrames"),
Expand Down
54 changes: 1 addition & 53 deletions src/torchcodec/decoders/_decoder_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from pathlib import Path

from typing import Optional, Tuple, Union
from typing import Union

from torch import Tensor
from torchcodec.decoders import _core as core
Expand All @@ -33,55 +33,3 @@ def create_decoder(
f"Unknown source type: {type(source)}. "
"Supported types are str, Path, bytes and Tensor."
)


def get_and_validate_stream_metadata(
*,
decoder: Tensor,
stream_index: Optional[int] = None,
media_type: str,
) -> Tuple[core._metadata.StreamMetadata, int, float, float]:

if media_type not in ("video", "audio"):
raise ValueError(f"Bad {media_type = }, should be audio or video")

container_metadata = core.get_container_metadata(decoder)

if stream_index is None:
best_stream_index = (
container_metadata.best_video_stream_index
if media_type == "video"
else container_metadata.best_audio_stream_index
)
if best_stream_index is None:
raise ValueError(
f"The best {media_type} stream is unknown and there is no specified stream. "
+ ERROR_REPORTING_INSTRUCTIONS
)
stream_index = best_stream_index

# This should be logically true because of the above conditions, but type checker
# is not clever enough.
assert stream_index is not None

metadata = container_metadata.streams[stream_index]

if metadata.begin_stream_seconds is None:
raise ValueError(
"The minimum pts value in seconds is unknown. "
+ ERROR_REPORTING_INSTRUCTIONS
)
begin_stream_seconds = metadata.begin_stream_seconds

if metadata.end_stream_seconds is None:
raise ValueError(
"The maximum pts value in seconds is unknown. "
+ ERROR_REPORTING_INSTRUCTIONS
)
end_stream_seconds = metadata.end_stream_seconds
return (
metadata,
stream_index,
begin_stream_seconds,
end_stream_seconds,
)
63 changes: 51 additions & 12 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

import numbers
from pathlib import Path
from typing import Literal, Optional, Union
from typing import Literal, Optional, Tuple, Union

from torch import device, Tensor

Expand All @@ -15,7 +15,6 @@
from torchcodec.decoders._decoder_utils import (
create_decoder,
ERROR_REPORTING_INSTRUCTIONS,
get_and_validate_stream_metadata,
)


Expand Down Expand Up @@ -108,18 +107,11 @@ def __init__(
self.stream_index,
self._begin_stream_seconds,
self._end_stream_seconds,
) = get_and_validate_stream_metadata(
decoder=self._decoder, stream_index=stream_index, media_type="video"
self._num_frames,
) = _get_and_validate_stream_metadata(
decoder=self._decoder, stream_index=stream_index
)

assert isinstance(self.metadata, core.VideoStreamMetadata) # mypy

if self.metadata.num_frames is None:
raise ValueError(
"The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
)
self._num_frames = self.metadata.num_frames

def __len__(self) -> int:
return self._num_frames

Expand Down Expand Up @@ -338,3 +330,50 @@ def get_frames_played_in_range(
stop_seconds=stop_seconds,
)
return FrameBatch(*frames)


def _get_and_validate_stream_metadata(
*,
decoder: Tensor,
stream_index: Optional[int] = None,
) -> Tuple[core._metadata.VideoStreamMetadata, int, float, float, int]:

container_metadata = core.get_container_metadata(decoder)

if stream_index is None:
if (stream_index := container_metadata.best_video_stream_index) is None:
raise ValueError(
"The best video stream is unknown and there is no specified stream. "
+ ERROR_REPORTING_INSTRUCTIONS
)

metadata = container_metadata.streams[stream_index]
assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy

if metadata.begin_stream_seconds is None:
raise ValueError(
"The minimum pts value in seconds is unknown. "
+ ERROR_REPORTING_INSTRUCTIONS
)
begin_stream_seconds = metadata.begin_stream_seconds

if metadata.end_stream_seconds is None:
raise ValueError(
"The maximum pts value in seconds is unknown. "
+ ERROR_REPORTING_INSTRUCTIONS
)
end_stream_seconds = metadata.end_stream_seconds

if metadata.num_frames is None:
raise ValueError(
"The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS
)
num_frames = metadata.num_frames

return (
metadata,
stream_index,
begin_stream_seconds,
end_stream_seconds,
num_frames,
)
Loading
Loading