diff --git a/src/torchcodec/_frame.py b/src/torchcodec/_frame.py index c847f57b8..8137c4579 100644 --- a/src/torchcodec/_frame.py +++ b/src/torchcodec/_frame.py @@ -38,6 +38,14 @@ class Frame(Iterable): duration_seconds: float """The duration of the frame, in seconds (float).""" + def __post_init__(self): + # This is called after __init__() when a Frame is created. We can run + # input validation checks here. + if not self.data.ndim == 3: + raise ValueError(f"data must be 3-dimensional, got {self.data.shape = }") + self.pts_seconds = float(self.pts_seconds) + self.duration_seconds = float(self.duration_seconds) + def __iter__(self) -> Iterator[Union[Tensor, float]]: for field in dataclasses.fields(self): yield getattr(self, field.name) @@ -57,9 +65,54 @@ class FrameBatch(Iterable): duration_seconds: Tensor """The duration of the frame, in seconds (1-D ``torch.Tensor`` of floats).""" - def __iter__(self) -> Iterator[Union[Tensor, float]]: - for field in dataclasses.fields(self): - yield getattr(self, field.name) + def __post_init__(self): + # This is called after __init__() when a FrameBatch is created. We can + # run input validation checks here. + if self.data.ndim < 4: + raise ValueError( + f"data must be at least 4-dimensional. Got {self.data.shape = } " + "For 3-dimensional data, create a Frame object instead." + ) + + leading_dims = self.data.shape[:-3] + if not (leading_dims == self.pts_seconds.shape == self.duration_seconds.shape): + raise ValueError( + "Tried to create a FrameBatch but the leading dimensions of the inputs do not match. " + f"Got {self.data.shape = } so we expected the shape of pts_seconds and " + f"duration_seconds to be {leading_dims = }, but got " + f"{self.pts_seconds.shape = } and {self.duration_seconds.shape = }." + ) + + def __iter__(self) -> Union[Iterator["FrameBatch"], Iterator[Frame]]: + cls = Frame if self.data.ndim == 4 else FrameBatch + for data, pts_seconds, duration_seconds in zip( + self.data, self.pts_seconds, self.duration_seconds + ): + yield cls( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) + + def __getitem__(self, key) -> Union["FrameBatch", Frame]: + data = self.data[key] + pts_seconds = self.pts_seconds[key] + duration_seconds = self.duration_seconds[key] + if self.data.ndim == 4: + return Frame( + data=data, + pts_seconds=float(pts_seconds.item()), + duration_seconds=float(duration_seconds.item()), + ) + else: + return FrameBatch( + data=data, + pts_seconds=pts_seconds, + duration_seconds=duration_seconds, + ) + + def __len__(self): + return len(self.data) def __repr__(self): return _frame_repr(self) diff --git a/test/test_frame_dataclasses.py b/test/test_frame_dataclasses.py new file mode 100644 index 000000000..9b79b882f --- /dev/null +++ b/test/test_frame_dataclasses.py @@ -0,0 +1,121 @@ +import pytest +import torch +from torchcodec import Frame, FrameBatch + + +def test_frame_unpacking(): + data, pts_seconds, duration_seconds = Frame(torch.rand(3, 4, 5), 2, 3) # noqa + + +def test_frame_error(): + with pytest.raises(ValueError, match="data must be 3-dimensional"): + Frame( + data=torch.rand(1, 2), + pts_seconds=1, + duration_seconds=1, + ) + with pytest.raises(ValueError, match="data must be 3-dimensional"): + Frame( + data=torch.rand(1, 2, 3, 4), + pts_seconds=1, + duration_seconds=1, + ) + + +def test_framebatch_error(): + with pytest.raises(ValueError, match="data must be at least 4-dimensional"): + FrameBatch( + data=torch.rand(1, 2, 3), + pts_seconds=torch.rand(1), + duration_seconds=torch.rand(1), + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(3, 4, 2, 1), + pts_seconds=torch.rand(3), # ok + duration_seconds=torch.rand(2), # bad + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(3, 4, 2, 1), + pts_seconds=torch.rand(2), # bad + duration_seconds=torch.rand(3), # ok + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(5, 3, 4, 2, 1), + pts_seconds=torch.rand(5, 3), # ok + duration_seconds=torch.rand(5, 2), # bad + ) + + with pytest.raises( + ValueError, match="leading dimensions of the inputs do not match" + ): + FrameBatch( + data=torch.rand(5, 3, 4, 2, 1), + pts_seconds=torch.rand(5, 2), # bad + duration_seconds=torch.rand(5, 3), # ok + ) + + +def test_framebatch_iteration(): + T, N, C, H, W = 7, 6, 3, 2, 4 + + fb = FrameBatch( + data=torch.rand(T, N, C, H, W), + pts_seconds=torch.rand(T, N), + duration_seconds=torch.rand(T, N), + ) + + for sub_fb in fb: + assert isinstance(sub_fb, FrameBatch) + assert sub_fb.data.shape == (N, C, H, W) + assert sub_fb.pts_seconds.shape == (N,) + assert sub_fb.duration_seconds.shape == (N,) + for frame in sub_fb: + assert isinstance(frame, Frame) + assert frame.data.shape == (C, H, W) + assert isinstance(frame.pts_seconds, float) + assert isinstance(frame.duration_seconds, float) + + # Check unpacking behavior + first_sub_fb, *_ = fb + assert isinstance(first_sub_fb, FrameBatch) + + +def test_framebatch_indexing(): + T, N, C, H, W = 7, 6, 3, 2, 4 + + fb = FrameBatch( + data=torch.rand(T, N, C, H, W), + pts_seconds=torch.rand(T, N), + duration_seconds=torch.rand(T, N), + ) + + for i in range(len(fb)): + assert isinstance(fb[i], FrameBatch) + assert fb[i].data.shape == (N, C, H, W) + assert fb[i].pts_seconds.shape == (N,) + assert fb[i].duration_seconds.shape == (N,) + for j in range(len(fb[i])): + assert isinstance(fb[i][j], Frame) + assert fb[i][j].data.shape == (C, H, W) + assert isinstance(fb[i][j].pts_seconds, float) + assert isinstance(fb[i][j].duration_seconds, float) + + fb_fancy = fb[torch.arange(3)] + assert isinstance(fb_fancy, FrameBatch) + assert fb_fancy.data.shape == (3, N, C, H, W) + + fb_fancy = fb[[[0], [1]]] # select T=0 and N=1. + assert isinstance(fb_fancy, FrameBatch) + assert fb_fancy.data.shape == (1, C, H, W)