From e29e04603145e7af986b15abf91fdcda486e36ef Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 28 Oct 2024 11:05:48 +0000 Subject: [PATCH 1/2] get_frames_at -> get_frames_in_range --- README.md | 2 +- examples/basic_example.py | 4 ++-- src/torchcodec/decoders/_video_decoder.py | 2 +- test/decoders/test_video_decoder.py | 14 +++++++------- 4 files changed, 11 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 7aadbc652..90d7b49c5 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,7 @@ decoder.get_frame_at(len(decoder) - 1) # pts_seconds: 9.960000038146973 # duration_seconds: 0.03999999910593033 -decoder.get_frames_at(start=10, stop=30, step=5) +decoder.get_frames_in_range(start=10, stop=30, step=5) # FrameBatch: # data (shape): torch.Size([4, 3, 400, 640]) # pts_seconds: tensor([0.4000, 0.6000, 0.8000, 1.0000]) diff --git a/examples/basic_example.py b/examples/basic_example.py index 4df03b8a8..db00cd040 100644 --- a/examples/basic_example.py +++ b/examples/basic_example.py @@ -120,7 +120,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): # their :term:`pts` (Presentation Time Stamp), and their duration. # This can be achieved using the # :meth:`~torchcodec.decoders.VideoDecoder.get_frame_at` and -# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_at` methods, which +# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_in_range` methods, which # will return a :class:`~torchcodec.Frame` and # :class:`~torchcodec.FrameBatch` objects respectively. @@ -129,7 +129,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): print(last_frame) # %% -middle_frames = decoder.get_frames_at(start=10, stop=20, step=2) +middle_frames = decoder.get_frames_in_range(start=10, stop=20, step=2) print(f"{type(middle_frames) = }") print(middle_frames) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 4a03fcbc2..b1d4e7bf3 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -181,7 +181,7 @@ def get_frame_at(self, index: int) -> Frame: duration_seconds=duration_seconds.item(), ) - def get_frames_at(self, start: int, stop: int, step: int = 1) -> FrameBatch: + def get_frames_in_range(self, start: int, stop: int, step: int = 1) -> FrameBatch: """Return multiple frames at the given index range. Frames are in [start, stop). diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 55a8256d8..0ab50af7a 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -366,14 +366,14 @@ def test_get_frame_displayed_at_fails(self): frame = decoder.get_frame_displayed_at(100.0) # noqa @pytest.mark.parametrize("stream_index", [0, 3, None]) - def test_get_frames_at(self, stream_index): + def test_get_frames_in_range(self, stream_index): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index) # test degenerate case where we only actually get 1 frame ref_frames9 = NASA_VIDEO.get_frame_data_by_range( start=9, stop=10, stream_index=stream_index ) - frames9 = decoder.get_frames_at(start=9, stop=10) + frames9 = decoder.get_frames_in_range(start=9, stop=10) assert_tensor_equal(ref_frames9, frames9.data) assert frames9.pts_seconds[0].item() == pytest.approx( @@ -389,7 +389,7 @@ def test_get_frames_at(self, stream_index): ref_frames0_9 = NASA_VIDEO.get_frame_data_by_range( start=0, stop=10, stream_index=stream_index ) - frames0_9 = decoder.get_frames_at(start=0, stop=10) + frames0_9 = decoder.get_frames_in_range(start=0, stop=10) assert frames0_9.data.shape == torch.Size( [ 10, @@ -412,7 +412,7 @@ def test_get_frames_at(self, stream_index): ref_frames0_8_2 = NASA_VIDEO.get_frame_data_by_range( start=0, stop=10, step=2, stream_index=stream_index ) - frames0_8_2 = decoder.get_frames_at(start=0, stop=10, step=2) + frames0_8_2 = decoder.get_frames_in_range(start=0, stop=10, step=2) assert frames0_8_2.data.shape == torch.Size( [ 5, @@ -434,13 +434,13 @@ def test_get_frames_at(self, stream_index): ) # test numpy.int64 for indices - frames0_8_2 = decoder.get_frames_at( + frames0_8_2 = decoder.get_frames_in_range( start=numpy.int64(0), stop=numpy.int64(10), step=numpy.int64(2) ) assert_tensor_equal(ref_frames0_8_2, frames0_8_2.data) # an empty range is valid! - empty_frames = decoder.get_frames_at(5, 5) + empty_frames = decoder.get_frames_in_range(5, 5) assert_tensor_equal( empty_frames.data, NASA_VIDEO.get_empty_chw_tensor(stream_index=stream_index), @@ -456,7 +456,7 @@ def test_get_frames_at(self, stream_index): ( lambda decoder: decoder[0], lambda decoder: decoder.get_frame_at(0).data, - lambda decoder: decoder.get_frames_at(0, 4).data, + lambda decoder: decoder.get_frames_in_range(0, 4).data, lambda decoder: decoder.get_frame_displayed_at(0).data, # TODO: uncomment once D60001893 lands # lambda decoder: decoder.get_frames_displayed_at(0, 1).data, From a4e9c5fd356aea49b8578d20fb24816fd4699b58 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 28 Oct 2024 11:07:44 +0000 Subject: [PATCH 2/2] get_frames_displayed_at -> get_frames_displayed_in_range --- examples/basic_example.py | 4 ++-- src/torchcodec/decoders/_video_decoder.py | 2 +- test/decoders/test_video_decoder.py | 28 +++++++++++------------ 3 files changed, 17 insertions(+), 17 deletions(-) diff --git a/examples/basic_example.py b/examples/basic_example.py index db00cd040..b8beac9e4 100644 --- a/examples/basic_example.py +++ b/examples/basic_example.py @@ -152,7 +152,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): # So far, we have retrieved frames based on their index. We can also retrieve # frames based on *when* they are displayed with # :meth:`~torchcodec.decoders.VideoDecoder.get_frame_displayed_at` and -# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_at`, which +# :meth:`~torchcodec.decoders.VideoDecoder.get_frames_displayed_in_range`, which # also returns :class:`~torchcodec.Frame` and :class:`~torchcodec.FrameBatch` # respectively. @@ -161,7 +161,7 @@ def plot(frames: torch.Tensor, title : Optional[str] = None): print(frame_at_2_seconds) # %% -first_two_seconds = decoder.get_frames_displayed_at( +first_two_seconds = decoder.get_frames_displayed_in_range( start_seconds=0, stop_seconds=2, ) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index b1d4e7bf3..58a8a9916 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -238,7 +238,7 @@ def get_frame_displayed_at(self, seconds: float) -> Frame: duration_seconds=duration_seconds.item(), ) - def get_frames_displayed_at( + def get_frames_displayed_in_range( self, start_seconds: float, stop_seconds: float ) -> FrameBatch: """Returns multiple frames in the given range. diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 0ab50af7a..f57314657 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -459,7 +459,7 @@ def test_get_frames_in_range(self, stream_index): lambda decoder: decoder.get_frames_in_range(0, 4).data, lambda decoder: decoder.get_frame_displayed_at(0).data, # TODO: uncomment once D60001893 lands - # lambda decoder: decoder.get_frames_displayed_at(0, 1).data, + # lambda decoder: decoder.get_frames_displayed_in_range(0, 1).data, ), ) def test_dimension_order(self, dimension_order, frame_getter): @@ -487,7 +487,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=stream_index) # Note that we are comparing the results of VideoDecoder's method: - # get_frames_displayed_at() + # get_frames_displayed_in_range() # With the testing framework's method: # get_frame_data_by_range() # That is, we are testing the correctness of a pts-based range against an index- @@ -504,7 +504,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): # value for frame 5 that we have access to on the Python side is slightly less than the pts # value on the C++ side. This test still produces the correct result because a slightly # less value still falls into the correct window. - frames0_4 = decoder.get_frames_displayed_at( + frames0_4 = decoder.get_frames_displayed_in_range( decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(5).pts_seconds ) assert_tensor_equal( @@ -513,7 +513,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): ) # Range where the stop seconds is about halfway between pts values for two frames. - also_frames0_4 = decoder.get_frames_displayed_at( + also_frames0_4 = decoder.get_frames_displayed_in_range( decoder.get_frame_at(0).pts_seconds, decoder.get_frame_at(4).pts_seconds + HALF_DURATION, ) @@ -521,7 +521,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): # Again, the intention here is to provide the exact values we care about. In practice, our # pts values are slightly smaller, so we nudge the start upwards. - frames5_9 = decoder.get_frames_displayed_at( + frames5_9 = decoder.get_frames_displayed_in_range( decoder.get_frame_at(5).pts_seconds, decoder.get_frame_at(10).pts_seconds, ) @@ -533,7 +533,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): # Range where we provide start_seconds and stop_seconds that are different, but # also should land in the same window of time between two frame's pts values. As # a result, we should only get back one frame. - frame6 = decoder.get_frames_displayed_at( + frame6 = decoder.get_frames_displayed_in_range( decoder.get_frame_at(6).pts_seconds, decoder.get_frame_at(6).pts_seconds + HALF_DURATION, ) @@ -543,7 +543,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): ) # Very small range that falls in the same frame. - frame35 = decoder.get_frames_displayed_at( + frame35 = decoder.get_frames_displayed_in_range( decoder.get_frame_at(35).pts_seconds, decoder.get_frame_at(35).pts_seconds + 1e-10, ) @@ -555,7 +555,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): # Single frame where the start seconds is before frame i's pts, and the stop is # after frame i's pts, but before frame i+1's pts. In that scenario, we expect # to see frames i-1 and i. - frames7_8 = decoder.get_frames_displayed_at( + frames7_8 = decoder.get_frames_displayed_in_range( NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds - HALF_DURATION, NASA_VIDEO.get_frame_info(8, stream_index=stream_index).pts_seconds @@ -567,7 +567,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): ) # Start and stop seconds are the same value, which should not return a frame. - empty_frame = decoder.get_frames_displayed_at( + empty_frame = decoder.get_frames_displayed_in_range( NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds, NASA_VIDEO.get_frame_info(4, stream_index=stream_index).pts_seconds, ) @@ -583,7 +583,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): ) # Start and stop seconds land within the first frame. - frame0 = decoder.get_frames_displayed_at( + frame0 = decoder.get_frames_displayed_in_range( NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds, NASA_VIDEO.get_frame_info(0, stream_index=stream_index).pts_seconds + HALF_DURATION, @@ -595,7 +595,7 @@ def test_get_frames_by_pts_in_range(self, stream_index): # We should be able to get all frames by giving the beginning and ending time # for the stream. - all_frames = decoder.get_frames_displayed_at( + all_frames = decoder.get_frames_displayed_in_range( decoder.metadata.begin_stream_seconds, decoder.metadata.end_stream_seconds ) assert_tensor_equal(all_frames.data, decoder[:]) @@ -604,13 +604,13 @@ def test_get_frames_by_pts_in_range_fails(self): decoder = VideoDecoder(NASA_VIDEO.path) with pytest.raises(ValueError, match="Invalid start seconds"): - frame = decoder.get_frames_displayed_at(100.0, 1.0) # noqa + frame = decoder.get_frames_displayed_in_range(100.0, 1.0) # noqa with pytest.raises(ValueError, match="Invalid start seconds"): - frame = decoder.get_frames_displayed_at(20, 23) # noqa + frame = decoder.get_frames_displayed_in_range(20, 23) # noqa with pytest.raises(ValueError, match="Invalid stop seconds"): - frame = decoder.get_frames_displayed_at(0, 23) # noqa + frame = decoder.get_frames_displayed_in_range(0, 23) # noqa if __name__ == "__main__":