Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions docs/source/api_ref_decoders.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,4 @@ torchcodec.decoders
:nosignatures:
:template: dataclass.rst

Frame
FrameBatch
VideoStreamMetadata
16 changes: 16 additions & 0 deletions docs/source/api_ref_torchcodec.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
.. _torchcodec:

===================
torchcodec
===================

.. currentmodule:: torchcodec


.. autosummary::
:toctree: generated/
:nosignatures:
:template: dataclass.rst

Frame
FrameBatch
3 changes: 2 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ We achieve these capabilities through:
.. grid-item-card:: :octicon:`file-code;1em`
API Reference
:img-top: _static/img/card-background.svg
:link: api_ref_decoders.html
:link: api_ref_torchcodec.html
:link-type: url

The API reference for TorchCodec
Expand Down Expand Up @@ -73,4 +73,5 @@ We achieve these capabilities through:
:caption: API Reference
:hidden:

api_ref_torchcodec
api_ref_decoders
14 changes: 7 additions & 7 deletions examples/basic_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# This can be achieved using the
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which
# will return a :class:`~torchcodec.decoders.Frame` and
# :class:`~torchcodec.decoders.FrameBatch` objects respectively.
# will return a :class:`~torchcodec.Frame` and
# :class:`~torchcodec.FrameBatch` objects respectively.

last_frame = decoder.get_frame_at(len(decoder) - 1)
print(f"{type(last_frame) = }")
Expand All @@ -138,12 +138,12 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
plot(middle_frames.data, "Middle frames")

# %%
# Both :class:`~torchcodec.decoders.Frame` and
# :class:`~torchcodec.decoders.FrameBatch` have a ``data`` field, which contains
# Both :class:`~torchcodec.Frame` and
# :class:`~torchcodec.FrameBatch` have a ``data`` field, which contains
# the decoded tensor data. They also have the ``pts_seconds`` and
# ``duration_seconds`` fields which are single ints for
# :class:`~torchcodec.decoders.Frame`, and 1-D :class:`torch.Tensor` for
# :class:`~torchcodec.decoders.FrameBatch` (one value per frame in the batch).
# :class:`~torchcodec.Frame`, and 1-D :class:`torch.Tensor` for
# :class:`~torchcodec.FrameBatch` (one value per frame in the batch).

# %%
# Using time-based indexing
Expand All @@ -153,7 +153,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None):
# frames based on *when* they are displayed with
# :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and
# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which
# also returns :class:`~torchcodec.decoders.Frame` and :class:`~torchcodec.decoders.FrameBatch`
# also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch`
# respectively.

frame_at_2_seconds = decoder.get_frame_displayed_at(seconds=2)
Expand Down
5 changes: 4 additions & 1 deletion src/torchcodec/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from . import decoders, samplers # noqa # noqa
# Note: usort wants to put Frame and FrameBatch after decoders and samplers,
# but that results in circular import.
from ._frame import Frame, FrameBatch # usort:skip # noqa
from . import decoders, samplers # noqa

__version__ = "0.0.4.dev"
65 changes: 65 additions & 0 deletions src/torchcodec/_frame.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
from dataclasses import dataclass
from typing import Iterable, Iterator, Union

from torch import Tensor


def _frame_repr(self):
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
# shape of the .data tensor rather than printing the (potentially very long)
# data tensor itself.
s = self.__class__.__name__ + ":\n"
spaces = " "
for field in dataclasses.fields(self):
field_name = field.name
field_val = getattr(self, field_name)
if field_name == "data":
field_name = "data (shape)"
field_val = field_val.shape
s += f"{spaces}{field_name}: {field_val}\n"
return s


@dataclass
class Frame(Iterable):
"""A single video frame with associated metadata."""

data: Tensor
"""The frame data as (3-D ``torch.Tensor``)."""
pts_seconds: float
"""The :term:`pts` of the frame, in seconds (float)."""
duration_seconds: float
"""The duration of the frame, in seconds (float)."""

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)

def __repr__(self):
return _frame_repr(self)


@dataclass
class FrameBatch(Iterable):
"""Multiple video frames with associated metadata."""

data: Tensor
"""The frames data as (4-D ``torch.Tensor``)."""
pts_seconds: Tensor
"""The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
duration_seconds: Tensor
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)

def __repr__(self):
return _frame_repr(self)
2 changes: 1 addition & 1 deletion src/torchcodec/decoders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,6 @@
# LICENSE file in the root directory of this source tree.

from ._core import VideoStreamMetadata
from ._video_decoder import Frame, FrameBatch, VideoDecoder # noqa
from ._video_decoder import VideoDecoder # noqa

SimpleVideoDecoder = VideoDecoder
60 changes: 2 additions & 58 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,15 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import dataclasses
import numbers
from dataclasses import dataclass
from pathlib import Path
from typing import Iterable, Iterator, Literal, Tuple, Union
from typing import Literal, Tuple, Union

from torch import Tensor

from torchcodec import Frame, FrameBatch
from torchcodec.decoders import _core as core


def _frame_repr(self):
# Utility to replace Frame and FrameBatch __repr__ method. This prints the
# shape of the .data tensor rather than printing the (potentially very long)
# data tensor itself.
s = self.__class__.__name__ + ":\n"
spaces = " "
for field in dataclasses.fields(self):
field_name = field.name
field_val = getattr(self, field_name)
if field_name == "data":
field_name = "data (shape)"
field_val = field_val.shape
s += f"{spaces}{field_name}: {field_val}\n"
return s


@dataclass
class Frame(Iterable):
"""A single video frame with associated metadata."""

data: Tensor
"""The frame data as (3-D ``torch.Tensor``)."""
pts_seconds: float
"""The :term:`pts` of the frame, in seconds (float)."""
duration_seconds: float
"""The duration of the frame, in seconds (float)."""

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)

def __repr__(self):
return _frame_repr(self)


@dataclass
class FrameBatch(Iterable):
"""Multiple video frames with associated metadata."""

data: Tensor
"""The frames data as (4-D ``torch.Tensor``)."""
pts_seconds: Tensor
"""The :term:`pts` of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""
duration_seconds: Tensor
"""The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats)."""

def __iter__(self) -> Iterator[Union[Tensor, float]]:
for field in dataclasses.fields(self):
yield getattr(self, field.name)

def __repr__(self):
return _frame_repr(self)


_ERROR_REPORTING_INSTRUCTIONS = """
This should never happen. Please report an issue following the steps in
https://github.com/pytorch/torchcodec/issues/new?assignees=&labels=&projects=&template=bug-report.yml.
Expand Down
3 changes: 2 additions & 1 deletion src/torchcodec/samplers/_implem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@

import torch

from torchcodec.decoders import Frame, FrameBatch, VideoDecoder
from torchcodec import Frame, FrameBatch
from torchcodec.decoders import VideoDecoder


def _validate_params(
Expand Down
3 changes: 2 additions & 1 deletion test/samplers/test_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import pytest
import torch

from torchcodec.decoders import FrameBatch, VideoDecoder
from torchcodec import FrameBatch
from torchcodec.decoders import VideoDecoder
from torchcodec.samplers import clips_at_random_indices, clips_at_regular_indices
from torchcodec.samplers._implem import _build_all_clips_indices, _POLICY_FUNCTIONS

Expand Down
Loading