diff --git a/src/torchcodec/decoders/_audio_decoder.py b/src/torchcodec/decoders/_audio_decoder.py index cac968e5d..d0e7ede00 100644 --- a/src/torchcodec/decoders/_audio_decoder.py +++ b/src/torchcodec/decoders/_audio_decoder.py @@ -13,7 +13,7 @@ from torchcodec.decoders import _core as core from torchcodec.decoders._decoder_utils import ( create_decoder, - get_and_validate_stream_metadata, + ERROR_REPORTING_INSTRUCTIONS, ) @@ -57,15 +57,20 @@ def __init__( self._decoder, stream_index=stream_index, sample_rate=sample_rate ) - ( - self.metadata, - self.stream_index, - self._begin_stream_seconds, - self._end_stream_seconds, - ) = get_and_validate_stream_metadata( - decoder=self._decoder, stream_index=stream_index, media_type="audio" + container_metadata = core.get_container_metadata(self._decoder) + self.stream_index = ( + container_metadata.best_audio_stream_index + if stream_index is None + else stream_index ) + if self.stream_index is None: + raise ValueError( + "The best audio stream is unknown and there is no specified stream. " + + ERROR_REPORTING_INSTRUCTIONS + ) + self.metadata = container_metadata.streams[self.stream_index] assert isinstance(self.metadata, core.AudioStreamMetadata) # mypy + self._desired_sample_rate = ( sample_rate if sample_rate is not None else self.metadata.sample_rate ) @@ -90,12 +95,6 @@ def get_samples_played_in_range( raise ValueError( f"Invalid start seconds: {start_seconds}. It must be less than or equal to stop seconds ({stop_seconds})." ) - if not self._begin_stream_seconds <= start_seconds < self._end_stream_seconds: - raise ValueError( - f"Invalid start seconds: {start_seconds}. " - f"It must be greater than or equal to {self._begin_stream_seconds} " - f"and less than or equal to {self._end_stream_seconds}." - ) frames, first_pts = core.get_frames_by_pts_in_range_audio( self._decoder, start_seconds=start_seconds, diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index c435ea72d..80b6fa7c5 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -147,6 +147,10 @@ void VideoDecoder::initializeDecoder() { streamMetadata.durationSeconds = av_q2d(avStream->time_base) * avStream->duration; } + if (avStream->start_time != AV_NOPTS_VALUE) { + streamMetadata.beginStreamFromHeader = + av_q2d(avStream->time_base) * avStream->start_time; + } if (avStream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO) { double fps = av_q2d(avStream->r_frame_rate); @@ -944,8 +948,9 @@ VideoDecoder::AudioFramesOutput VideoDecoder::getFramesPlayedInRangeAudio( TORCH_CHECK( frames.size() > 0 && firstFramePtsSeconds.has_value(), "No audio frames were decoded. ", - "This should probably not happen. ", - "Please report an issue on the TorchCodec repo."); + "This is probably because start_seconds is too high? ", + "Current value is ", + startSeconds); return AudioFramesOutput{torch::cat(frames, 1), *firstFramePtsSeconds}; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 6c727ff6b..c480ed3ea 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -59,6 +59,7 @@ class VideoDecoder { std::optional codecId; std::optional codecName; std::optional durationSeconds; + std::optional beginStreamFromHeader; std::optional numFrames; std::optional numKeyFrames; std::optional averageFps; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index adbed7cae..786d3f327 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -473,6 +473,10 @@ std::string get_stream_json_metadata( if (streamMetadata.numFrames.has_value()) { map["numFrames"] = std::to_string(*streamMetadata.numFrames); } + if (streamMetadata.beginStreamFromHeader.has_value()) { + map["beginStreamFromHeader"] = + std::to_string(*streamMetadata.beginStreamFromHeader); + } if (streamMetadata.minPtsSecondsFromScan.has_value()) { map["minPtsSecondsFromScan"] = std::to_string(*streamMetadata.minPtsSecondsFromScan); diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 806526370..6ecc8e33e 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -32,27 +32,59 @@ class StreamMetadata: duration_seconds_from_header: Optional[float] """Duration of the stream, in seconds, obtained from the header (float or None). This could be inaccurate.""" + begin_stream_seconds_from_header: Optional[float] + """Beginning of the stream, in seconds, obtained from the header (float or + None). Usually, this is equal to 0.""" bit_rate: Optional[float] """Bit rate of the stream, in seconds (float or None).""" + codec: Optional[str] + """Codec (str or None).""" + stream_index: int + """Index of the stream within the video (int).""" + + def __repr__(self): + s = self.__class__.__name__ + ":\n" + for field in dataclasses.fields(self): + s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n" + return s + + +@dataclass +class VideoStreamMetadata(StreamMetadata): + """Metadata of a single video stream.""" + begin_stream_seconds_from_content: Optional[float] """Beginning of the stream, in seconds (float or None). - Conceptually, this corresponds to the first frame's :term:`pts`. It is - computed as min(frame.pts) across all frames in the stream. Usually, this is - equal to 0.""" + Conceptually, this corresponds to the first frame's :term:`pts`. It is only + computed when a :term:`scan` is done as min(frame.pts) across all frames in + the stream. Usually, this is equal to 0.""" end_stream_seconds_from_content: Optional[float] """End of the stream, in seconds (float or None). Conceptually, this corresponds to last_frame.pts + last_frame.duration. It - is computed as max(frame.pts + frame.duration) across all frames in the - stream. Note that no frame is played at this time value, so calling - :meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with - this value would result in an error. Retrieving the last frame is best done - by simply indexing the :class:`~torchcodec.decoders.VideoDecoder` - object with ``[-1]``. + is only computed when a :term:`scan` is done as max(frame.pts + + frame.duration) across all frames in the stream. Note that no frame is + played at this time value, so calling + :meth:`~torchcodec.decoders.VideoDecoder.get_frame_played_at` with this + value would result in an error. Retrieving the last frame is best done by + simply indexing the :class:`~torchcodec.decoders.VideoDecoder` object with + ``[-1]``. """ - codec: Optional[str] - """Codec (str or None).""" - stream_index: int - """Index of the stream within the video (int).""" + width: Optional[int] + """Width of the frames (int or None).""" + height: Optional[int] + """Height of the frames (int or None).""" + num_frames_from_header: Optional[int] + """Number of frames, from the stream's metadata. This is potentially + inaccurate. We recommend using the ``num_frames`` attribute instead. + (int or None).""" + num_frames_from_content: Optional[int] + """Number of frames computed by TorchCodec by scanning the stream's + content (the scan doesn't involve decoding). This is more accurate + than ``num_frames_from_header``. We recommend using the + ``num_frames`` attribute instead. (int or None).""" + average_fps_from_header: Optional[float] + """Averate fps of the stream, obtained from the header (float or None). + We recommend using the ``average_fps`` attribute instead.""" @property def duration_seconds(self) -> Optional[float]: @@ -94,36 +126,6 @@ def end_stream_seconds(self) -> Optional[float]: else: return self.end_stream_seconds_from_content - def __repr__(self): - # Overridden because properites are not printed by default. - s = self.__class__.__name__ + ":\n" - s += f"{SPACES}duration_seconds: {self.duration_seconds}\n" - for field in dataclasses.fields(self): - s += f"{SPACES}{field.name}: {getattr(self, field.name)}\n" - return s - - -@dataclass -class VideoStreamMetadata(StreamMetadata): - """Metadata of a single video stream.""" - - width: Optional[int] - """Width of the frames (int or None).""" - height: Optional[int] - """Height of the frames (int or None).""" - num_frames_from_header: Optional[int] - """Number of frames, from the stream's metadata. This is potentially - inaccurate. We recommend using the ``num_frames`` attribute instead. - (int or None).""" - num_frames_from_content: Optional[int] - """Number of frames computed by TorchCodec by scanning the stream's - content (the scan doesn't involve decoding). This is more accurate - than ``num_frames_from_header``. We recommend using the - ``num_frames`` attribute instead. (int or None).""" - average_fps_from_header: Optional[float] - """Averate fps of the stream, obtained from the header (float or None). - We recommend using the ``average_fps`` attribute instead.""" - @property def num_frames(self) -> Optional[int]: """Number of frames in the stream. This corresponds to @@ -154,6 +156,9 @@ def average_fps(self) -> Optional[float]: def __repr__(self): s = super().__repr__() + s += f"{SPACES}duration_seconds: {self.duration_seconds}\n" + s += f"{SPACES}begin_stream_seconds: {self.begin_stream_seconds}\n" + s += f"{SPACES}end_stream_seconds: {self.end_stream_seconds}\n" s += f"{SPACES}num_frames: {self.num_frames}\n" s += f"{SPACES}average_fps: {self.average_fps}\n" return s @@ -224,14 +229,19 @@ def get_container_metadata(decoder: torch.Tensor) -> ContainerMetadata: common_meta = dict( duration_seconds_from_header=stream_dict.get("durationSeconds"), bit_rate=stream_dict.get("bitRate"), - begin_stream_seconds_from_content=stream_dict.get("minPtsSecondsFromScan"), - end_stream_seconds_from_content=stream_dict.get("maxPtsSecondsFromScan"), + begin_stream_seconds_from_header=stream_dict.get("beginStreamFromHeader"), codec=stream_dict.get("codec"), stream_index=stream_index, ) if stream_dict["mediaType"] == "video": streams_metadata.append( VideoStreamMetadata( + begin_stream_seconds_from_content=stream_dict.get( + "minPtsSecondsFromScan" + ), + end_stream_seconds_from_content=stream_dict.get( + "maxPtsSecondsFromScan" + ), width=stream_dict.get("width"), height=stream_dict.get("height"), num_frames_from_header=stream_dict.get("numFrames"), diff --git a/src/torchcodec/decoders/_decoder_utils.py b/src/torchcodec/decoders/_decoder_utils.py index d2750dd95..bb882fbf5 100644 --- a/src/torchcodec/decoders/_decoder_utils.py +++ b/src/torchcodec/decoders/_decoder_utils.py @@ -6,7 +6,7 @@ from pathlib import Path -from typing import Optional, Tuple, Union +from typing import Union from torch import Tensor from torchcodec.decoders import _core as core @@ -33,55 +33,3 @@ def create_decoder( f"Unknown source type: {type(source)}. " "Supported types are str, Path, bytes and Tensor." ) - - -def get_and_validate_stream_metadata( - *, - decoder: Tensor, - stream_index: Optional[int] = None, - media_type: str, -) -> Tuple[core._metadata.StreamMetadata, int, float, float]: - - if media_type not in ("video", "audio"): - raise ValueError(f"Bad {media_type = }, should be audio or video") - - container_metadata = core.get_container_metadata(decoder) - - if stream_index is None: - best_stream_index = ( - container_metadata.best_video_stream_index - if media_type == "video" - else container_metadata.best_audio_stream_index - ) - if best_stream_index is None: - raise ValueError( - f"The best {media_type} stream is unknown and there is no specified stream. " - + ERROR_REPORTING_INSTRUCTIONS - ) - stream_index = best_stream_index - - # This should be logically true because of the above conditions, but type checker - # is not clever enough. - assert stream_index is not None - - metadata = container_metadata.streams[stream_index] - - if metadata.begin_stream_seconds is None: - raise ValueError( - "The minimum pts value in seconds is unknown. " - + ERROR_REPORTING_INSTRUCTIONS - ) - begin_stream_seconds = metadata.begin_stream_seconds - - if metadata.end_stream_seconds is None: - raise ValueError( - "The maximum pts value in seconds is unknown. " - + ERROR_REPORTING_INSTRUCTIONS - ) - end_stream_seconds = metadata.end_stream_seconds - return ( - metadata, - stream_index, - begin_stream_seconds, - end_stream_seconds, - ) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 6c7db62ce..081f332b4 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -6,7 +6,7 @@ import numbers from pathlib import Path -from typing import Literal, Optional, Union +from typing import Literal, Optional, Tuple, Union from torch import device, Tensor @@ -15,7 +15,6 @@ from torchcodec.decoders._decoder_utils import ( create_decoder, ERROR_REPORTING_INSTRUCTIONS, - get_and_validate_stream_metadata, ) @@ -108,18 +107,11 @@ def __init__( self.stream_index, self._begin_stream_seconds, self._end_stream_seconds, - ) = get_and_validate_stream_metadata( - decoder=self._decoder, stream_index=stream_index, media_type="video" + self._num_frames, + ) = _get_and_validate_stream_metadata( + decoder=self._decoder, stream_index=stream_index ) - assert isinstance(self.metadata, core.VideoStreamMetadata) # mypy - - if self.metadata.num_frames is None: - raise ValueError( - "The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS - ) - self._num_frames = self.metadata.num_frames - def __len__(self) -> int: return self._num_frames @@ -338,3 +330,50 @@ def get_frames_played_in_range( stop_seconds=stop_seconds, ) return FrameBatch(*frames) + + +def _get_and_validate_stream_metadata( + *, + decoder: Tensor, + stream_index: Optional[int] = None, +) -> Tuple[core._metadata.VideoStreamMetadata, int, float, float, int]: + + container_metadata = core.get_container_metadata(decoder) + + if stream_index is None: + if (stream_index := container_metadata.best_video_stream_index) is None: + raise ValueError( + "The best video stream is unknown and there is no specified stream. " + + ERROR_REPORTING_INSTRUCTIONS + ) + + metadata = container_metadata.streams[stream_index] + assert isinstance(metadata, core._metadata.VideoStreamMetadata) # mypy + + if metadata.begin_stream_seconds is None: + raise ValueError( + "The minimum pts value in seconds is unknown. " + + ERROR_REPORTING_INSTRUCTIONS + ) + begin_stream_seconds = metadata.begin_stream_seconds + + if metadata.end_stream_seconds is None: + raise ValueError( + "The maximum pts value in seconds is unknown. " + + ERROR_REPORTING_INSTRUCTIONS + ) + end_stream_seconds = metadata.end_stream_seconds + + if metadata.num_frames is None: + raise ValueError( + "The number of frames is unknown. " + ERROR_REPORTING_INSTRUCTIONS + ) + num_frames = metadata.num_frames + + return ( + metadata, + stream_index, + begin_stream_seconds, + end_stream_seconds, + num_frames, + ) diff --git a/test/decoders/test_decoders.py b/test/decoders/test_decoders.py index 0442e1df7..e8dfc675c 100644 --- a/test/decoders/test_decoders.py +++ b/test/decoders/test_decoders.py @@ -955,7 +955,7 @@ def test_metadata(self, asset): == decoder.metadata.stream_index == asset.default_stream_index ) - assert decoder.metadata.duration_seconds == pytest.approx( + assert decoder.metadata.duration_seconds_from_header == pytest.approx( asset.duration_seconds ) assert decoder.metadata.sample_rate == asset.sample_rate @@ -967,13 +967,18 @@ def test_error(self, asset): decoder = AudioDecoder(asset.path) with pytest.raises(ValueError, match="Invalid start seconds"): - decoder.get_samples_played_in_range(start_seconds=-1300) + decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2) - with pytest.raises(ValueError, match="Invalid start seconds"): + with pytest.raises(RuntimeError, match="No audio frames were decoded"): decoder.get_samples_played_in_range(start_seconds=9999) - with pytest.raises(ValueError, match="Invalid start seconds"): - decoder.get_samples_played_in_range(start_seconds=3, stop_seconds=2) + @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) + def test_negative_start(self, asset): + decoder = AudioDecoder(asset.path) + samples = decoder.get_samples_played_in_range(start_seconds=-1300) + reference_samples = decoder.get_samples_played_in_range() + torch.testing.assert_close(samples.data, reference_samples.data) + assert samples.pts_seconds == reference_samples.pts_seconds @pytest.mark.parametrize("asset", (NASA_AUDIO, NASA_AUDIO_MP3)) @pytest.mark.parametrize("stop_seconds", (None, "duration", 99999999)) diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index dec986035..e4e64d6bd 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -7,6 +7,7 @@ import functools import pytest +from torchcodec.decoders import AudioDecoder, VideoDecoder from torchcodec.decoders._core import ( AudioStreamMetadata, @@ -20,6 +21,10 @@ from ..utils import NASA_AUDIO_MP3, NASA_VIDEO +# TODO: Expected values in these tests should be based on the assets's +# attributes rather than on hard-coded values. + + def _get_container_metadata(path, seek_mode): decoder = create_from_file(str(path), seek_mode=seek_mode) return get_container_metadata(decoder) @@ -73,6 +78,7 @@ def test_get_metadata(metadata_getter): assert best_video_stream_metadata.duration_seconds == pytest.approx( 13.013, abs=0.001 ) + assert best_video_stream_metadata.begin_stream_seconds_from_header == 0 assert best_video_stream_metadata.bit_rate == 128783 assert best_video_stream_metadata.average_fps == pytest.approx(29.97, abs=0.001) assert best_video_stream_metadata.codec == "h264" @@ -85,9 +91,8 @@ def test_get_metadata(metadata_getter): best_audio_stream_metadata = metadata.streams[metadata.best_audio_stream_index] assert isinstance(best_audio_stream_metadata, AudioStreamMetadata) assert best_audio_stream_metadata is metadata.best_audio_stream - assert best_audio_stream_metadata.duration_seconds == pytest.approx( - 13.056, abs=0.001 - ) + assert best_audio_stream_metadata.duration_seconds_from_header == 13.056 + assert best_audio_stream_metadata.begin_stream_seconds_from_header == 0 assert best_audio_stream_metadata.bit_rate == 128837 assert best_audio_stream_metadata.codec == "aac" assert best_audio_stream_metadata.sample_format == "fltp" @@ -105,9 +110,8 @@ def test_get_metadata_audio_file(metadata_getter): best_audio_stream_metadata = metadata.streams[metadata.best_audio_stream_index] assert isinstance(best_audio_stream_metadata, AudioStreamMetadata) assert best_audio_stream_metadata is metadata.best_audio_stream - assert best_audio_stream_metadata.duration_seconds == pytest.approx( - 13.248, abs=0.001 - ) + assert best_audio_stream_metadata.duration_seconds_from_header == 13.248 + assert best_audio_stream_metadata.begin_stream_seconds_from_header == 0.138125 assert best_audio_stream_metadata.bit_rate == 64000 assert best_audio_stream_metadata.codec == "mp3" assert best_audio_stream_metadata.sample_format == "fltp" @@ -126,6 +130,7 @@ def test_num_frames_fallback( bit_rate=123, num_frames_from_header=num_frames_from_header, num_frames_from_content=num_frames_from_content, + begin_stream_seconds_from_header=0, begin_stream_seconds_from_content=0, end_stream_seconds_from_content=4, codec="whatever", @@ -136,3 +141,44 @@ def test_num_frames_fallback( ) assert metadata.num_frames == expected_num_frames + + +def test_repr(): + # Test for calls to print(), str(), etc. Useful to make sure we don't forget + # to add additional @properties to __repr__ + assert ( + str(VideoDecoder(NASA_VIDEO.path).metadata) + == """VideoStreamMetadata: + duration_seconds_from_header: 13.013 + begin_stream_seconds_from_header: 0.0 + bit_rate: 128783.0 + codec: h264 + stream_index: 3 + begin_stream_seconds_from_content: 0.0 + end_stream_seconds_from_content: 13.013 + width: 480 + height: 270 + num_frames_from_header: 390 + num_frames_from_content: 390 + average_fps_from_header: 29.97003 + duration_seconds: 13.013 + begin_stream_seconds: 0.0 + end_stream_seconds: 13.013 + num_frames: 390 + average_fps: 29.97002997002997 +""" + ) + + assert ( + str(AudioDecoder(NASA_AUDIO_MP3.path).metadata) + == """AudioStreamMetadata: + duration_seconds_from_header: 13.248 + begin_stream_seconds_from_header: 0.138125 + bit_rate: 64000.0 + codec: mp3 + stream_index: 0 + sample_rate: 8000 + num_channels: 2 + sample_format: fltp +""" + )