diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index 66f5dd9a9..40216304a 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -185,7 +185,6 @@ def get_frame_at_pts_abstract( def get_frames_by_pts_abstract( decoder: torch.Tensor, *, - stream_index: int, timestamps: List[float], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] @@ -198,7 +197,7 @@ def get_frames_by_pts_abstract( @register_fake("torchcodec_ns::get_frame_at_index") def get_frame_at_index_abstract( - decoder: torch.Tensor, *, stream_index: int, frame_index: int + decoder: torch.Tensor, *, frame_index: int ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(3)] return ( @@ -212,7 +211,6 @@ def get_frame_at_index_abstract( def get_frames_at_indices_abstract( decoder: torch.Tensor, *, - stream_index: int, frame_indices: List[int], ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: image_size = [get_ctx().new_dynamic_size() for _ in range(4)] @@ -227,7 +225,6 @@ def get_frames_at_indices_abstract( def get_frames_in_range_abstract( decoder: torch.Tensor, *, - stream_index: int, start: int, stop: int, step: Optional[int] = None, @@ -244,7 +241,6 @@ def get_frames_in_range_abstract( def get_frames_by_pts_in_range_abstract( decoder: torch.Tensor, *, - stream_index: int, start_seconds: float, stop_seconds: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -257,9 +253,7 @@ def get_frames_by_pts_in_range_abstract( @register_fake("torchcodec_ns::_get_key_frame_indices") -def get_key_frame_indices_abstract( - decoder: torch.Tensor, *, stream_index: int -) -> torch.Tensor: +def get_key_frame_indices_abstract(decoder: torch.Tensor) -> torch.Tensor: return torch.empty([], dtype=torch.int) @@ -282,7 +276,6 @@ def get_stream_json_metadata_abstract(decoder: torch.Tensor, stream_idx: int) -> def _test_frame_pts_equality_abstract( decoder: torch.Tensor, *, - stream_index: int, frame_index: int, pts_seconds_to_test: float, ) -> bool: diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 41cb63fc4..047f00d91 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import contextlib + import numpy import pytest import torch @@ -874,6 +876,38 @@ def test_get_key_frame_indices(self, device): key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 ) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_compile(self, device): + decoder = VideoDecoder(NASA_VIDEO.path, device=device) -if __name__ == "__main__": - pytest.main() + @contextlib.contextmanager + def restore_capture_scalar_outputs(): + try: + original = torch._dynamo.config.capture_scalar_outputs + yield + finally: + torch._dynamo.config.capture_scalar_outputs = original + + # TODO: We get a graph break because we call Tensor.item() to turn the + # tensors in FrameBatch into scalars. When we work on compilation and exportability, + # we should investigate. + with restore_capture_scalar_outputs(): + torch._dynamo.config.capture_scalar_outputs = True + + @torch.compile(fullgraph=True, backend="eager") + def get_some_frames(decoder): + frames = [] + frames.append(decoder.get_frame_at(1)) + frames.append(decoder.get_frame_at(3)) + frames.append(decoder.get_frame_at(5)) + return frames + + frames = get_some_frames(decoder) + + ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) + ref_frame3 = NASA_VIDEO.get_frame_data_by_index(3).to(device) + ref_frame5 = NASA_VIDEO.get_frame_data_by_index(5).to(device) + + assert_frames_equal(ref_frame1, frames[0].data) + assert_frames_equal(ref_frame3, frames[1].data) + assert_frames_equal(ref_frame5, frames[2].data) diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index eb8a0926b..8e91efb71 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -9,7 +9,6 @@ os.environ["TORCH_LOGS"] = "output_code" import json import subprocess -from typing import Tuple import numpy as np import pytest @@ -48,20 +47,6 @@ INDEX_OF_FRAME_AT_6_SECONDS = 180 -class ReferenceDecoder: - def __init__(self, device="cpu"): - self.decoder: torch.Tensor = create_from_file(str(NASA_VIDEO.path)) - add_video_stream(self.decoder, device=device) - - def get_next_frame(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - assert self.decoder is not None - return get_next_frame(self.decoder) - - def seek(self, pts: float): - assert self.decoder is not None - seek_to_pts(self.decoder, pts) - - class TestOps: @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_and_next(self, device): @@ -352,27 +337,6 @@ def get_frame1_and_frame_time6(decoder): assert_frames_equal(frame0, reference_frame0.to(device)) assert_frames_equal(frame_time6, reference_frame_time6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_class_based_compile_seek_and_next(self, device): - # TODO_OPEN_ISSUE Scott (T180277797): Ditto as above. - @torch.compile(fullgraph=True, backend="eager") - def class_based_get_frame1_and_frame_time6( - decoder: ReferenceDecoder, - ) -> Tuple[torch.Tensor, torch.Tensor]: - frame0, _, _ = decoder.get_next_frame() - decoder.seek(6.0) - frame_time6, _, _ = decoder.get_next_frame() - return frame0, frame_time6 - - decoder = ReferenceDecoder(device=device) - frame0, frame_time6 = class_based_get_frame1_and_frame_time6(decoder) - reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) - reference_frame_time6 = NASA_VIDEO.get_frame_data_by_index( - INDEX_OF_FRAME_AT_6_SECONDS - ) - assert_frames_equal(frame0, reference_frame0.to(device)) - assert_frames_equal(frame_time6, reference_frame_time6.to(device)) - @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("create_from", ("file", "tensor", "bytes")) def test_create_decoder(self, create_from, device):