Skip to content
Merged
4 changes: 3 additions & 1 deletion src/torchcodec/_core/SingleStreamDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,14 @@ void SingleStreamDecoder::addVideoStream(

metadataDims_ =
FrameDims(streamMetadata.height.value(), streamMetadata.width.value());
FrameDims currInputDims = metadataDims_;
for (auto& transform : transforms) {
TORCH_CHECK(transform != nullptr, "Transforms should never be nullptr!");
if (transform->getOutputFrameDims().has_value()) {
resizedOutputDims_ = transform->getOutputFrameDims().value();
}
transform->validate(streamMetadata);
transform->validate(currInputDims);
currInputDims = resizedOutputDims_.value_or(metadataDims_);

// Note that we are claiming ownership of the transform objects passed in to
// us.
Expand Down
44 changes: 37 additions & 7 deletions src/torchcodec/_core/Transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,45 @@ std::optional<FrameDims> CropTransform::getOutputFrameDims() const {
return outputDims_;
}

void CropTransform::validate(const StreamMetadata& streamMetadata) const {
TORCH_CHECK(x_ <= streamMetadata.width, "Crop x position out of bounds");
void CropTransform::validate(const FrameDims& inputDims) const {
TORCH_CHECK(
x_ + outputDims_.width <= streamMetadata.width,
"Crop x position out of bounds")
TORCH_CHECK(y_ <= streamMetadata.height, "Crop y position out of bounds");
outputDims_.height <= inputDims.height,
"Crop output height (",
outputDims_.height,
") is greater than input height (",
inputDims.height,
")");
TORCH_CHECK(
y_ + outputDims_.height <= streamMetadata.height,
"Crop y position out of bounds");
outputDims_.width <= inputDims.width,
"Crop output width (",
outputDims_.width,
") is greater than input width (",
inputDims.width,
")");
TORCH_CHECK(
x_ <= inputDims.width,
"Crop x start position, ",
x_,
", out of bounds of input width, ",
inputDims.width);
TORCH_CHECK(
x_ + outputDims_.width <= inputDims.width,
"Crop x end position, ",
x_ + outputDims_.width,
", out of bounds of input width ",
inputDims.width);
TORCH_CHECK(
y_ <= inputDims.height,
"Crop y start position, ",
y_,
", out of bounds of input height, ",
inputDims.height);
TORCH_CHECK(
y_ + outputDims_.height <= inputDims.height,
"Crop y end position, ",
y_ + outputDims_.height,
", out of bounds of input height ",
inputDims.height);
}

} // namespace facebook::torchcodec
5 changes: 2 additions & 3 deletions src/torchcodec/_core/Transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,7 @@ class Transform {
//
// Note that the validation function does not return anything. We expect
// invalid configurations to throw an exception.
virtual void validate(
[[maybe_unused]] const StreamMetadata& streamMetadata) const {}
virtual void validate([[maybe_unused]] const FrameDims& inputDims) const {}
};

class ResizeTransform : public Transform {
Expand All @@ -64,7 +63,7 @@ class CropTransform : public Transform {

std::string getFilterGraphCpu() const override;
std::optional<FrameDims> getOutputFrameDims() const override;
void validate(const StreamMetadata& streamMetadata) const override;
void validate(const FrameDims& inputDims) const override;

private:
FrameDims outputDims_;
Expand Down
17 changes: 15 additions & 2 deletions src/torchcodec/_core/custom_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,19 @@ int checkedToPositiveInt(const std::string& str) {
return ret;
}

int checkedToNonNegativeInt(const std::string& str) {
int ret = 0;
try {
ret = std::stoi(str);
} catch (const std::invalid_argument&) {
TORCH_CHECK(false, "String cannot be converted to an int:" + str);
} catch (const std::out_of_range&) {
TORCH_CHECK(false, "String would become integer out of range:" + str);
}
TORCH_CHECK(ret >= 0, "String must be a non-negative integer:" + str);
return ret;
}

// Resize transform specs take the form:
//
// "resize, <height>, <width>"
Expand Down Expand Up @@ -270,8 +283,8 @@ Transform* makeCropTransform(
"cropTransformSpec must have 5 elements including its name");
int height = checkedToPositiveInt(cropTransformSpec[1]);
int width = checkedToPositiveInt(cropTransformSpec[2]);
int x = checkedToPositiveInt(cropTransformSpec[3]);
int y = checkedToPositiveInt(cropTransformSpec[4]);
int x = checkedToNonNegativeInt(cropTransformSpec[3]);
int y = checkedToNonNegativeInt(cropTransformSpec[4]);
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The location (0, 0) is a valid image location. 🤦

return new CropTransform(FrameDims(height, width), x, y);
}

Expand Down
109 changes: 68 additions & 41 deletions src/torchcodec/decoders/_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import json
import numbers
from pathlib import Path
from typing import List, Literal, Optional, Sequence, Tuple, Union
from typing import Literal, Optional, Sequence, Tuple, Union

import torch
from torch import device as torch_device, nn, Tensor
Expand All @@ -19,7 +19,7 @@
create_decoder,
ERROR_REPORTING_INSTRUCTIONS,
)
from torchcodec.transforms import DecoderTransform, Resize
from torchcodec.transforms import DecoderTransform, RandomCrop, Resize


