diff --git a/src/torchcodec/_core/SingleStreamDecoder.cpp b/src/torchcodec/_core/SingleStreamDecoder.cpp index 8375753ae..1213f381a 100644 --- a/src/torchcodec/_core/SingleStreamDecoder.cpp +++ b/src/torchcodec/_core/SingleStreamDecoder.cpp @@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream( metadataDims_ = FrameDims(streamMetadata.height.value(), streamMetadata.width.value()); + FrameDims currInputDims = metadataDims_; for (auto& transform : transforms) { TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!"); if (transform->getOutputFrameDims().has_value()) { resizedOutputDims_ = transform->getOutputFrameDims().value(); } - transform->validate(streamMetadata); + transform->validate(currInputDims); + currInputDims = resizedOutputDims_.value_or(metadataDims_); // Note that we are claiming ownership of the transform objects passed in to // us. diff --git a/src/torchcodec/_core/Transform.cpp b/src/torchcodec/_core/Transform.cpp index e75fba697..e379c38da 100644 --- a/src/torchcodec/_core/Transform.cpp +++ b/src/torchcodec/_core/Transform.cpp @@ -53,15 +53,45 @@ std::optional CropTransform::getOutputFrameDims() const { return outputDims_; } -void CropTransform::validate(const StreamMetadata& streamMetadata) const { - TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds"); +void CropTransform::validate(const FrameDims& inputDims) const { TORCH_CHECK( - x_ + outputDims_.width <= streamMetadata.width, - "Crop x position out of bounds") - TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds"); + outputDims_.height <= inputDims.height, + "Crop output height (", + outputDims_.height, + ") is greater than input height (", + inputDims.height, + ")"); TORCH_CHECK( - y_ + outputDims_.height <= streamMetadata.height, - "Crop y position out of bounds"); + outputDims_.width <= inputDims.width, + "Crop output width (", + outputDims_.width, + ") is greater than input width (", + inputDims.width, + ")"); + TORCH_CHECK( + x_ <= inputDims.width, + "Crop x start position, ", + x_, + ", out of bounds of input width, ", + inputDims.width); + TORCH_CHECK( + x_ + outputDims_.width <= inputDims.width, + "Crop x end position, ", + x_ + outputDims_.width, + ", out of bounds of input width ", + inputDims.width); + TORCH_CHECK( + y_ <= inputDims.height, + "Crop y start position, ", + y_, + ", out of bounds of input height, ", + inputDims.height); + TORCH_CHECK( + y_ + outputDims_.height <= inputDims.height, + "Crop y end position, ", + y_ + outputDims_.height, + ", out of bounds of input height ", + inputDims.height); } } // namespace facebook::torchcodec diff --git a/src/torchcodec/_core/Transform.h b/src/torchcodec/_core/Transform.h index d58ab9fc9..4e07e2dc1 100644 --- a/src/torchcodec/_core/Transform.h +++ b/src/torchcodec/_core/Transform.h @@ -36,8 +36,7 @@ class Transform { // // Note that the validation function does not return anything. We expect // invalid configurations to throw an exception. - virtual void validate( - [[maybe_unused]] const StreamMetadata& streamMetadata) const {} + virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {} }; class ResizeTransform : public Transform { @@ -64,7 +63,7 @@ class CropTransform : public Transform { std::string getFilterGraphCpu() const override; std::optional getOutputFrameDims() const override; - void validate(const StreamMetadata& streamMetadata) const override; + void validate(const FrameDims& inputDims) const override; private: FrameDims outputDims_; diff --git a/src/torchcodec/_core/custom_ops.cpp b/src/torchcodec/_core/custom_ops.cpp index 4ec72974d..54064942c 100644 --- a/src/torchcodec/_core/custom_ops.cpp +++ b/src/torchcodec/_core/custom_ops.cpp @@ -239,6 +239,19 @@ int checkedToPositiveInt(const std::string& str) { return ret; } +int checkedToNonNegativeInt(const std::string& str) { + int ret = 0; + try { + ret = std::stoi(str); + } catch (const std::invalid_argument&) { + TORCH_CHECK(false, "String cannot be converted to an int:" + str); + } catch (const std::out_of_range&) { + TORCH_CHECK(false, "String would become integer out of range:" + str); + } + TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str); + return ret; +} + // Resize transform specs take the form: // // "resize, , " @@ -270,8 +283,8 @@ Transform* makeCropTransform( "cropTransformSpec must have 5 elements including its name"); int height = checkedToPositiveInt(cropTransformSpec[1]); int width = checkedToPositiveInt(cropTransformSpec[2]); - int x = checkedToPositiveInt(cropTransformSpec[3]); - int y = checkedToPositiveInt(cropTransformSpec[4]); + int x = checkedToNonNegativeInt(cropTransformSpec[3]); + int y = checkedToNonNegativeInt(cropTransformSpec[4]); return new CropTransform(FrameDims(height, width), x, y); } diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 1b4d4706d..8de497fa7 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -8,7 +8,7 @@ import json import numbers from pathlib import Path -from typing import List, Literal, Optional, Sequence, Tuple, Union +from typing import Literal, Optional, Sequence, Tuple, Union import torch from torch import device as torch_device, nn, Tensor @@ -19,7 +19,7 @@ create_decoder, ERROR_REPORTING_INSTRUCTIONS, ) -from torchcodec.transforms import DecoderTransform, Resize +from torchcodec.transforms import DecoderTransform, RandomCrop, Resize class VideoDecoder: @@ -167,7 +167,10 @@ def __init__( device = str(device) device_variant = _get_cuda_backend() - transform_specs = _make_transform_specs(transforms) + transform_specs = _make_transform_specs( + transforms, + input_dims=(self.metadata.height, self.metadata.width), + ) core.add_video_stream( self._decoder, @@ -448,23 +451,33 @@ def _get_and_validate_stream_metadata( ) -def _convert_to_decoder_transforms( - transforms: Sequence[Union[DecoderTransform, nn.Module]], -) -> List[DecoderTransform]: - """Convert a sequence of transforms that may contain TorchVision transform - objects into a list of only TorchCodec transform objects. +def _make_transform_specs( + transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], + input_dims: Tuple[Optional[int], Optional[int]], +) -> str: + """Given a sequence of transforms, turn those into the specification string + the core API expects. Args: - transforms: Squence of transform objects. The objects can be one of two - types: + transforms: Optional sequence of transform objects. The objects can be + one of two types: 1. torchcodec.transforms.DecoderTransform 2. torchvision.transforms.v2.Transform, but our type annotation only mentions its base, nn.Module. We don't want to take a hard dependency on TorchVision. + input_dims: Optional (height, width) pair. Note that only some + transforms need to know the dimensions. If the user provides + transforms that don't need to know the dimensions, and that metadata + is missing, everything should still work. That means we assert their + existence as late as possible. Returns: - List of DecoderTransform objects. + String of transforms in the format the core API expects: transform + specifications separate by semicolons. """ + if transforms is None: + return "" + try: from torchvision.transforms import v2 @@ -472,52 +485,66 @@ def _convert_to_decoder_transforms( except ImportError: tv_available = False - converted_transforms: list[DecoderTransform] = [] + # The following loop accomplishes two tasks: + # + # 1. Converts the transform to a DecoderTransform, if necessary. We + # accept TorchVision transform objects and they must be converted + # to their matching DecoderTransform. + # 2. Calculates what the input dimensions are to each transform. + # + # The order in our transforms list is semantically meaningful, as we + # actually have a pipeline where the output of one transform is the input to + # the next. For example, if we have the transforms list [A, B, C, D], then + # we should understand that as: + # + # A -> B -> C -> D + # + # Where the frame produced by A is the input to B, the frame produced by B + # is the input to C, etc. This particularly matters for frame dimensions. + # Transforms can both: + # + # 1. Produce frames with arbitrary dimensions. + # 2. Rely on their input frame's dimensions to calculate ahead-of-time + # what their runtime behavior will be. + # + # The consequence of the above facts is that we need to statically track + # frame dimensions in the pipeline while we pre-process it. The input + # frame's dimensions to A, our first transform, is always what we know from + # our metadata. For each transform, we always calculate its output + # dimensions from its input dimensions. We store these with the converted + # transform, to be all used together when we generate the specs. + converted_transforms: list[ + Tuple[ + DecoderTransform, + # A (height, width) pair where the values may be missing. + Tuple[Optional[int], Optional[int]], + ] + ] = [] + curr_input_dims = input_dims for transform in transforms: if not isinstance(transform, DecoderTransform): if not tv_available: raise ValueError( f"The supplied transform, {transform}, is not a TorchCodec " - " DecoderTransform. TorchCodec also accept TorchVision " + " DecoderTransform. TorchCodec also accepts TorchVision " "v2 transforms, but TorchVision is not installed." ) elif isinstance(transform, v2.Resize): - converted_transforms.append(Resize._from_torchvision(transform)) + transform = Resize._from_torchvision(transform) + elif isinstance(transform, v2.RandomCrop): + transform = RandomCrop._from_torchvision(transform) else: raise ValueError( f"Unsupported transform: {transform}. Transforms must be " "either a TorchCodec DecoderTransform or a TorchVision " "v2 transform." ) - else: - converted_transforms.append(transform) - - return converted_transforms + converted_transforms.append((transform, curr_input_dims)) + output_dims = transform._get_output_dims() + curr_input_dims = output_dims if output_dims is not None else curr_input_dims -def _make_transform_specs( - transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]], -) -> str: - """Given a sequence of transforms, turn those into the specification string - the core API expects. - - Args: - transforms: Optional sequence of transform objects. The objects can be - one of two types: - 1. torchcodec.transforms.DecoderTransform - 2. torchvision.transforms.v2.Transform, but our type annotation - only mentions its base, nn.Module. We don't want to take a - hard dependency on TorchVision. - - Returns: - String of transforms in the format the core API expects: transform - specifications separate by semicolons. - """ - if transforms is None: - return "" - - transforms = _convert_to_decoder_transforms(transforms) - return ";".join([t._make_transform_spec() for t in transforms]) + return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms]) def _read_custom_frame_mappings( diff --git a/src/torchcodec/transforms/__init__.py b/src/torchcodec/transforms/__init__.py index 9f4a92f81..c93bad39e 100644 --- a/src/torchcodec/transforms/__init__.py +++ b/src/torchcodec/transforms/__init__.py @@ -4,4 +4,4 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from ._decoder_transforms import DecoderTransform, Resize # noqa +from ._decoder_transforms import DecoderTransform, RandomCrop, Resize # noqa diff --git a/src/torchcodec/transforms/_decoder_transforms.py b/src/torchcodec/transforms/_decoder_transforms.py index dec4704b0..ed38820b1 100644 --- a/src/torchcodec/transforms/_decoder_transforms.py +++ b/src/torchcodec/transforms/_decoder_transforms.py @@ -7,8 +7,9 @@ from abc import ABC, abstractmethod from dataclasses import dataclass from types import ModuleType -from typing import Sequence +from typing import Optional, Sequence, Tuple +import torch from torch import nn @@ -22,8 +23,8 @@ class DecoderTransform(ABC): decoded frames and applying the same kind of transform. Most ``DecoderTransform`` objects have a complementary transform in TorchVision, - specificially in `torchvision.transforms.v2 `_. For such transforms, we - ensure that: + specificially in `torchvision.transforms.v2 `_. + For such transforms, we ensure that: 1. The names are the same. 2. Default behaviors are the same. @@ -37,9 +38,48 @@ class DecoderTransform(ABC): """ @abstractmethod - def _make_transform_spec(self) -> str: + def _make_transform_spec( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> str: + """Makes the transform spec that is used by the `VideoDecoder`. + + Args: + input_dims (Tuple[Optional[int], Optional[int]]): The dimensions of + the input frame in the form (height, width). We cannot know the + dimensions at object construction time because it's dependent on + the video being decoded and upstream transforms in the same + transform pipeline. Not all transforms need to know this; those + that don't will ignore it. The individual values in the tuple are + optional because the original values come from file metadata which + may be missing. We maintain the optionality throughout the APIs so + that we can decide as late as possible that it's necessary for the + values to exist. That is, if the values are missing from the + metadata and we have transforms which ignore the input dimensions, + we want that to still work. + + Note: This method is the moral equivalent of TorchVision's + `Transform.make_params()`. + + Returns: + str: A string which contains the spec for the transform that the + `VideoDecoder` knows what to do with. + """ pass + def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: + """Get the dimensions of the output frame. + + Transforms that change the frame dimensions need to override this + method. Transforms that don't change the frame dimensions can rely on + this default implementation. + + Returns: + Optional[Tuple[Optional[int], Optional[int]]]: The output dimensions. + - None: The output dimensions are the same as the input dimensions. + - (int, int): The (height, width) of the output frame. + """ + return None + def import_torchvision_transforms_v2() -> ModuleType: try: @@ -59,35 +99,141 @@ class Resize(DecoderTransform): Interpolation is always bilinear. Anti-aliasing is always on. Args: - size: (sequence of int): Desired output size. Must be a sequence of + size (Sequence[int]): Desired output size. Must be a sequence of the form (height, width). """ size: Sequence[int] - def _make_transform_spec(self) -> str: + def _make_transform_spec( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> str: + # TODO: establish this invariant in the constructor during refactor assert len(self.size) == 2 return f"resize, {self.size[0]}, {self.size[1]}" + def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: + # TODO: establish this invariant in the constructor during refactor + assert len(self.size) == 2 + return (self.size[0], self.size[1]) + @classmethod - def _from_torchvision(cls, resize_tv: nn.Module): + def _from_torchvision(cls, tv_resize: nn.Module): v2 = import_torchvision_transforms_v2() - assert isinstance(resize_tv, v2.Resize) + assert isinstance(tv_resize, v2.Resize) - if resize_tv.interpolation is not v2.InterpolationMode.BILINEAR: + if tv_resize.interpolation is not v2.InterpolationMode.BILINEAR: raise ValueError( "TorchVision Resize transform must use bilinear interpolation." ) - if resize_tv.antialias is False: + if tv_resize.antialias is False: raise ValueError( "TorchVision Resize transform must have antialias enabled." ) - if resize_tv.size is None: + if tv_resize.size is None: raise ValueError("TorchVision Resize transform must have a size specified.") - if len(resize_tv.size) != 2: + if len(tv_resize.size) != 2: raise ValueError( "TorchVision Resize transform must have a (height, width) " - f"pair for the size, got {resize_tv.size}." + f"pair for the size, got {tv_resize.size}." + ) + return cls(size=tv_resize.size) + + +@dataclass +class RandomCrop(DecoderTransform): + """Crop the decoded frame to a given size at a random location in the frame. + + Complementary TorchVision transform: :class:`~torchvision.transforms.v2.RandomCrop`. + Padding of all kinds is disabled. The random location within the frame is + determined during the initialization of the + :class:`~torchcodec.decoders.VideoDecoder` object that owns this transform. + As a consequence, each decoded frame in the video will be cropped at the + same location. Videos with variable resolution may result in undefined + behavior. + + Args: + size (Sequence[int]): Desired output size. Must be a sequence of + the form (height, width). + """ + + size: Sequence[int] + + def _make_transform_spec( + self, input_dims: Tuple[Optional[int], Optional[int]] + ) -> str: + if len(self.size) != 2: + raise ValueError( + f"RandomCrop's size must be a sequence of length 2, got {self.size}. " + "This should never happen, please report a bug." + ) + + height, width = input_dims + if height is None: + raise ValueError( + "Video metadata has no height. " + "RandomCrop can only be used when input frame dimensions are known." + ) + if width is None: + raise ValueError( + "Video metadata has no width. " + "RandomCrop can only be used when input frame dimensions are known." + ) + + # Note: This logic below must match the logic in + # torchvision.transforms.v2.RandomCrop.make_params(). Given + # the same seed, they should get the same result. This is an + # API guarantee with our users. + if height < self.size[0] or width < self.size[1]: + raise ValueError( + f"Input dimensions {input_dims} are smaller than the crop size {self.size}." + ) + + top = int(torch.randint(0, height - self.size[0] + 1, size=()).item()) + left = int(torch.randint(0, width - self.size[1] + 1, size=()).item()) + + return f"crop, {self.size[0]}, {self.size[1]}, {left}, {top}" + + def _get_output_dims(self) -> Optional[Tuple[Optional[int], Optional[int]]]: + # TODO: establish this invariant in the constructor during refactor + assert len(self.size) == 2 + return (self.size[0], self.size[1]) + + @classmethod + def _from_torchvision( + cls, + tv_random_crop: nn.Module, + ): + v2 = import_torchvision_transforms_v2() + + if not isinstance(tv_random_crop, v2.RandomCrop): + raise ValueError( + "Transform must be TorchVision's RandomCrop, " + f"it is instead {type(tv_random_crop).__name__}. " + "This should never happen, please report a bug." ) - return cls(size=resize_tv.size) + + if tv_random_crop.padding is not None: + raise ValueError( + "TorchVision RandomCrop transform must not specify padding." + ) + + if tv_random_crop.pad_if_needed is True: + raise ValueError( + "TorchVision RandomCrop transform must not specify pad_if_needed." + ) + + if tv_random_crop.fill != 0: + raise ValueError("TorchVision RandomCrop fill must be 0.") + + if tv_random_crop.padding_mode != "constant": + raise ValueError("TorchVision RandomCrop padding_mode must be constant.") + + if len(tv_random_crop.size) != 2: + raise ValueError( + "TorchVision RandcomCrop transform must have a (height, width) " + f"pair for the size, got {tv_random_crop.size}." + ) + + return cls(size=tv_random_crop.size) diff --git a/test/test_transform_ops.py b/test/test_transform_ops.py index bc42732ef..5839f79a4 100644 --- a/test/test_transform_ops.py +++ b/test/test_transform_ops.py @@ -145,6 +145,169 @@ def test_resize_fails(self): ): VideoDecoder(NASA_VIDEO.path, transforms=[v2.Resize(size=(100))]) + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((0.5, 0.5), (0.25, 0.1), (1.0, 1.0), (0.15, 0.75)), + ) + @pytest.mark.parametrize("video", [NASA_VIDEO, TEST_SRC_2_720P]) + @pytest.mark.parametrize("seed", [0, 1234]) + def test_random_crop_torchvision( + self, + height_scaling_factor, + width_scaling_factor, + video, + seed, + ): + height = int(video.get_height() * height_scaling_factor) + width = int(video.get_width() * width_scaling_factor) + + # We want both kinds of RandomCrop objects to get arrive at the same + # locations to crop, so we need to make sure they get the same random + # seed. It's used in RandomCrop's _make_transform_spec() method, called + # by the VideoDecoder. + torch.manual_seed(seed) + tc_random_crop = torchcodec.transforms.RandomCrop(size=(height, width)) + decoder_random_crop = VideoDecoder(video.path, transforms=[tc_random_crop]) + + # Resetting manual seed for when TorchCodec's RandomCrop, created from + # the TorchVision RandomCrop, is used inside of the VideoDecoder. It + # needs to match the call above. + torch.manual_seed(seed) + decoder_random_crop_tv = VideoDecoder( + video.path, + transforms=[v2.RandomCrop(size=(height, width))], + ) + + decoder_full = VideoDecoder(video.path) + + num_frames = len(decoder_random_crop_tv) + assert num_frames == len(decoder_full) + + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame_random_crop = decoder_random_crop[frame_index] + frame_random_crop_tv = decoder_random_crop_tv[frame_index] + assert_frames_equal(frame_random_crop, frame_random_crop_tv) + + expected_shape = (video.get_num_color_channels(), height, width) + assert frame_random_crop_tv.shape == expected_shape + + # Resetting manual seed to make sure the invocation of the + # TorchVision RandomCrop matches the two calls above. + torch.manual_seed(seed) + frame_full = decoder_full[frame_index] + frame_tv = v2.RandomCrop(size=(height, width))(frame_full) + assert_frames_equal(frame_random_crop, frame_tv) + + @pytest.mark.parametrize( + "height_scaling_factor, width_scaling_factor", + ((0.25, 0.1), (0.25, 0.25)), + ) + def test_random_crop_nhwc( + self, + height_scaling_factor, + width_scaling_factor, + ): + height = int(TEST_SRC_2_720P.get_height() * height_scaling_factor) + width = int(TEST_SRC_2_720P.get_width() * width_scaling_factor) + + decoder = VideoDecoder( + TEST_SRC_2_720P.path, + transforms=[torchcodec.transforms.RandomCrop(size=(height, width))], + dimension_order="NHWC", + ) + + num_frames = len(decoder) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame = decoder[frame_index] + assert frame.shape == (height, width, 3) + + @pytest.mark.parametrize( + "error_message, params", + ( + ("must not specify padding", dict(size=(100, 100), padding=255)), + ( + "must not specify pad_if_needed", + dict(size=(100, 100), pad_if_needed=True), + ), + ("fill must be 0", dict(size=(100, 100), fill=255)), + ( + "padding_mode must be constant", + dict(size=(100, 100), padding_mode="edge"), + ), + ), + ) + def test_crop_fails(self, error_message, params): + with pytest.raises( + ValueError, + match=error_message, + ): + VideoDecoder( + NASA_VIDEO.path, + transforms=[v2.RandomCrop(**params)], + ) + + @pytest.mark.parametrize("seed", [0, 314]) + def test_random_crop_reusable_objects(self, seed): + torch.manual_seed(seed) + random_crop = torchcodec.transforms.RandomCrop(size=(99, 99)) + + # Create a spec which causes us to calculate the random crop location. + first_spec = random_crop._make_transform_spec((888, 888)) + + # Create a spec again, which should calculate a different random crop + # location. Despite having the same image size, the specs should be + # different because the crop should be at a different location + second_spec = random_crop._make_transform_spec((888, 888)) + assert first_spec != second_spec + + # Create a spec again, but with a different image size. The specs should + # obviously be different, but the original image size should not be in + # the spec at all. + third_spec = random_crop._make_transform_spec((777, 777)) + assert third_spec != first_spec + assert "888" not in third_spec + + @pytest.mark.parametrize( + "resize, random_crop", + [ + (torchcodec.transforms.Resize, torchcodec.transforms.RandomCrop), + (v2.Resize, v2.RandomCrop), + ], + ) + def test_transform_pipeline(self, resize, random_crop): + decoder = VideoDecoder( + TEST_SRC_2_720P.path, + transforms=[ + # resized to bigger than original + resize(size=(2160, 3840)), + # crop to smaller than the resize, but still bigger than original + random_crop(size=(1080, 1920)), + ], + ) + + num_frames = len(decoder) + for frame_index in [ + 0, + int(num_frames * 0.25), + int(num_frames * 0.5), + int(num_frames * 0.75), + num_frames - 1, + ]: + frame = decoder[frame_index] + assert frame.shape == (TEST_SRC_2_720P.get_num_color_channels(), 1080, 1920) + def test_transform_fails(self): with pytest.raises( ValueError, @@ -407,14 +570,14 @@ def test_crop_transform_fails(self): with pytest.raises( RuntimeError, - match="x position out of bounds", + match="x start position, 9999, out of bounds", ): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, transform_specs="crop, 100, 100, 9999, 100") with pytest.raises( RuntimeError, - match="y position out of bounds", + match=r"Crop output height \(999\) is greater than input height \(270\)", ): decoder = create_from_file(str(NASA_VIDEO.path)) add_video_stream(decoder, transform_specs="crop, 999, 100, 100, 100")