diff --git a/docs/source/api_ref_decoders.rst b/docs/source/api_ref_decoders.rst index 2daaf80c7..b613ce2f8 100644 --- a/docs/source/api_ref_decoders.rst +++ b/docs/source/api_ref_decoders.rst @@ -20,6 +20,4 @@ torchcodec.decoders :nosignatures: :template: dataclass.rst - Frame - FrameBatch VideoStreamMetadata diff --git a/docs/source/api_ref_torchcodec.rst b/docs/source/api_ref_torchcodec.rst new file mode 100644 index 000000000..36def114f --- /dev/null +++ b/docs/source/api_ref_torchcodec.rst @@ -0,0 +1,16 @@ +.. _torchcodec: + +=================== +torchcodec +=================== + +.. currentmodule:: torchcodec + + +.. autosummary:: + :toctree: generated/ + :nosignatures: + :template: dataclass.rst + + Frame + FrameBatch diff --git a/docs/source/index.rst b/docs/source/index.rst index f91cca315..1ce569f3a 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -45,7 +45,7 @@ We achieve these capabilities through: .. grid-item-card:: :octicon:`file-code;1em` API Reference :img-top: _static/img/card-background.svg - :link: api_ref_decoders.html + :link: api_ref_torchcodec.html :link-type: url The API reference for TorchCodec @@ -73,4 +73,5 @@ We achieve these capabilities through: :caption: API Reference :hidden: + api_ref_torchcodec api_ref_decoders diff --git a/examples/basic_example.py b/examples/basic_example.py index cae0668b3..4df03b8a8 100644 --- a/examples/basic_example.py +++ b/examples/basic_example.py @@ -121,8 +121,8 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): # This can be achieved using the # :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and # :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which -# will return a :class:`~torchcodec.decoders.Frame` and -# :class:`~torchcodec.decoders.FrameBatch` objects respectively. +# will return a :class:`~torchcodec.Frame` and +# :class:`~torchcodec.FrameBatch` objects respectively. last_frame = decoder.get_frame_at(len(decoder) - 1) print(f"{type(last_frame) = }") @@ -138,12 +138,12 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): plot(middle_frames.data, "Middle frames") # %% -# Both :class:`~torchcodec.decoders.Frame` and -# :class:`~torchcodec.decoders.FrameBatch` have a ``data`` field, which contains +# Both :class:`~torchcodec.Frame` and +# :class:`~torchcodec.FrameBatch` have a ``data`` field, which contains # the decoded tensor data. They also have the ``pts_seconds`` and # ``duration_seconds`` fields which are single ints for -# :class:`~torchcodec.decoders.Frame`, and 1-D :class:`torch.Tensor` for -# :class:`~torchcodec.decoders.FrameBatch` (one value per frame in the batch). +# :class:`~torchcodec.Frame`, and 1-D :class:`torch.Tensor` for +# :class:`~torchcodec.FrameBatch` (one value per frame in the batch). # %% # Using time-based indexing @@ -153,7 +153,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): # frames based on *when* they are displayed with # :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and # :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which -# also returns :class:`~torchcodec.decoders.Frame` and :class:`~torchcodec.decoders.FrameBatch` +# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch` # respectively. frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2) diff --git a/src/torchcodec/__init__.py b/src/torchcodec/__init__.py index cfdf49898..a27a83e04 100644 --- a/src/torchcodec/__init__.py +++ b/src/torchcodec/__init__.py @@ -4,6 +4,9 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from . import decoders, samplers # noqa # noqa +# Note: usort wants to put Frame and FrameBatch after decoders and samplers, +# but that results in circular import. +from ._frame import Frame, FrameBatch # usort:skip # noqa +from . import decoders, samplers # noqa __version__ = "0.0.4.dev" diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py new file mode 100644 index 000000000..c847f57b8 --- /dev/null +++ b/src/torchcodec/_frame.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import dataclasses +from dataclasses import dataclass +from typing import Iterable, Iterator, Union + +from torch import Tensor + + +def _frame_repr(self): + # Utility to replace Frame and FrameBatch __repr__ method. This prints the + # shape of the .data tensor rather than printing the (potentially very long) + # data tensor itself. + s = self.__class__.__name__ + ":\n" + spaces = " " + for field in dataclasses.fields(self): + field_name = field.name + field_val = getattr(self, field_name) + if field_name == "data": + field_name = "data (shape)" + field_val = field_val.shape + s += f"{spaces}{field_name}: {field_val}\n" + return s + + +@dataclass +class Frame(Iterable): + """A single video frame with associated metadata.""" + + data: Tensor + """The frame data as (3-D ``torch.Tensor``).""" + pts_seconds: float + """The :term:`pts` of the frame, in seconds (float).""" + duration_seconds: float + """The duration of the frame, in seconds (float).""" + + def __iter__(self) -> Iterator[Union[Tensor, float]]: + for field in dataclasses.fields(self): + yield getattr(self, field.name) + + def __repr__(self): + return _frame_repr(self) + + +@dataclass +class FrameBatch(Iterable): + """Multiple video frames with associated metadata.""" + + data: Tensor + """The frames data as (4-D ``torch.Tensor``).""" + pts_seconds: Tensor + """The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" + duration_seconds: Tensor + """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" + + def __iter__(self) -> Iterator[Union[Tensor, float]]: + for field in dataclasses.fields(self): + yield getattr(self, field.name) + + def __repr__(self): + return _frame_repr(self) diff --git a/src/torchcodec/decoders/__init__.py b/src/torchcodec/decoders/__init__.py index 7673ed45f..307f18f43 100644 --- a/src/torchcodec/decoders/__init__.py +++ b/src/torchcodec/decoders/__init__.py @@ -5,6 +5,6 @@ # LICENSE file in the root directory of this source tree. from ._core import VideoStreamMetadata -from ._video_decoder import Frame, FrameBatch, VideoDecoder # noqa +from ._video_decoder import VideoDecoder # noqa SimpleVideoDecoder = VideoDecoder diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 3e2f535fb..578c8dd61 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -4,71 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import dataclasses import numbers -from dataclasses import dataclass from pathlib import Path -from typing import Iterable, Iterator, Literal, Tuple, Union +from typing import Literal, Tuple, Union from torch import Tensor +from torchcodec import Frame, FrameBatch from torchcodec.decoders import _core as core - -def _frame_repr(self): - # Utility to replace Frame and FrameBatch __repr__ method. This prints the - # shape of the .data tensor rather than printing the (potentially very long) - # data tensor itself. - s = self.__class__.__name__ + ":\n" - spaces = " " - for field in dataclasses.fields(self): - field_name = field.name - field_val = getattr(self, field_name) - if field_name == "data": - field_name = "data (shape)" - field_val = field_val.shape - s += f"{spaces}{field_name}: {field_val}\n" - return s - - -@dataclass -class Frame(Iterable): - """A single video frame with associated metadata.""" - - data: Tensor - """The frame data as (3-D ``torch.Tensor``).""" - pts_seconds: float - """The :term:`pts` of the frame, in seconds (float).""" - duration_seconds: float - """The duration of the frame, in seconds (float).""" - - def __iter__(self) -> Iterator[Union[Tensor, float]]: - for field in dataclasses.fields(self): - yield getattr(self, field.name) - - def __repr__(self): - return _frame_repr(self) - - -@dataclass -class FrameBatch(Iterable): - """Multiple video frames with associated metadata.""" - - data: Tensor - """The frames data as (4-D ``torch.Tensor``).""" - pts_seconds: Tensor - """The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" - duration_seconds: Tensor - """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" - - def __iter__(self) -> Iterator[Union[Tensor, float]]: - for field in dataclasses.fields(self): - yield getattr(self, field.name) - - def __repr__(self): - return _frame_repr(self) - - _ERROR_REPORTING_INSTRUCTIONS = """ This should never happen. Please report an issue following the steps in https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml. diff --git a/src/torchcodec/samplers/_implem.py b/src/torchcodec/samplers/_implem.py index 2c5cccde2..884f91f4b 100644 --- a/src/torchcodec/samplers/_implem.py +++ b/src/torchcodec/samplers/_implem.py @@ -2,7 +2,8 @@ import torch -from torchcodec.decoders import Frame, FrameBatch, VideoDecoder +from torchcodec import Frame, FrameBatch +from torchcodec.decoders import VideoDecoder def _validate_params( diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index 3496a002d..42303b33f 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -5,7 +5,8 @@ import pytest import torch -from torchcodec.decoders import FrameBatch, VideoDecoder +from torchcodec import FrameBatch +from torchcodec.decoders import VideoDecoder from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS