From a8a8cea74f6965094f5e29ff859db1adc3a0bcb9 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 17 Nov 2025 06:17:14 -0800 Subject: [PATCH 01/13] Committing to move on --- .../transforms/_decoder_transforms.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index dec4704b0..95a82542c 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -91,3 +91,42 @@ def _from_torchvision(cls, resize_tv: nn.Module): f"pair for the size, got {resize_tv.size}." ) return cls(size=resize_tv.size) + +@dataclass +class RandomCrop(DecoderTransform): + + size: Sequence[int] + _top: Optional[int] = None + _left: Optional[int] = None + + def _make_transform_spec(self) -> str: + assert len(self.size) == 2 + return f"crop, {self.size[0]}, {self.size[1]}, {_left}, {_top}" + + @classmethod + def _from_torchvision(cls, random_crop_tv: nn.Module): + v2 = import_torchvision_transforms_v2() + + assert isinstance(random_crop_tv, v2.RandomCrop) + + if random_crop_tv.padding is not None: + raise ValueError( + "TorchVision RandomCrop transform must not specify padding." + ) + if random_crop_tv.pad_if_needed is True: + raise ValueError( + "TorchVision RandomCrop transform must not specify pad_if_needed." + ) + if random_crop_tv.fill != 0: + raise ValueError("TorchVision RandomCrop must specify fill of 0.") + if random_crop_tv.padding_mode != "constant": + raise ValueError( + "TorchVision RandomCrop must specify padding_mode of constant." + ) + if len(random_crop_tv.size) != 2: + raise ValueError( + "TorchVision RandcomCrop transform must have a (height, width) " + f"pair for the size, got {random_crop_tv.size}." + ) + params = random_crop_tv.make_params([]) + return cls(size=random_crop_tv.size) From aa157651fbff7d8ab86f373565e92f1e0d3c2801 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Nov 2025 19:33:31 -0800 Subject: [PATCH 02/13] It... works? --- src/torchcodec/decoders/_video_decoder.py | 19 ++++++-- src/torchcodec/transforms/__init__.py | 2 +- .../transforms/_decoder_transforms.py | 39 ++++++++++++++--- test/test_transform_ops.py | 43 +++++++++++++++++++ 4 files changed, 93 insertions(+), 10 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 1b4d4706d..86ea4f064 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -19,7 +19,7 @@ create_decoder, ERROR_REPORTING_INSTRUCTIONS, ) -from torchcodec.transforms import DecoderTransform, Resize +from torchcodec.transforms import DecoderTransform, RandomCrop, Resize class VideoDecoder: @@ -167,7 +167,9 @@ def __init__( device = str(device) device_variant = _get_cuda_backend() - transform_specs = _make_transform_specs(transforms) + transform_specs = _make_transform_specs( + transforms, input_dims=(self.metadata.height, self.metadata.width) + ) core.add_video_stream( self._decoder, @@ -450,6 +452,7 @@ def _get_and_validate_stream_metadata( def _convert_to_decoder_transforms( transforms: Sequence[Union[DecoderTransform, nn.Module]], + input_dims: Tuple[int, int], ) -> List[DecoderTransform]: """Convert a sequence of transforms that may contain TorchVision transform objects into a list of only TorchCodec transform objects. @@ -482,7 +485,13 @@ def _convert_to_decoder_transforms( "v2 transforms, but TorchVision is not installed." ) elif isinstance(transform, v2.Resize): - converted_transforms.append(Resize._from_torchvision(transform)) + transform_tc = Resize._from_torchvision(transform) + input_dims = transform_tc._get_output_dims(input_dims) + converted_transforms.append(transform_tc) + elif isinstance(transform, v2.RandomCrop): + transform_tc = RandomCrop._from_torchvision(transform, input_dims) + input_dims = transform_tc._get_output_dims(input_dims) + converted_transforms.append(transform_tc) else: raise ValueError( f"Unsupported transform: {transform}. Transforms must be " @@ -490,6 +499,7 @@ def _convert_to_decoder_transforms( "v2 transform." ) else: + intput_dims = transform._get_output_dims(input_dims) converted_transforms.append(transform) return converted_transforms @@ -497,6 +507,7 @@ def _convert_to_decoder_transforms( def _make_transform_specs( transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], + input_dims: Tuple[int, int], ) -> str: """Given a sequence of transforms, turn those into the specification string the core API expects. @@ -516,7 +527,7 @@ def _make_transform_specs( if transforms is None: return "" - transforms = _convert_to_decoder_transforms(transforms) + transforms = _convert_to_decoder_transforms(transforms, input_dims) return ";".join([t._make_transform_spec() for t in transforms]) diff --git a/src/torchcodec/transforms/__init__.py b/src/torchcodec/transforms/__init__.py index 9f4a92f81..c93bad39e 100644 --- a/src/torchcodec/transforms/__init__.py +++ b/src/torchcodec/transforms/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._decoder_transforms import DecoderTransform, Resize # noqa +from ._decoder_transforms import DecoderTransform, RandomCrop, Resize # noqa diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 95a82542c..aa9b86f5f 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -7,8 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from types import ModuleType -from typing import Sequence +from typing import Optional, Sequence, Tuple +import torch from torch import nn @@ -40,6 +41,9 @@ class DecoderTransform(ABC): def _make_transform_spec(self) -> str: pass + def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: + return input_dims + def import_torchvision_transforms_v2() -> ModuleType: try: @@ -69,6 +73,9 @@ def _make_transform_spec(self) -> str: assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" + def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: + return self.size + @classmethod def _from_torchvision(cls, resize_tv: nn.Module): v2 = import_torchvision_transforms_v2() @@ -92,19 +99,38 @@ def _from_torchvision(cls, resize_tv: nn.Module): ) return cls(size=resize_tv.size) + @dataclass class RandomCrop(DecoderTransform): size: Sequence[int] _top: Optional[int] = None _left: Optional[int] = None + _input_dims: Optional[Tuple[int, int]] = None def _make_transform_spec(self) -> str: assert len(self.size) == 2 - return f"crop, {self.size[0]}, {self.size[1]}, {_left}, {_top}" + if self._top is None or self._left is None: + assert self._input_dims is not None + if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]: + raise ValueError( + f"Input dimensions {input_dims} are smaller than the crop size {self.size}." + ) + self._top = torch.randint( + 0, self._input_dims[0] - self.size[0] + 1, size=() + ) + self._left = torch.randint( + 0, self._input_dims[1] - self.size[1] + 1, size=() + ) + + return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}" + + def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: + self._input_dims = input_dims + return self.size @classmethod - def _from_torchvision(cls, random_crop_tv: nn.Module): + def _from_torchvision(cls, random_crop_tv: nn.Module, input_dims: Tuple[int, int]): v2 = import_torchvision_transforms_v2() assert isinstance(random_crop_tv, v2.RandomCrop) @@ -128,5 +154,8 @@ def _from_torchvision(cls, random_crop_tv: nn.Module): "TorchVision RandcomCrop transform must have a (height, width) " f"pair for the size, got {random_crop_tv.size}." ) - params = random_crop_tv.make_params([]) - return cls(size=random_crop_tv.size) + params = random_crop_tv.make_params( + torch.empty(size=(3, *input_dims), dtype=torch.uint8) + ) + assert random_crop_tv.size == (params["height"], params["width"]) + return cls(size=random_crop_tv.size, _top=params["top"], _left=params["left"]) diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index bc42732ef..7eb82cfaf 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -145,6 +145,49 @@ def test_resize_fails(self): ): VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))]) + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((0.5, 0.5), (0.25, 0.1)), + ) + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + def test_random_crop_torchvision( + self, video, height_scaling_factor, width_scaling_factor + ): + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + + torch.manual_seed(0) + tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width)) + decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop]) + + torch.manual_seed(0) + decoder_random_crop_tv = VideoDecoder( + video.path, transforms=[v2.RandomCrop(size=(height, width))] + ) + + decoder_full = VideoDecoder(video.path) + + num_frames = len(decoder_random_crop_tv) + assert num_frames == len(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.1), + int(num_frames * 0.2), + int(num_frames * 0.3), + int(num_frames * 0.4), + int(num_frames * 0.5), + int(num_frames * 0.75), + int(num_frames * 0.90), + num_frames - 1, + ]: + frame_random_crop = decoder_random_crop[frame_index] + frame_random_crop_tv = decoder_random_crop_tv[frame_index] + assert_frames_equal(frame_random_crop, frame_random_crop_tv) + + expected_shape = (video.get_num_color_channels(), height, width) + assert frame_random_crop_tv.shape == expected_shape + def test_transform_fails(self): with pytest.raises( ValueError, From fd8f7a5b0d37b4ef38a798672aa1e021836f041c Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Thu, 20 Nov 2025 19:37:33 -0800 Subject: [PATCH 03/13] Lint --- src/torchcodec/decoders/_video_decoder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 86ea4f064..c9b10a548 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -499,7 +499,7 @@ def _convert_to_decoder_transforms( "v2 transform." ) else: - intput_dims = transform._get_output_dims(input_dims) + input_dims = transform._get_output_dims(input_dims) converted_transforms.append(transform) return converted_transforms From 7e43313315898fa5181a5e61e08b8cd387866e3a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Nov 2025 12:25:18 -0800 Subject: [PATCH 04/13] Docstrings, better error checking, better testing --- src/torchcodec/_core/custom_ops.cpp | 17 ++++- .../transforms/_decoder_transforms.py | 48 ++++++++++--- test/test_transform_ops.py | 72 ++++++++++++++++++- 3 files changed, 125 insertions(+), 12 deletions(-) diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 3c6048187..9c7b8ac7f 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -239,6 +239,19 @@ int checkedToPositiveInt(const std::string& str) { return ret; } +int checkedToNonNegativeInt(const std::string& str) { + int ret = 0; + try { + ret = std::stoi(str); + } catch (const std::invalid_argument&) { + TORCH_CHECK(false, "String cannot be converted to an int:" + str); + } catch (const std::out_of_range&) { + TORCH_CHECK(false, "String would become integer out of range:" + str); + } + TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str); + return ret; +} + // Resize transform specs take the form: // // "resize, , " @@ -270,8 +283,8 @@ Transform* makeCropTransform( "cropTransformSpec must have 5 elements including its name"); int height = checkedToPositiveInt(cropTransformSpec[1]); int width = checkedToPositiveInt(cropTransformSpec[2]); - int x = checkedToPositiveInt(cropTransformSpec[3]); - int y = checkedToPositiveInt(cropTransformSpec[4]); + int x = checkedToNonNegativeInt(cropTransformSpec[3]); + int y = checkedToNonNegativeInt(cropTransformSpec[4]); return new CropTransform(FrameDims(height, width), x, y); } diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index aa9b86f5f..2d2215049 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -23,8 +23,8 @@ class DecoderTransform(ABC): decoded frames and applying the same kind of transform. Most ``DecoderTransform`` objects have a complementary transform in TorchVision, - specificially in `torchvision.transforms.v2 `_. For such transforms, we - ensure that: + specificially in `torchvision.transforms.v2 `_. + For such transforms, we ensure that: 1. The names are the same. 2. Default behaviors are the same. @@ -74,7 +74,7 @@ def _make_transform_spec(self) -> str: return f"resize, {self.size[0]}, {self.size[1]}" def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: - return self.size + return (*self.size,) @classmethod def _from_torchvision(cls, resize_tv: nn.Module): @@ -102,6 +102,20 @@ def _from_torchvision(cls, resize_tv: nn.Module): @dataclass class RandomCrop(DecoderTransform): + """Crop the decoded frame to a given size at a random location in the frame. + + Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`. + Padding of all kinds is disabled. The random location within the frame is + determined during the initialization of the + :class:~`torchcodec.decoders.VideoDecoder` object that owns this transform. + As a consequence, each decoded frame in the video will be cropped at the + same location. Videos with variable resolution may result in undefined + behavior. + + Args: + size: (sequence of int): Desired output size. Must be a sequence of + the form (height, width). + """ size: Sequence[int] _top: Optional[int] = None @@ -109,13 +123,30 @@ class RandomCrop(DecoderTransform): _input_dims: Optional[Tuple[int, int]] = None def _make_transform_spec(self) -> str: - assert len(self.size) == 2 + if len(self.size) != 2: + raise ValueError( + f"RandomCrop's size must be a sequence of length 2, got {self.size}. " + "This should never happen, please report a bug." + ) + if self._top is None or self._left is None: - assert self._input_dims is not None + # TODO: It would be very strange if only ONE of those is None. But should we + # make it an error? We can continue, but it would probably mean + # something bad happened. Dear reviewer, please register an opinion here: + if self._input_dims is None: + raise ValueError( + "RandomCrop's input_dims must be set before calling _make_transform_spec(). " + "This should never happen, please report a bug." + ) if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]: raise ValueError( f"Input dimensions {input_dims} are smaller than the crop size {self.size}." ) + + # Note: This logic must match the logic in + # torchvision.transforms.v2.RandomCrop.make_params(). Given + # the same seed, they should get the same result. This is an + # API guarantee with our users. self._top = torch.randint( 0, self._input_dims[0] - self.size[0] + 1, size=() ) @@ -144,17 +175,16 @@ def _from_torchvision(cls, random_crop_tv: nn.Module, input_dims: Tuple[int, int "TorchVision RandomCrop transform must not specify pad_if_needed." ) if random_crop_tv.fill != 0: - raise ValueError("TorchVision RandomCrop must specify fill of 0.") + raise ValueError("TorchVision RandomCrop fill must be 0.") if random_crop_tv.padding_mode != "constant": - raise ValueError( - "TorchVision RandomCrop must specify padding_mode of constant." - ) + raise ValueError("TorchVision RandomCrop padding_mode must be constant.") if len(random_crop_tv.size) != 2: raise ValueError( "TorchVision RandcomCrop transform must have a (height, width) " f"pair for the size, got {random_crop_tv.size}." ) params = random_crop_tv.make_params( + # TODO: deal with NCHW versus NHWC; video decoder knows torch.empty(size=(3, *input_dims), dtype=torch.uint8) ) assert random_crop_tv.size == (params["height"], params["width"]) diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 7eb82cfaf..b1b5f4f83 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -147,7 +147,7 @@ def test_resize_fails(self): @pytest.mark.parametrize( "height_scaling_factor, width_scaling_factor", - ((0.5, 0.5), (0.25, 0.1)), + ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.25, 0.25)), ) @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) def test_random_crop_torchvision( @@ -156,6 +156,9 @@ def test_random_crop_torchvision( height = int(video.get_height() * height_scaling_factor) width = int(video.get_width() * width_scaling_factor) + # We want both kinds of RandomCrop objects to get arrive at the same + # locations to crop, so we need to make sure they get the same random + # seed. torch.manual_seed(0) tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width)) decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop]) @@ -188,6 +191,73 @@ def test_random_crop_torchvision( expected_shape = (video.get_num_color_channels(), height, width) assert frame_random_crop_tv.shape == expected_shape + frame_full = decoder_full[frame_index] + frame_tv = v2.functional.crop( + frame_full, + top=tc_random_crop._top, + left=tc_random_crop._left, + height=tc_random_crop.size[0], + width=tc_random_crop.size[1], + ) + assert_frames_equal(frame_random_crop, frame_tv) + + def test_crop_fails(self): + with pytest.raises( + ValueError, + match="must not specify padding", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[ + v2.RandomCrop( + size=(100, 100), + padding=255, + ) + ], + ) + + with pytest.raises( + ValueError, + match="must not specify pad_if_needed", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[ + v2.RandomCrop( + size=(100, 100), + pad_if_needed=True, + ) + ], + ) + + with pytest.raises( + ValueError, + match="fill must be 0", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[ + v2.RandomCrop( + size=(100, 100), + fill=255, + ) + ], + ) + + with pytest.raises( + ValueError, + match="padding_mode must be constant", + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[ + v2.RandomCrop( + size=(100, 100), + padding_mode="edge", + ) + ], + ) + def test_transform_fails(self): with pytest.raises( ValueError, From 8e6a8f2a7177c376fb730a791afe1024a4b09abb Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 21 Nov 2025 18:37:58 -0800 Subject: [PATCH 05/13] Way more defensive programming --- src/torchcodec/decoders/_video_decoder.py | 23 +++- .../transforms/_decoder_transforms.py | 111 +++++++++++++----- test/test_transform_ops.py | 98 +++++++++------- 3 files changed, 152 insertions(+), 80 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index c9b10a548..4a249d1a7 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -168,7 +168,9 @@ def __init__( device_variant = _get_cuda_backend() transform_specs = _make_transform_specs( - transforms, input_dims=(self.metadata.height, self.metadata.width) + transforms, + input_dims=(self.metadata.height, self.metadata.width), + dimension_order=dimension_order, ) core.add_video_stream( @@ -452,7 +454,8 @@ def _get_and_validate_stream_metadata( def _convert_to_decoder_transforms( transforms: Sequence[Union[DecoderTransform, nn.Module]], - input_dims: Tuple[int, int], + input_dims: Tuple[Optional[int], Optional[int]], + dimension_order: Literal["NCHW", "NHWC"], ) -> List[DecoderTransform]: """Convert a sequence of transforms that may contain TorchVision transform objects into a list of only TorchCodec transform objects. @@ -489,7 +492,16 @@ def _convert_to_decoder_transforms( input_dims = transform_tc._get_output_dims(input_dims) converted_transforms.append(transform_tc) elif isinstance(transform, v2.RandomCrop): - transform_tc = RandomCrop._from_torchvision(transform, input_dims) + if dimension_order != "NCHW": + raise ValueError( + "TorchVision v2 RandomCrop is only supported for NCHW " + "dimension order. Please use the TorchCodec RandomCrop " + "transform instead." + ) + transform_tc = RandomCrop._from_torchvision( + transform, + input_dims, + ) input_dims = transform_tc._get_output_dims(input_dims) converted_transforms.append(transform_tc) else: @@ -507,7 +519,8 @@ def _convert_to_decoder_transforms( def _make_transform_specs( transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], - input_dims: Tuple[int, int], + input_dims: Tuple[Optional[int], Optional[int]], + dimension_order: Literal["NCHW", "NHWC"], ) -> str: """Given a sequence of transforms, turn those into the specification string the core API expects. @@ -527,7 +540,7 @@ def _make_transform_specs( if transforms is None: return "" - transforms = _convert_to_decoder_transforms(transforms, input_dims) + transforms = _convert_to_decoder_transforms(transforms, input_dims, dimension_order) return ";".join([t._make_transform_spec() for t in transforms]) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 2d2215049..7ac45c1cd 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -41,7 +41,9 @@ class DecoderTransform(ABC): def _make_transform_spec(self) -> str: pass - def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: + def _get_output_dims( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> Tuple[Optional[int], Optional[int]]: return input_dims @@ -70,34 +72,39 @@ class Resize(DecoderTransform): size: Sequence[int] def _make_transform_spec(self) -> str: + # TODO: establish this invariant in the constructor during refactor assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" - def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: - return (*self.size,) + def _get_output_dims( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> Tuple[Optional[int], Optional[int]]: + # TODO: establish this invariant in the constructor during refactor + assert len(self.size) == 2 + return (self.size[0], self.size[1]) @classmethod - def _from_torchvision(cls, resize_tv: nn.Module): + def _from_torchvision(cls, tv_resize: nn.Module): v2 = import_torchvision_transforms_v2() - assert isinstance(resize_tv, v2.Resize) + assert isinstance(tv_resize, v2.Resize) - if resize_tv.interpolation is not v2.InterpolationMode.BILINEAR: + if tv_resize.interpolation is not v2.InterpolationMode.BILINEAR: raise ValueError( "TorchVision Resize transform must use bilinear interpolation." ) - if resize_tv.antialias is False: + if tv_resize.antialias is False: raise ValueError( "TorchVision Resize transform must have antialias enabled." ) - if resize_tv.size is None: + if tv_resize.size is None: raise ValueError("TorchVision Resize transform must have a size specified.") - if len(resize_tv.size) != 2: + if len(tv_resize.size) != 2: raise ValueError( "TorchVision Resize transform must have a (height, width) " - f"pair for the size, got {resize_tv.size}." + f"pair for the size, got {tv_resize.size}." ) - return cls(size=resize_tv.size) + return cls(size=tv_resize.size) @dataclass @@ -140,52 +147,92 @@ def _make_transform_spec(self) -> str: ) if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]: raise ValueError( - f"Input dimensions {input_dims} are smaller than the crop size {self.size}." + f"Input dimensions {self._input_dims} are smaller than the crop size {self.size}." ) # Note: This logic must match the logic in # torchvision.transforms.v2.RandomCrop.make_params(). Given # the same seed, they should get the same result. This is an # API guarantee with our users. - self._top = torch.randint( - 0, self._input_dims[0] - self.size[0] + 1, size=() + self._top = int( + torch.randint(0, self._input_dims[0] - self.size[0] + 1, size=()).item() ) - self._left = torch.randint( - 0, self._input_dims[1] - self.size[1] + 1, size=() + self._left = int( + torch.randint(0, self._input_dims[1] - self.size[1] + 1, size=()).item() ) return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}" - def _get_output_dims(self, input_dims: Tuple[int, int]) -> Tuple[int, int]: - self._input_dims = input_dims - return self.size + def _get_output_dims( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> Tuple[Optional[int], Optional[int]]: + # TODO: establish this invariant in the constructor during refactor + assert len(self.size) == 2 + + height, width = input_dims + if height is None: + raise ValueError( + "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known." + ) + if width is None: + raise ValueError( + "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known." + ) + + self._input_dims = (height, width) + return (self.size[0], self.size[1]) @classmethod - def _from_torchvision(cls, random_crop_tv: nn.Module, input_dims: Tuple[int, int]): + def _from_torchvision( + cls, + tv_random_crop: nn.Module, + input_dims: Tuple[Optional[int], Optional[int]], + ): v2 = import_torchvision_transforms_v2() - assert isinstance(random_crop_tv, v2.RandomCrop) + assert isinstance(tv_random_crop, v2.RandomCrop) - if random_crop_tv.padding is not None: + if tv_random_crop.padding is not None: raise ValueError( "TorchVision RandomCrop transform must not specify padding." ) - if random_crop_tv.pad_if_needed is True: + + if tv_random_crop.pad_if_needed is True: raise ValueError( "TorchVision RandomCrop transform must not specify pad_if_needed." ) - if random_crop_tv.fill != 0: + + if tv_random_crop.fill != 0: raise ValueError("TorchVision RandomCrop fill must be 0.") - if random_crop_tv.padding_mode != "constant": + + if tv_random_crop.padding_mode != "constant": raise ValueError("TorchVision RandomCrop padding_mode must be constant.") - if len(random_crop_tv.size) != 2: + + if len(tv_random_crop.size) != 2: raise ValueError( "TorchVision RandcomCrop transform must have a (height, width) " - f"pair for the size, got {random_crop_tv.size}." + f"pair for the size, got {tv_random_crop.size}." + ) + + height, width = input_dims + if height is None: + raise ValueError( + "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known." + ) + if width is None: + raise ValueError( + "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known." ) - params = random_crop_tv.make_params( - # TODO: deal with NCHW versus NHWC; video decoder knows - torch.empty(size=(3, *input_dims), dtype=torch.uint8) + + # Note that TorchVision v2 transforms only accept NCHW tensors. + params = tv_random_crop.make_params( + torch.empty(size=(3, height, width), dtype=torch.uint8) ) - assert random_crop_tv.size == (params["height"], params["width"]) - return cls(size=random_crop_tv.size, _top=params["top"], _left=params["left"]) + + if tv_random_crop.size != (params["height"], params["width"]): + raise ValueError( + f"TorchVision RandomCrop's provided size, {tv_random_crop.size} " + f"must match the computed size, {params['height'], params['width']}." + ) + + return cls(size=tv_random_crop.size, _top=params["top"], _left=params["left"]) diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index b1b5f4f83..fd4a7de85 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -151,7 +151,10 @@ def test_resize_fails(self): ) @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) def test_random_crop_torchvision( - self, video, height_scaling_factor, width_scaling_factor + self, + height_scaling_factor, + width_scaling_factor, + video, ): height = int(video.get_height() * height_scaling_factor) width = int(video.get_width() * width_scaling_factor) @@ -165,7 +168,8 @@ def test_random_crop_torchvision( torch.manual_seed(0) decoder_random_crop_tv = VideoDecoder( - video.path, transforms=[v2.RandomCrop(size=(height, width))] + video.path, + transforms=[v2.RandomCrop(size=(height, width))], ) decoder_full = VideoDecoder(video.path) @@ -201,61 +205,69 @@ def test_random_crop_torchvision( ) assert_frames_equal(frame_random_crop, frame_tv) - def test_crop_fails(self): - with pytest.raises( - ValueError, - match="must not specify padding", - ): - VideoDecoder( - NASA_VIDEO.path, - transforms=[ - v2.RandomCrop( - size=(100, 100), - padding=255, - ) - ], - ) + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((0.25, 0.1), (0.25, 0.25)), + ) + def test_random_crop_nhwc( + self, + height_scaling_factor, + width_scaling_factor, + ): + height = int(TEST_SRC_2_720P.get_height() * height_scaling_factor) + width = int(TEST_SRC_2_720P.get_width() * width_scaling_factor) - with pytest.raises( - ValueError, - match="must not specify pad_if_needed", - ): - VideoDecoder( - NASA_VIDEO.path, - transforms=[ - v2.RandomCrop( - size=(100, 100), - pad_if_needed=True, - ) - ], - ) + decoder = VideoDecoder( + TEST_SRC_2_720P.path, + transforms=[torchcodec.transforms.RandomCrop(size=(height, width))], + dimension_order="NHWC", + ) + + num_frames = len(decoder) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame = decoder[frame_index] + assert frame.shape == (height, width, 3) + @pytest.mark.parametrize( + "error_message, params", + ( + ("must not specify padding", dict(size=(100, 100), padding=255)), + ( + "must not specify pad_if_needed", + dict(size=(100, 100), pad_if_needed=True), + ), + ("fill must be 0", dict(size=(100, 100), fill=255)), + ( + "padding_mode must be constant", + dict(size=(100, 100), padding_mode="edge"), + ), + ), + ) + def test_crop_fails(self, error_message, params): with pytest.raises( ValueError, - match="fill must be 0", + match=error_message, ): VideoDecoder( NASA_VIDEO.path, - transforms=[ - v2.RandomCrop( - size=(100, 100), - fill=255, - ) - ], + transforms=[v2.RandomCrop(**params)], ) + def test_tv_random_crop_nhwc_fails(self): with pytest.raises( ValueError, - match="padding_mode must be constant", + match="TorchVision v2 RandomCrop is only supported for NCHW", ): VideoDecoder( NASA_VIDEO.path, - transforms=[ - v2.RandomCrop( - size=(100, 100), - padding_mode="edge", - ) - ], + transforms=[v2.RandomCrop(size=(100, 100))], + dimension_order="NHWC", ) def test_transform_fails(self): From d8b7ed003d0c86c76c8b55b6c3c9a0a1610222ed Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 1 Dec 2025 10:58:44 -0800 Subject: [PATCH 06/13] Refactor all the things --- src/torchcodec/decoders/_video_decoder.py | 118 +++++++++--------- .../transforms/_decoder_transforms.py | 92 +++++--------- test/test_transform_ops.py | 25 +--- 3 files changed, 99 insertions(+), 136 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 4a249d1a7..9aefa2e69 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ import json import numbers from pathlib import Path -from typing import List, Literal, Optional, Sequence, Tuple, Union +from typing import Literal, Optional, Sequence, Tuple, Union import torch from torch import device as torch_device, nn, Tensor @@ -170,7 +170,6 @@ def __init__( transform_specs = _make_transform_specs( transforms, input_dims=(self.metadata.height, self.metadata.width), - dimension_order=dimension_order, ) core.add_video_stream( @@ -452,25 +451,33 @@ def _get_and_validate_stream_metadata( ) -def _convert_to_decoder_transforms( - transforms: Sequence[Union[DecoderTransform, nn.Module]], +def _make_transform_specs( + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], input_dims: Tuple[Optional[int], Optional[int]], - dimension_order: Literal["NCHW", "NHWC"], -) -> List[DecoderTransform]: - """Convert a sequence of transforms that may contain TorchVision transform - objects into a list of only TorchCodec transform objects. +) -> str: + """Given a sequence of transforms, turn those into the specification string + the core API expects. Args: - transforms: Squence of transform objects. The objects can be one of two - types: + transforms: Optional sequence of transform objects. The objects can be + one of two types: 1. torchcodec.transforms.DecoderTransform 2. torchvision.transforms.v2.Transform, but our type annotation only mentions its base, nn.Module. We don't want to take a hard dependency on TorchVision. + input_dims: Optional (height, width) pair. Note that only some + transforms need to know the dimensions. If the user provides + transforms that don't need to know the dimensions, and that metadata + is missing, everything should still work. That means we assert their + existence as late as possible. Returns: - List of DecoderTransform objects. + String of transforms in the format the core API expects: transform + specifications separate by semicolons. """ + if transforms is None: + return "" + try: from torchvision.transforms import v2 @@ -478,70 +485,63 @@ def _convert_to_decoder_transforms( except ImportError: tv_available = False - converted_transforms: list[DecoderTransform] = [] + # The following loop accomplishes two tasks: + # + # 1. Converts the transform to a DecoderTransform, if necessary. We + # accept TorchVision transform objects and they must be converted + # to their matching DecoderTransform. + # 2. Calculates what the input dimensions are to each transform. + # + # The order in our transforms list is semantically meaningful, as we + # actually have a pipeline where the output of one transform is the input to + # the next. For example, if we have the transforms list [A, B, C, D], then + # we should understand that as: + # A -> B -> C -> D + # Where the frame produced by A is the input to B, the frame produced by B + # is the input to C, etc. This particularly matters for frame dimensions. + # Transforms can both: + # + # 1. Produce frames with arbitrary dimensions. + # 2. Rely on their input frame's dimensions to calculate ahead-of-time + # what their runtime behavior will be. + # + # The consequence of the above facts is that we need to statically track + # frame dimensions in the pipeline while we pre-process it. The input + # frame's dimensions to A, our first transform, is always what we know from + # our metadata. For each transform, we always calculate its output + # dimensions from its input dimensions. We store these with the converted + # transform, to be all used together when we generate the specs. + converted_transforms: list[(DecoderTransform, Tuple[int, int])] = [] + curr_input_dims = input_dims for transform in transforms: - if not isinstance(transform, DecoderTransform): + if isinstance(transform, DecoderTransform): + output_dims = transform._calculate_output_dims(curr_input_dims) + converted_transforms.append((transform, curr_input_dims)) + else: if not tv_available: raise ValueError( f"The supplied transform, {transform}, is not a TorchCodec " - " DecoderTransform. TorchCodec also accept TorchVision " + " DecoderTransform. TorchCodec also accepts TorchVision " "v2 transforms, but TorchVision is not installed." ) elif isinstance(transform, v2.Resize): - transform_tc = Resize._from_torchvision(transform) - input_dims = transform_tc._get_output_dims(input_dims) - converted_transforms.append(transform_tc) + tc_transform = Resize._from_torchvision(transform) + output_dims = tc_transform._calculate_output_dims(curr_input_dims) + converted_transforms.append((tc_transform, curr_input_dims)) elif isinstance(transform, v2.RandomCrop): - if dimension_order != "NCHW": - raise ValueError( - "TorchVision v2 RandomCrop is only supported for NCHW " - "dimension order. Please use the TorchCodec RandomCrop " - "transform instead." - ) - transform_tc = RandomCrop._from_torchvision( - transform, - input_dims, - ) - input_dims = transform_tc._get_output_dims(input_dims) - converted_transforms.append(transform_tc) + tc_transform = RandomCrop._from_torchvision(transform) + output_dims = tc_transform._calculate_output_dims(curr_input_dims) + converted_transforms.append((tc_transform, curr_input_dims)) else: raise ValueError( f"Unsupported transform: {transform}. Transforms must be " "either a TorchCodec DecoderTransform or a TorchVision " "v2 transform." ) - else: - input_dims = transform._get_output_dims(input_dims) - converted_transforms.append(transform) - - return converted_transforms - -def _make_transform_specs( - transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], - input_dims: Tuple[Optional[int], Optional[int]], - dimension_order: Literal["NCHW", "NHWC"], -) -> str: - """Given a sequence of transforms, turn those into the specification string - the core API expects. - - Args: - transforms: Optional sequence of transform objects. The objects can be - one of two types: - 1. torchcodec.transforms.DecoderTransform - 2. torchvision.transforms.v2.Transform, but our type annotation - only mentions its base, nn.Module. We don't want to take a - hard dependency on TorchVision. - - Returns: - String of transforms in the format the core API expects: transform - specifications separate by semicolons. - """ - if transforms is None: - return "" + curr_input_dims = output_dims - transforms = _convert_to_decoder_transforms(transforms, input_dims, dimension_order) - return ";".join([t._make_transform_spec() for t in transforms]) + return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) def _read_custom_frame_mappings( diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 7ac45c1cd..d426fb063 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -38,10 +38,10 @@ class DecoderTransform(ABC): """ @abstractmethod - def _make_transform_spec(self) -> str: + def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str: pass - def _get_output_dims( + def _calculate_output_dims( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> Tuple[Optional[int], Optional[int]]: return input_dims @@ -71,12 +71,12 @@ class Resize(DecoderTransform): size: Sequence[int] - def _make_transform_spec(self) -> str: + def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str: # TODO: establish this invariant in the constructor during refactor assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" - def _get_output_dims( + def _calculate_output_dims( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> Tuple[Optional[int], Optional[int]]: # TODO: establish this invariant in the constructor during refactor @@ -125,45 +125,37 @@ class RandomCrop(DecoderTransform): """ size: Sequence[int] + + # Note that these values are never read by this object or the decoder. We + # record them for testing purposes only. _top: Optional[int] = None _left: Optional[int] = None - _input_dims: Optional[Tuple[int, int]] = None - def _make_transform_spec(self) -> str: + def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str: if len(self.size) != 2: raise ValueError( f"RandomCrop's size must be a sequence of length 2, got {self.size}. " "This should never happen, please report a bug." ) - if self._top is None or self._left is None: - # TODO: It would be very strange if only ONE of those is None. But should we - # make it an error? We can continue, but it would probably mean - # something bad happened. Dear reviewer, please register an opinion here: - if self._input_dims is None: - raise ValueError( - "RandomCrop's input_dims must be set before calling _make_transform_spec(). " - "This should never happen, please report a bug." - ) - if self._input_dims[0] < self.size[0] or self._input_dims[1] < self.size[1]: - raise ValueError( - f"Input dimensions {self._input_dims} are smaller than the crop size {self.size}." - ) - - # Note: This logic must match the logic in - # torchvision.transforms.v2.RandomCrop.make_params(). Given - # the same seed, they should get the same result. This is an - # API guarantee with our users. - self._top = int( - torch.randint(0, self._input_dims[0] - self.size[0] + 1, size=()).item() - ) - self._left = int( - torch.randint(0, self._input_dims[1] - self.size[1] + 1, size=()).item() + # Note: This logic below must match the logic in + # torchvision.transforms.v2.RandomCrop.make_params(). Given + # the same seed, they should get the same result. This is an + # API guarantee with our users. + if input_dims[0] < self.size[0] or input_dims[1] < self.size[1]: + raise ValueError( + f"Input dimensions {input_dims} are smaller than the crop size {self.size}." ) - return f"crop, {self.size[0]}, {self.size[1]}, {self._left}, {self._top}" + top = int(torch.randint(0, input_dims[0] - self.size[0] + 1, size=()).item()) + self._top = top + + left = int(torch.randint(0, input_dims[1] - self.size[1] + 1, size=()).item()) + self._left = left + + return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" - def _get_output_dims( + def _calculate_output_dims( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> Tuple[Optional[int], Optional[int]]: # TODO: establish this invariant in the constructor during refactor @@ -172,25 +164,30 @@ def _get_output_dims( height, width = input_dims if height is None: raise ValueError( - "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known." + "Video metadata has no height. " + "RandomCrop can only be used when input frame dimensions are known." ) if width is None: raise ValueError( - "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known." + "Video metadata has no width. " + "RandomCrop can only be used when input frame dimensions are known." ) - self._input_dims = (height, width) return (self.size[0], self.size[1]) @classmethod def _from_torchvision( cls, tv_random_crop: nn.Module, - input_dims: Tuple[Optional[int], Optional[int]], ): v2 = import_torchvision_transforms_v2() - assert isinstance(tv_random_crop, v2.RandomCrop) + if not isinstance(tv_random_crop, v2.RandomCrop): + raise ValueError( + "Transform must be TorchVision's RandomCrop, " + f"it is instead {type(tv_random_crop).__name__}. " + "This should never happen, please report a bug." + ) if tv_random_crop.padding is not None: raise ValueError( @@ -214,25 +211,4 @@ def _from_torchvision( f"pair for the size, got {tv_random_crop.size}." ) - height, width = input_dims - if height is None: - raise ValueError( - "Video metadata has no height. RandomCrop can only be used when input frame dimensions are known." - ) - if width is None: - raise ValueError( - "Video metadata has no width. RandomCrop can only be used when input frame dimensions are known." - ) - - # Note that TorchVision v2 transforms only accept NCHW tensors. - params = tv_random_crop.make_params( - torch.empty(size=(3, height, width), dtype=torch.uint8) - ) - - if tv_random_crop.size != (params["height"], params["width"]): - raise ValueError( - f"TorchVision RandomCrop's provided size, {tv_random_crop.size} " - f"must match the computed size, {params['height'], params['width']}." - ) - - return cls(size=tv_random_crop.size, _top=params["top"], _left=params["left"]) + return cls(size=tv_random_crop.size) diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index fd4a7de85..62890334e 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -147,14 +147,16 @@ def test_resize_fails(self): @pytest.mark.parametrize( "height_scaling_factor, width_scaling_factor", - ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.25, 0.25)), + ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)), ) @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + @pytest.mark.parametrize("seed", [0, 1234]) def test_random_crop_torchvision( self, height_scaling_factor, width_scaling_factor, video, + seed, ): height = int(video.get_height() * height_scaling_factor) width = int(video.get_width() * width_scaling_factor) @@ -162,11 +164,11 @@ def test_random_crop_torchvision( # We want both kinds of RandomCrop objects to get arrive at the same # locations to crop, so we need to make sure they get the same random # seed. - torch.manual_seed(0) + torch.manual_seed(seed) tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width)) decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop]) - torch.manual_seed(0) + torch.manual_seed(seed) decoder_random_crop_tv = VideoDecoder( video.path, transforms=[v2.RandomCrop(size=(height, width))], @@ -179,13 +181,9 @@ def test_random_crop_torchvision( for frame_index in [ 0, - int(num_frames * 0.1), - int(num_frames * 0.2), - int(num_frames * 0.3), - int(num_frames * 0.4), + int(num_frames * 0.25), int(num_frames * 0.5), int(num_frames * 0.75), - int(num_frames * 0.90), num_frames - 1, ]: frame_random_crop = decoder_random_crop[frame_index] @@ -259,17 +257,6 @@ def test_crop_fails(self, error_message, params): transforms=[v2.RandomCrop(**params)], ) - def test_tv_random_crop_nhwc_fails(self): - with pytest.raises( - ValueError, - match="TorchVision v2 RandomCrop is only supported for NCHW", - ): - VideoDecoder( - NASA_VIDEO.path, - transforms=[v2.RandomCrop(size=(100, 100))], - dimension_order="NHWC", - ) - def test_transform_fails(self): with pytest.raises( ValueError, From 705d1ef8f55b143ae9c842a711f09f921a067b48 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 1 Dec 2025 11:18:37 -0800 Subject: [PATCH 07/13] Comment formatting pedantry --- src/torchcodec/decoders/_video_decoder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 9aefa2e69..93db20f1e 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -496,7 +496,9 @@ def _make_transform_specs( # actually have a pipeline where the output of one transform is the input to # the next. For example, if we have the transforms list [A, B, C, D], then # we should understand that as: + # # A -> B -> C -> D + # # Where the frame produced by A is the input to B, the frame produced by B # is the input to C, etc. This particularly matters for frame dimensions. # Transforms can both: From 62eb58565e65ae622ca9984be3a677c908004d47 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Mon, 1 Dec 2025 11:38:28 -0800 Subject: [PATCH 08/13] Type checking --- src/torchcodec/decoders/_video_decoder.py | 4 ++- .../transforms/_decoder_transforms.py | 30 +++++++++++++++---- 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 93db20f1e..0d6e517ba 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -513,7 +513,9 @@ def _make_transform_specs( # our metadata. For each transform, we always calculate its output # dimensions from its input dimensions. We store these with the converted # transform, to be all used together when we generate the specs. - converted_transforms: list[(DecoderTransform, Tuple[int, int])] = [] + converted_transforms: list[ + Tuple[DecoderTransform, Tuple[Optional[int], Optional[int]]] + ] = [] curr_input_dims = input_dims for transform in transforms: if isinstance(transform, DecoderTransform): diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index d426fb063..1c9d73ea1 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -38,7 +38,9 @@ class DecoderTransform(ABC): """ @abstractmethod - def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str: + def _make_transform_spec( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> str: pass def _calculate_output_dims( @@ -71,7 +73,9 @@ class Resize(DecoderTransform): size: Sequence[int] - def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str: + def _make_transform_spec( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> str: # TODO: establish this invariant in the constructor during refactor assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" @@ -131,26 +135,40 @@ class RandomCrop(DecoderTransform): _top: Optional[int] = None _left: Optional[int] = None - def _make_transform_spec(self, input_dims: Tuple[int, int]) -> str: + def _make_transform_spec( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> str: if len(self.size) != 2: raise ValueError( f"RandomCrop's size must be a sequence of length 2, got {self.size}. " "This should never happen, please report a bug." ) + height, width = input_dims + if height is None: + raise ValueError( + "Video metadata has no height. " + "RandomCrop can only be used when input frame dimensions are known." + ) + if width is None: + raise ValueError( + "Video metadata has no width. " + "RandomCrop can only be used when input frame dimensions are known." + ) + # Note: This logic below must match the logic in # torchvision.transforms.v2.RandomCrop.make_params(). Given # the same seed, they should get the same result. This is an # API guarantee with our users. - if input_dims[0] < self.size[0] or input_dims[1] < self.size[1]: + if height < self.size[0] or width < self.size[1]: raise ValueError( f"Input dimensions {input_dims} are smaller than the crop size {self.size}." ) - top = int(torch.randint(0, input_dims[0] - self.size[0] + 1, size=()).item()) + top = int(torch.randint(0, height - self.size[0] + 1, size=()).item()) self._top = top - left = int(torch.randint(0, input_dims[1] - self.size[1] + 1, size=()).item()) + left = int(torch.randint(0, width - self.size[1] + 1, size=()).item()) self._left = left return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" From f8844f48d42112b31e6d3ab1ccb7da187651896a Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Tue, 2 Dec 2025 20:39:13 -0800 Subject: [PATCH 09/13] Simplify; handle pipelines --- src/torchcodec/_core/SingleStreamDecoder.cpp | 4 +- src/torchcodec/_core/Transform.cpp | 44 ++++++++++++++--- src/torchcodec/_core/Transform.h | 5 +- src/torchcodec/decoders/_video_decoder.py | 8 +-- .../transforms/_decoder_transforms.py | 30 +++--------- test/test_transform_ops.py | 49 ++++++++++++++++++- 6 files changed, 100 insertions(+), 40 deletions(-) diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 8375753ae..1213f381a 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream( metadataDims_ = FrameDims(streamMetadata.height.value(), streamMetadata.width.value()); + FrameDims currInputDims = metadataDims_; for (auto& transform : transforms) { TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); if (transform->getOutputFrameDims().has_value()) { resizedOutputDims_ = transform->getOutputFrameDims().value(); } - transform->validate(streamMetadata); + transform->validate(currInputDims); + currInputDims = resizedOutputDims_.value_or(metadataDims_); // Note that we are claiming ownership of the transform objects passed in to // us. diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index e75fba697..e379c38da 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -53,15 +53,45 @@ std::optional CropTransform::getOutputFrameDims() const { return outputDims_; } -void CropTransform::validate(const StreamMetadata& streamMetadata) const { - TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds"); +void CropTransform::validate(const FrameDims& inputDims) const { TORCH_CHECK( - x_ + outputDims_.width <= streamMetadata.width, - "Crop x position out of bounds") - TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds"); + outputDims_.height <= inputDims.height, + "Crop output height (", + outputDims_.height, + ") is greater than input height (", + inputDims.height, + ")"); TORCH_CHECK( - y_ + outputDims_.height <= streamMetadata.height, - "Crop y position out of bounds"); + outputDims_.width <= inputDims.width, + "Crop output width (", + outputDims_.width, + ") is greater than input width (", + inputDims.width, + ")"); + TORCH_CHECK( + x_ <= inputDims.width, + "Crop x start position, ", + x_, + ", out of bounds of input width, ", + inputDims.width); + TORCH_CHECK( + x_ + outputDims_.width <= inputDims.width, + "Crop x end position, ", + x_ + outputDims_.width, + ", out of bounds of input width ", + inputDims.width); + TORCH_CHECK( + y_ <= inputDims.height, + "Crop y start position, ", + y_, + ", out of bounds of input height, ", + inputDims.height); + TORCH_CHECK( + y_ + outputDims_.height <= inputDims.height, + "Crop y end position, ", + y_ + outputDims_.height, + ", out of bounds of input height ", + inputDims.height); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index d58ab9fc9..4e07e2dc1 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -36,8 +36,7 @@ class Transform { // // Note that the validation function does not return anything. We expect // invalid configurations to throw an exception. - virtual void validate( - [[maybe_unused]] const StreamMetadata& streamMetadata) const {} + virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {} }; class ResizeTransform : public Transform { @@ -64,7 +63,7 @@ class CropTransform : public Transform { std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; - void validate(const StreamMetadata& streamMetadata) const override; + void validate(const FrameDims& inputDims) const override; private: FrameDims outputDims_; diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 0d6e517ba..4a9817616 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -519,7 +519,7 @@ def _make_transform_specs( curr_input_dims = input_dims for transform in transforms: if isinstance(transform, DecoderTransform): - output_dims = transform._calculate_output_dims(curr_input_dims) + output_dims = transform._get_output_dims() converted_transforms.append((transform, curr_input_dims)) else: if not tv_available: @@ -530,11 +530,11 @@ def _make_transform_specs( ) elif isinstance(transform, v2.Resize): tc_transform = Resize._from_torchvision(transform) - output_dims = tc_transform._calculate_output_dims(curr_input_dims) + output_dims = tc_transform._get_output_dims() converted_transforms.append((tc_transform, curr_input_dims)) elif isinstance(transform, v2.RandomCrop): tc_transform = RandomCrop._from_torchvision(transform) - output_dims = tc_transform._calculate_output_dims(curr_input_dims) + output_dims = tc_transform._get_output_dims() converted_transforms.append((tc_transform, curr_input_dims)) else: raise ValueError( @@ -543,7 +543,7 @@ def _make_transform_specs( "v2 transform." ) - curr_input_dims = output_dims + curr_input_dims = output_dims if output_dims is not None else curr_input_dims return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 1c9d73ea1..8601d5b25 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -43,10 +43,11 @@ def _make_transform_spec( ) -> str: pass - def _calculate_output_dims( - self, input_dims: Tuple[Optional[int], Optional[int]] - ) -> Tuple[Optional[int], Optional[int]]: - return input_dims + # Transforms that change the dimensions of their input frame return a value. + # Transforms that don't return None; they can rely on this default + # implementation. + def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: + return None def import_torchvision_transforms_v2() -> ModuleType: @@ -80,9 +81,7 @@ def _make_transform_spec( assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" - def _calculate_output_dims( - self, input_dims: Tuple[Optional[int], Optional[int]] - ) -> Tuple[Optional[int], Optional[int]]: + def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: # TODO: establish this invariant in the constructor during refactor assert len(self.size) == 2 return (self.size[0], self.size[1]) @@ -173,24 +172,9 @@ def _make_transform_spec( return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" - def _calculate_output_dims( - self, input_dims: Tuple[Optional[int], Optional[int]] - ) -> Tuple[Optional[int], Optional[int]]: + def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: # TODO: establish this invariant in the constructor during refactor assert len(self.size) == 2 - - height, width = input_dims - if height is None: - raise ValueError( - "Video metadata has no height. " - "RandomCrop can only be used when input frame dimensions are known." - ) - if width is None: - raise ValueError( - "Video metadata has no width. " - "RandomCrop can only be used when input frame dimensions are known." - ) - return (self.size[0], self.size[1]) @classmethod diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 62890334e..7492a1638 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -257,6 +257,51 @@ def test_crop_fails(self, error_message, params): transforms=[v2.RandomCrop(**params)], ) + @pytest.mark.parametrize("seed", [0, 314]) + def test_random_crop_reusable_objects(self, seed): + torch.manual_seed(seed) + random_crop = torchcodec.transforms.RandomCrop(size=(100, 100)) + + # Create a spec which causes us to calculate the random crop location. + _ = random_crop._make_transform_spec((1000, 1000)) + first_top = random_crop._top + first_left = random_crop._left + + # Create a spec again, which should calculate a different random crop + # location. + _ = random_crop._make_transform_spec((1000, 1000)) + assert first_top != random_crop._top + assert first_left != random_crop._left + + @pytest.mark.parametrize( + "resize, random_crop", + [ + (torchcodec.transforms.Resize, torchcodec.transforms.RandomCrop), + (v2.Resize, v2.RandomCrop), + ], + ) + def test_transform_pipeline(self, resize, random_crop): + decoder = VideoDecoder( + TEST_SRC_2_720P.path, + transforms=[ + # resized to bigger than original + resize(size=(2160, 3840)), + # crop to smaller than the resize, but still bigger than original + random_crop(size=(1080, 1920)), + ], + ) + + num_frames = len(decoder) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame = decoder[frame_index] + assert frame.shape == (TEST_SRC_2_720P.get_num_color_channels(), 1080, 1920) + def test_transform_fails(self): with pytest.raises( ValueError, @@ -519,14 +564,14 @@ def test_crop_transform_fails(self): with pytest.raises( RuntimeError, - match="x position out of bounds", + match="x start position, 9999, out of bounds", ): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100") with pytest.raises( RuntimeError, - match="y position out of bounds", + match=r"Crop output height \(999\) is greater than input height \(270\)", ): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100") From 2fed9c3f35ce835a8f16bdd4ad6f860931dc3150 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 3 Dec 2025 08:18:43 -0800 Subject: [PATCH 10/13] Moare refactor pleasze --- src/torchcodec/decoders/_video_decoder.py | 15 +++----- .../transforms/_decoder_transforms.py | 8 ---- test/test_transform_ops.py | 38 +++++++++++-------- 3 files changed, 27 insertions(+), 34 deletions(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 4a9817616..3620efea2 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -518,10 +518,7 @@ def _make_transform_specs( ] = [] curr_input_dims = input_dims for transform in transforms: - if isinstance(transform, DecoderTransform): - output_dims = transform._get_output_dims() - converted_transforms.append((transform, curr_input_dims)) - else: + if not isinstance(transform, DecoderTransform): if not tv_available: raise ValueError( f"The supplied transform, {transform}, is not a TorchCodec " @@ -529,13 +526,9 @@ def _make_transform_specs( "v2 transforms, but TorchVision is not installed." ) elif isinstance(transform, v2.Resize): - tc_transform = Resize._from_torchvision(transform) - output_dims = tc_transform._get_output_dims() - converted_transforms.append((tc_transform, curr_input_dims)) + transform = Resize._from_torchvision(transform) elif isinstance(transform, v2.RandomCrop): - tc_transform = RandomCrop._from_torchvision(transform) - output_dims = tc_transform._get_output_dims() - converted_transforms.append((tc_transform, curr_input_dims)) + transform = RandomCrop._from_torchvision(transform) else: raise ValueError( f"Unsupported transform: {transform}. Transforms must be " @@ -543,6 +536,8 @@ def _make_transform_specs( "v2 transform." ) + converted_transforms.append((transform, curr_input_dims)) + output_dims = transform._get_output_dims() curr_input_dims = output_dims if output_dims is not None else curr_input_dims return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 8601d5b25..32dc6aae0 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -129,11 +129,6 @@ class RandomCrop(DecoderTransform): size: Sequence[int] - # Note that these values are never read by this object or the decoder. We - # record them for testing purposes only. - _top: Optional[int] = None - _left: Optional[int] = None - def _make_transform_spec( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> str: @@ -165,10 +160,7 @@ def _make_transform_spec( ) top = int(torch.randint(0, height - self.size[0] + 1, size=()).item()) - self._top = top - left = int(torch.randint(0, width - self.size[1] + 1, size=()).item()) - self._left = left return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index 7492a1638..5839f79a4 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -163,11 +163,15 @@ def test_random_crop_torchvision( # We want both kinds of RandomCrop objects to get arrive at the same # locations to crop, so we need to make sure they get the same random - # seed. + # seed. It's used in RandomCrop's _make_transform_spec() method, called + # by the VideoDecoder. torch.manual_seed(seed) tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width)) decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop]) + # Resetting manual seed for when TorchCodec's RandomCrop, created from + # the TorchVision RandomCrop, is used inside of the VideoDecoder. It + # needs to match the call above. torch.manual_seed(seed) decoder_random_crop_tv = VideoDecoder( video.path, @@ -193,14 +197,11 @@ def test_random_crop_torchvision( expected_shape = (video.get_num_color_channels(), height, width) assert frame_random_crop_tv.shape == expected_shape + # Resetting manual seed to make sure the invocation of the + # TorchVision RandomCrop matches the two calls above. + torch.manual_seed(seed) frame_full = decoder_full[frame_index] - frame_tv = v2.functional.crop( - frame_full, - top=tc_random_crop._top, - left=tc_random_crop._left, - height=tc_random_crop.size[0], - width=tc_random_crop.size[1], - ) + frame_tv = v2.RandomCrop(size=(height, width))(frame_full) assert_frames_equal(frame_random_crop, frame_tv) @pytest.mark.parametrize( @@ -260,18 +261,23 @@ def test_crop_fails(self, error_message, params): @pytest.mark.parametrize("seed", [0, 314]) def test_random_crop_reusable_objects(self, seed): torch.manual_seed(seed) - random_crop = torchcodec.transforms.RandomCrop(size=(100, 100)) + random_crop = torchcodec.transforms.RandomCrop(size=(99, 99)) # Create a spec which causes us to calculate the random crop location. - _ = random_crop._make_transform_spec((1000, 1000)) - first_top = random_crop._top - first_left = random_crop._left + first_spec = random_crop._make_transform_spec((888, 888)) # Create a spec again, which should calculate a different random crop - # location. - _ = random_crop._make_transform_spec((1000, 1000)) - assert first_top != random_crop._top - assert first_left != random_crop._left + # location. Despite having the same image size, the specs should be + # different because the crop should be at a different location + second_spec = random_crop._make_transform_spec((888, 888)) + assert first_spec != second_spec + + # Create a spec again, but with a different image size. The specs should + # obviously be different, but the original image size should not be in + # the spec at all. + third_spec = random_crop._make_transform_spec((777, 777)) + assert third_spec != first_spec + assert "888" not in third_spec @pytest.mark.parametrize( "resize, random_crop", From e7367421cca7bc475c9d7f033d4dedbdbd30ad14 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 3 Dec 2025 08:51:35 -0800 Subject: [PATCH 11/13] Better doc strings --- .../transforms/_decoder_transforms.py | 43 ++++++++++++++++--- 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 32dc6aae0..133865532 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -41,12 +41,43 @@ class DecoderTransform(ABC): def _make_transform_spec( self, input_dims: Tuple[Optional[int], Optional[int]] ) -> str: + """Makes the transform spec that is used by the `VideoDecoder`. + + Args: + input_dims (Tuple[Optional[int], Optional[int]]): The dimensions of + the input frame in the form (height, width). We cannot know the + dimensions at object construction time because it's dependent on + the video being decoded and upstream transforms in the same + transform pipeline. Not all transforms need to know this; those + that don't will ignore it. The individual values in the tuple are + optional because the original values come from file metadata which + may be missing. We maintain the optionality throughout the APIs so + that we can decide as late as possible that it's necessary for the + values to exist. That is, if the values are missing from the + metadata and we have transforms which ignore the input dimensions, + we want that to still work. + + Note: This method is the moral equivalent of TorchVision's + `Transformer.make_params()`. + + Returns: + str: A string which contains the spec for the transform that the + `VideoDecoder` knows what to do with. + """ pass - # Transforms that change the dimensions of their input frame return a value. - # Transforms that don't return None; they can rely on this default - # implementation. def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: + """Get the dimensions of the output frame. + + Transforms that change the frame dimensions need to override this + method. Transforms that don't change the frame dimensions can rely on + this default implementation. + + Returns: + Optional[Tuple[Optional[int], Optional[int]]]: The output dimensions. + - None: The output dimensions are the same as the input dimensions. + - (int, int): The (height, width) of the output frame. + """ return None @@ -68,7 +99,7 @@ class Resize(DecoderTransform): Interpolation is always bilinear. Anti-aliasing is always on. Args: - size: (sequence of int): Desired output size. Must be a sequence of + size (Sequence[int]): Desired output size. Must be a sequence of the form (height, width). """ @@ -117,13 +148,13 @@ class RandomCrop(DecoderTransform): Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`. Padding of all kinds is disabled. The random location within the frame is determined during the initialization of the - :class:~`torchcodec.decoders.VideoDecoder` object that owns this transform. + :class:`~torchcodec.decoders.VideoDecoder` object that owns this transform. As a consequence, each decoded frame in the video will be cropped at the same location. Videos with variable resolution may result in undefined behavior. Args: - size: (sequence of int): Desired output size. Must be a sequence of + size (Sequence[int]): Desired output size. Must be a sequence of the form (height, width). """ From 643272706890586e0006bb2229a29cfd45cefcba Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 3 Dec 2025 08:55:39 -0800 Subject: [PATCH 12/13] Comment --- src/torchcodec/decoders/_video_decoder.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 3620efea2..8de497fa7 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -514,7 +514,11 @@ def _make_transform_specs( # dimensions from its input dimensions. We store these with the converted # transform, to be all used together when we generate the specs. converted_transforms: list[ - Tuple[DecoderTransform, Tuple[Optional[int], Optional[int]]] + Tuple[ + DecoderTransform, + # A (height, width) pair where the values may be missing. + Tuple[Optional[int], Optional[int]], + ] ] = [] curr_input_dims = input_dims for transform in transforms: From be8ed26065b7d13df02dbb0f27981fb68d33a216 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 3 Dec 2025 10:22:10 -0800 Subject: [PATCH 13/13] Transformer -> Transform --- src/torchcodec/transforms/_decoder_transforms.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index 133865532..ed38820b1 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -58,7 +58,7 @@ def _make_transform_spec( we want that to still work. Note: This method is the moral equivalent of TorchVision's - `Transformer.make_params()`. + `Transform.make_params()`. Returns: str: A string which contains the spec for the transform that the