From 267859cb4f29989e8ab104a2f60713cbe901e7ef Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 28 Feb 2025 13:59:48 -0800 Subject: [PATCH] Remove compile test --- test/decoders/test_video_decoder.py | 38 ----------------------------- 1 file changed, 38 deletions(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 047f00d91..de05fd08d 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -4,8 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import contextlib - import numpy import pytest import torch @@ -875,39 +873,3 @@ def test_get_key_frame_indices(self, device): torch.testing.assert_close( key_frame_indices, h265_reference_key_frame_indices, atol=0, rtol=0 ) - - @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_compile(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) - - @contextlib.contextmanager - def restore_capture_scalar_outputs(): - try: - original = torch._dynamo.config.capture_scalar_outputs - yield - finally: - torch._dynamo.config.capture_scalar_outputs = original - - # TODO: We get a graph break because we call Tensor.item() to turn the - # tensors in FrameBatch into scalars. When we work on compilation and exportability, - # we should investigate. - with restore_capture_scalar_outputs(): - torch._dynamo.config.capture_scalar_outputs = True - - @torch.compile(fullgraph=True, backend="eager") - def get_some_frames(decoder): - frames = [] - frames.append(decoder.get_frame_at(1)) - frames.append(decoder.get_frame_at(3)) - frames.append(decoder.get_frame_at(5)) - return frames - - frames = get_some_frames(decoder) - - ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) - ref_frame3 = NASA_VIDEO.get_frame_data_by_index(3).to(device) - ref_frame5 = NASA_VIDEO.get_frame_data_by_index(5).to(device) - - assert_frames_equal(ref_frame1, frames[0].data) - assert_frames_equal(ref_frame3, frames[1].data) - assert_frames_equal(ref_frame5, frames[2].data)