From 14825293eb134d39d07b739d9896e3ba7b774ce9 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 14:18:44 +0100 Subject: [PATCH 1/3] Frame and FrameBatch improvements --- src/torchcodec/_frame.py | 47 ++++++++++++- test/test_frame_dataclasses.py | 121 +++++++++++++++++++++++++++++++++ 2 files changed, 165 insertions(+), 3 deletions(-) create mode 100644 test/test_frame_dataclasses.py diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index c847f57b8..e6013ba82 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -38,6 +38,13 @@ class Frame(Iterable): duration_seconds: float """The duration of the frame, in seconds (float).""" + def __post_init__(self): + if not self.data.ndim == 3: + raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") + + self.pts_seconds = float(self.pts_seconds) + self.duration_seconds = float(self.duration_seconds) + def __iter__(self) -> Iterator[Union[Tensor, float]]: for field in dataclasses.fields(self): yield getattr(self, field.name) @@ -57,9 +64,43 @@ class FrameBatch(Iterable): duration_seconds: Tensor """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" - def __iter__(self) -> Iterator[Union[Tensor, float]]: - for field in dataclasses.fields(self): - yield getattr(self, field.name) + def __post_init__(self): + if self.data.ndim < 4: + raise ValueError( + f"data must be at least 4-dimensional. Got {self.data.shape = } " + "For 3-dimensional data, create a Frame object instead." + ) + + leading_dims = self.data.shape[:-3] + if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape): + raise ValueError( + "Tried to create a FrameBatch but the leading dimensions of the inputs do not match. " + f"Got {self.data.shape = } so we expected the shape of pts_seconds and " + f"duration_seconds to be {leading_dims = }, but got " + f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }." + ) + + def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]: + cls = Frame if self.data.ndim == 4 else FrameBatch + for data, pts_seconds, duration_seconds in zip( + self.data, self.pts_seconds, self.duration_seconds + ): + yield cls( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) + + def __getitem__(self, key) -> Union["FrameBatch", Frame]: + cls = Frame if self.data.ndim == 4 else FrameBatch + return cls( + self.data[key], + self.pts_seconds[key], + self.duration_seconds[key], + ) + + def __len__(self): + return len(self.data) def __repr__(self): return _frame_repr(self) diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py new file mode 100644 index 000000000..9b79b882f --- /dev/null +++ b/test/test_frame_dataclasses.py @@ -0,0 +1,121 @@ +import pytest +import torch +from torchcodec import Frame, FrameBatch + + +def test_frame_unpacking(): + data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa + + +def test_frame_error(): + with pytest.raises(ValueError, match="data must be 3-dimensional"): + Frame( + data=torch.rand(1, 2), + pts_seconds=1, + duration_seconds=1, + ) + with pytest.raises(ValueError, match="data must be 3-dimensional"): + Frame( + data=torch.rand(1, 2, 3, 4), + pts_seconds=1, + duration_seconds=1, + ) + + +def test_framebatch_error(): + with pytest.raises(ValueError, match="data must be at least 4-dimensional"): + FrameBatch( + data=torch.rand(1, 2, 3), + pts_seconds=torch.rand(1), + duration_seconds=torch.rand(1), + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(3, 4, 2, 1), + pts_seconds=torch.rand(3), # ok + duration_seconds=torch.rand(2), # bad + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(3, 4, 2, 1), + pts_seconds=torch.rand(2), # bad + duration_seconds=torch.rand(3), # ok + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(5, 3, 4, 2, 1), + pts_seconds=torch.rand(5, 3), # ok + duration_seconds=torch.rand(5, 2), # bad + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(5, 3, 4, 2, 1), + pts_seconds=torch.rand(5, 2), # bad + duration_seconds=torch.rand(5, 3), # ok + ) + + +def test_framebatch_iteration(): + T, N, C, H, W = 7, 6, 3, 2, 4 + + fb = FrameBatch( + data=torch.rand(T, N, C, H, W), + pts_seconds=torch.rand(T, N), + duration_seconds=torch.rand(T, N), + ) + + for sub_fb in fb: + assert isinstance(sub_fb, FrameBatch) + assert sub_fb.data.shape == (N, C, H, W) + assert sub_fb.pts_seconds.shape == (N,) + assert sub_fb.duration_seconds.shape == (N,) + for frame in sub_fb: + assert isinstance(frame, Frame) + assert frame.data.shape == (C, H, W) + assert isinstance(frame.pts_seconds, float) + assert isinstance(frame.duration_seconds, float) + + # Check unpacking behavior + first_sub_fb, *_ = fb + assert isinstance(first_sub_fb, FrameBatch) + + +def test_framebatch_indexing(): + T, N, C, H, W = 7, 6, 3, 2, 4 + + fb = FrameBatch( + data=torch.rand(T, N, C, H, W), + pts_seconds=torch.rand(T, N), + duration_seconds=torch.rand(T, N), + ) + + for i in range(len(fb)): + assert isinstance(fb[i], FrameBatch) + assert fb[i].data.shape == (N, C, H, W) + assert fb[i].pts_seconds.shape == (N,) + assert fb[i].duration_seconds.shape == (N,) + for j in range(len(fb[i])): + assert isinstance(fb[i][j], Frame) + assert fb[i][j].data.shape == (C, H, W) + assert isinstance(fb[i][j].pts_seconds, float) + assert isinstance(fb[i][j].duration_seconds, float) + + fb_fancy = fb[torch.arange(3)] + assert isinstance(fb_fancy, FrameBatch) + assert fb_fancy.data.shape == (3, N, C, H, W) + + fb_fancy = fb[[[0], [1]]] # select T=0 and N=1. + assert isinstance(fb_fancy, FrameBatch) + assert fb_fancy.data.shape == (1, C, H, W) From 46612374677ff4dbbea5fdf2b175bc1090c4d1fa Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 14:28:56 +0100 Subject: [PATCH 2/3] Fix mypy? --- src/torchcodec/_frame.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index e6013ba82..b9542df4a 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -41,7 +41,6 @@ class Frame(Iterable): def __post_init__(self): if not self.data.ndim == 3: raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") - self.pts_seconds = float(self.pts_seconds) self.duration_seconds = float(self.duration_seconds) @@ -92,12 +91,21 @@ def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]: ) def __getitem__(self, key) -> Union["FrameBatch", Frame]: - cls = Frame if self.data.ndim == 4 else FrameBatch - return cls( - self.data[key], - self.pts_seconds[key], - self.duration_seconds[key], - ) + data = self.data[key] + pts_seconds = self.pts_seconds[key] + duration_seconds = self.duration_seconds[key] + if self.data.ndim == 4: + return Frame( + data=data, + pts_seconds=float(pts_seconds.item()), + duration_seconds=float(duration_seconds.item()), + ) + else: + return FrameBatch( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) def __len__(self): return len(self.data) From c6b594c371686c624bdce22c442beda902f1e909 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Wed, 23 Oct 2024 17:38:58 +0100 Subject: [PATCH 3/3] Added comment --- src/torchcodec/_frame.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index b9542df4a..8137c4579 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -39,6 +39,8 @@ class Frame(Iterable): """The duration of the frame, in seconds (float).""" def __post_init__(self): + # This is called after __init__() when a Frame is created. We can run + # input validation checks here. if not self.data.ndim == 3: raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") self.pts_seconds = float(self.pts_seconds) @@ -64,6 +66,8 @@ class FrameBatch(Iterable): """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" def __post_init__(self): + # This is called after __init__() when a FrameBatch is created. We can + # run input validation checks here. if self.data.ndim < 4: raise ValueError( f"data must be at least 4-dimensional. Got {self.data.shape = } "