class VideoDecoder:
Expand Down Expand Up @@ -167,7 +167,10 @@ def __init__(
device = str(device)

device_variant = _get_cuda_backend()
transform_specs = _make_transform_specs(transforms)
transform_specs = _make_transform_specs(
transforms,
input_dims=(self.metadata.height, self.metadata.width),
)

core.add_video_stream(
self._decoder,
Expand Down Expand Up @@ -448,76 +451,100 @@ def _get_and_validate_stream_metadata(
)


def _convert_to_decoder_transforms(
transforms: Sequence[Union[DecoderTransform, nn.Module]],
) -> List[DecoderTransform]:
"""Convert a sequence of transforms that may contain TorchVision transform
objects into a list of only TorchCodec transform objects.
def _make_transform_specs(
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
input_dims: Tuple[Optional[int], Optional[int]],
) -> str:
"""Given a sequence of transforms, turn those into the specification string
the core API expects.

Args:
transforms: Squence of transform objects. The objects can be one of two
types:
transforms: Optional sequence of transform objects. The objects can be
one of two types:
1. torchcodec.transforms.DecoderTransform
2. torchvision.transforms.v2.Transform, but our type annotation
only mentions its base, nn.Module. We don't want to take a
hard dependency on TorchVision.
input_dims: Optional (height, width) pair. Note that only some
transforms need to know the dimensions. If the user provides
transforms that don't need to know the dimensions, and that metadata
is missing, everything should still work. That means we assert their
existence as late as possible.

Returns:
List of DecoderTransform objects.
String of transforms in the format the core API expects: transform
specifications separate by semicolons.
"""
if transforms is None:
return ""

try:
from torchvision.transforms import v2

tv_available = True
except ImportError:
tv_available = False

converted_transforms: list[DecoderTransform] = []
# The following loop accomplishes two tasks:
#
# 1. Converts the transform to a DecoderTransform, if necessary. We
# accept TorchVision transform objects and they must be converted
# to their matching DecoderTransform.
# 2. Calculates what the input dimensions are to each transform.
#
# The order in our transforms list is semantically meaningful, as we
# actually have a pipeline where the output of one transform is the input to
# the next. For example, if we have the transforms list [A, B, C, D], then
# we should understand that as:
#
# A -> B -> C -> D
#
# Where the frame produced by A is the input to B, the frame produced by B
# is the input to C, etc. This particularly matters for frame dimensions.
# Transforms can both:
#
# 1. Produce frames with arbitrary dimensions.
# 2. Rely on their input frame's dimensions to calculate ahead-of-time
# what their runtime behavior will be.
#
# The consequence of the above facts is that we need to statically track
# frame dimensions in the pipeline while we pre-process it. The input
# frame's dimensions to A, our first transform, is always what we know from
# our metadata. For each transform, we always calculate its output
# dimensions from its input dimensions. We store these with the converted
# transform, to be all used together when we generate the specs.
converted_transforms: list[
Tuple[
DecoderTransform,
# A (height, width) pair where the values may be missing.
Tuple[Optional[int], Optional[int]],
]
] = []
curr_input_dims = input_dims
for transform in transforms:
if not isinstance(transform, DecoderTransform):
if not tv_available:
raise ValueError(
f"The supplied transform, {transform}, is not a TorchCodec "
" DecoderTransform. TorchCodec also accept TorchVision "
" DecoderTransform. TorchCodec also accepts TorchVision "
"v2 transforms, but TorchVision is not installed."
)
elif isinstance(transform, v2.Resize):
converted_transforms.append(Resize._from_torchvision(transform))
transform = Resize._from_torchvision(transform)
elif isinstance(transform, v2.RandomCrop):
transform = RandomCrop._from_torchvision(transform)
else:
raise ValueError(
f"Unsupported transform: {transform}. Transforms must be "
"either a TorchCodec DecoderTransform or a TorchVision "
"v2 transform."
)
else:
converted_transforms.append(transform)

return converted_transforms

converted_transforms.append((transform, curr_input_dims))
output_dims = transform._get_output_dims()
curr_input_dims = output_dims if output_dims is not None else curr_input_dims

def _make_transform_specs(
transforms: Optional[Sequence[Union[DecoderTransform, nn.Module]]],
) -> str:
"""Given a sequence of transforms, turn those into the specification string
the core API expects.

Args:
transforms: Optional sequence of transform objects. The objects can be
one of two types:
1. torchcodec.transforms.DecoderTransform
2. torchvision.transforms.v2.Transform, but our type annotation
only mentions its base, nn.Module. We don't want to take a
hard dependency on TorchVision.

Returns:
String of transforms in the format the core API expects: transform
specifications separate by semicolons.
"""
if transforms is None:
return ""

transforms = _convert_to_decoder_transforms(transforms)
return ";".join([t._make_transform_spec() for t in transforms])
return ";".join([t._make_transform_spec(dims) for t, dims in converted_transforms])


def _read_custom_frame_mappings(
Expand Down
2 changes: 1 addition & 1 deletion src/torchcodec/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from ._decoder_transforms import DecoderTransform, Resize # noqa
from ._decoder_transforms import DecoderTransform, RandomCrop, Resize # noqa
Loading
Loading