Skip to content
91 changes: 85 additions & 6 deletions src/transformers/image_processing_utils_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,6 @@ def validate_fast_preprocess_arguments(
do_normalize: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[SizeDict] = None,
do_resize: Optional[bool] = None,
Expand All @@ -99,8 +97,6 @@ def validate_fast_preprocess_arguments(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisibility,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
Expand Down Expand Up @@ -181,6 +177,8 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False):
do_normalize: Optional[bool]
image_mean: Optional[Union[float, list[float]]]
image_std: Optional[Union[float, list[float]]]
do_pad: Optional[bool]
pad_size: Optional[dict[str, int]]
do_convert_rgb: Optional[bool]
return_tensors: Optional[Union[str, TensorType]]
data_format: Optional[ChannelDimension]
Expand All @@ -199,6 +197,8 @@ class BaseImageProcessorFast(BaseImageProcessor):
crop_size = None
do_resize = None
do_center_crop = None
do_pad = None
pad_size = None
do_rescale = None
rescale_factor = 1 / 255
do_normalize = None
Expand All @@ -222,6 +222,9 @@ def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]):
)
crop_size = kwargs.pop("crop_size", self.crop_size)
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None
pad_size = kwargs.pop("pad_size", self.pad_size)
self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None

for key in self.valid_kwargs.__annotations__:
kwarg = kwargs.pop(key, None)
if kwarg is not None:
Expand All @@ -239,6 +242,74 @@ def is_fast(self) -> bool:
"""
return True

def pad(
self,
images: "torch.Tensor",
pad_size: SizeDict = None,
fill_value: Optional[int] = 0,
padding_mode: Optional[str] = "constant",
return_mask: Optional[bool] = False,
disable_grouping: Optional[bool] = False,
**kwargs,
) -> "torch.Tensor":
"""
Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch.

Args:
images (`torch.Tensor`):
Images to pad.
pad_size (`SizeDict`, *optional*):
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image.
fill_value (`int`, *optional*, defaults to `0`):
The constant value used to fill the padded area.
padding_mode (`str`, *optional*, defaults to "constant"):
The padding mode to use. Can be any of the modes supported by
`torch.nn.functional.pad` (e.g. constant, reflection, replication).
return_mask (`bool`, *optional*, defaults to `False`):
Whether to return a pixel mask to denote padded regions.
disable_grouping (`bool`, *optional*, defaults to `False`):
Whether to disable grouping of images by size.

Returns:
`torch.Tensor`: The resized image.
"""
if pad_size is not None:
if not (pad_size.height and pad_size.width):
raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.")
pad_size = (pad_size.height, pad_size.width)
else:
pad_size = get_max_height_width(images)

grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_masks_grouped = {}
for shape, stacked_images in grouped_images.items():
image_size = stacked_images.shape[-2:]
padding_height = pad_size[0] - image_size[0]
padding_width = pad_size[1] - image_size[1]
if padding_height < 0 or padding_width < 0:
raise ValueError(
f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the "
f"image size. Got pad_size={pad_size}, image_size={image_size}."
)
if image_size != pad_size:
padding = (0, 0, padding_width, padding_height)
stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode)
processed_images_grouped[shape] = stacked_images

if return_mask:
# keep only one from the channel dimension in pixel mask
stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :]
stacked_masks[..., : image_size[0], : image_size[1]] = 1
processed_masks_grouped[shape] = stacked_masks

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
if return_mask:
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)
return processed_images, processed_masks

return processed_images
Comment on lines +245 to +311
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also wondering if we should include the whole group and reorder process in other functions like resize or rescale and normalize. Bu that means we can't have several processing functions under one grouping loop. And a new grouping is only needed if the previous function changed the shape of the images (not need after a rescale and normalize for example), and they do cost a bit of processing time, so it might be best to leave this to the _preprocess function. What do you think?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The grouping for padding was added only because it's supposed to be called outside of the loop if we want to pad all images to the same size. I see what you mean and it's partially related to above thread. We need to handle two types of paddings. WDYT of adding an arg like group_images=False ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh yes that makes sense then, if we're padding to the maximum size. I don't think we need to add anything, we can just override the function when we have a different type of padding needed

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, that is what most processors do. Though it's not "overriding", they use fn called pad_to_square which we can actually standardize and add in Base class as well in the next iteration


def resize(
self,
image: "torch.Tensor",
Expand Down Expand Up @@ -577,6 +648,7 @@ def _further_process_kwargs(
self,
size: Optional[SizeDict] = None,
crop_size: Optional[SizeDict] = None,
pad_size: Optional[SizeDict] = None,
default_to_square: Optional[bool] = None,
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
Expand All @@ -593,6 +665,8 @@ def _further_process_kwargs(
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square))
if crop_size is not None:
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size"))
if pad_size is not None:
pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size"))
if isinstance(image_mean, list):
image_mean = tuple(image_mean)
if isinstance(image_std, list):
Expand All @@ -602,6 +676,7 @@ def _further_process_kwargs(

kwargs["size"] = size
kwargs["crop_size"] = crop_size
kwargs["pad_size"] = pad_size
kwargs["image_mean"] = image_mean
kwargs["image_std"] = image_std
kwargs["data_format"] = data_format
Expand Down Expand Up @@ -714,6 +789,8 @@ def _preprocess(
do_normalize: bool,
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: Optional[bool],
pad_size: Optional[SizeDict],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
Expand All @@ -739,10 +816,12 @@ def _preprocess(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
processed_images_grouped[shape] = stacked_images

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images

if do_pad:
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping)

processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)

def to_dict(self):
Expand Down
13 changes: 9 additions & 4 deletions src/transformers/image_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,7 +525,7 @@ def validate_preprocess_arguments(
image_mean: Optional[Union[float, list[float]]] = None,
image_std: Optional[Union[float, list[float]]] = None,
do_pad: Optional[bool] = None,
size_divisibility: Optional[int] = None,
pad_size: Optional[Union[dict[str, int], int]] = None,
do_center_crop: Optional[bool] = None,
crop_size: Optional[dict[str, int]] = None,
do_resize: Optional[bool] = None,
Expand All @@ -544,10 +544,15 @@ def validate_preprocess_arguments(
if do_rescale and rescale_factor is None:
raise ValueError("`rescale_factor` must be specified if `do_rescale` is `True`.")

if do_pad and size_divisibility is None:
# Here, size_divisor might be passed as the value of size
if do_pad and pad_size is None:
# Processors pad images using different args depending on the model, so the below check is pointless
# but we keep it for BC for now. TODO: remove in v5
# Usually padding can be called with:
# - "pad_size/size" if we're padding to specific values
# - "size_divisor" if we're padding to any value divisible by X
# - "None" if we're padding to the maximum size image in batch
raise ValueError(
"Depending on the model, `size_divisibility`, `size_divisor`, `pad_size` or `size` must be specified if `do_pad` is `True`."
"Depending on the model, `size_divisor` or `pad_size` or `size` must be specified if `do_pad` is `True`."
)

if do_normalize and (image_mean is None or image_std is None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,6 @@ def preprocess(
do_normalize=do_normalize,
image_mean=image_mean,
image_std=image_std,
do_pad=do_pad,
size_divisibility=size_divisor,
do_center_crop=do_center_crop,
crop_size=crop_size,
do_resize=do_resize,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
SizeDict,
TensorType,
Unpack,
get_max_height_width,
group_images_by_shape,
reorder_images,
)
Expand Down Expand Up @@ -99,13 +98,9 @@ class BridgeTowerFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
size_divisor (`int`, *optional*, defaults to 32):
The size by which to make sure both the height and width can be divided. Only has an effect if `do_resize`
is set to `True`. Can be overridden by the `size_divisor` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Whether to pad the image to the `(max_height, max_width)` of the images in the batch. Can be overridden by
the `do_pad` parameter in the `preprocess` method.
"""

size_divisor: Optional[int]
do_pad: Optional[bool]


@auto_docstring
Expand Down Expand Up @@ -224,59 +219,6 @@ def _pad_image(
)
return padded_image

def pad(
self,
images: list["torch.Tensor"],
constant_values: Union[float, Iterable[float]] = 0,
return_pixel_mask: bool = True,
disable_grouping: Optional[bool] = False,
) -> tuple:
"""
Pads a batch of images to the bottom and right of the image with zeros to the size of largest height and width
in the batch and optionally returns their corresponding pixel mask.

Args:
image (`torch.Tensor`):
Image to pad.
constant_values (`float` or `Iterable[float]`, *optional*):
The value to use for the padding if `mode` is `"constant"`.
return_pixel_mask (`bool`, *optional*, defaults to `True`):
Whether to return a pixel mask.
disable_grouping (`bool`, *optional*, defaults to `False`):
Whether to disable grouping of images by size.
return_tensors (`str` or `TensorType`, *optional*):
The type of tensors to return. Can be one of:
- Unset: Return a list of `np.ndarray`.
- `TensorType.TENSORFLOW` or `'tf'`: Return a batch of type `tf.Tensor`.
- `TensorType.PYTORCH` or `'pt'`: Return a batch of type `torch.Tensor`.
- `TensorType.NUMPY` or `'np'`: Return a batch of type `np.ndarray`.
- `TensorType.JAX` or `'jax'`: Return a batch of type `jax.numpy.ndarray`.
"""
pad_size = get_max_height_width(images)

grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_masks_grouped = {}
for shape, stacked_images in grouped_images.items():
stacked_images = self._pad_image(
stacked_images,
pad_size,
constant_values=constant_values,
)
processed_images_grouped[shape] = stacked_images

if return_pixel_mask:
stacked_masks = make_pixel_mask(image=stacked_images, output_size=pad_size)
processed_masks_grouped[shape] = stacked_masks

processed_images = reorder_images(processed_images_grouped, grouped_images_index)

processed_masks = None
if return_pixel_mask:
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index)

return processed_images, processed_masks

def _preprocess(
self,
images: list["torch.Tensor"],
Expand Down Expand Up @@ -325,7 +267,7 @@ def _preprocess(
data = {}
if do_pad:
processed_images, processed_masks = self.pad(
processed_images, return_pixel_mask=True, disable_grouping=disable_grouping
processed_images, return_mask=True, disable_grouping=disable_grouping
)
processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
data["pixel_mask"] = processed_masks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,17 @@
Processor class for BridgeTower.
"""

from ...processing_utils import ProcessingKwargs, ProcessorMixin
from typing import Optional

from ...processing_utils import ImagesKwargs, ProcessingKwargs, ProcessorMixin


class BridgeTowerImagesKwargs(ImagesKwargs):
size_divisor: Optional[int]


class BridgeTowerProcessorKwargs(ProcessingKwargs, total=False):
images_kwargs: BridgeTowerImagesKwargs
_defaults = {
"text_kwargs": {
"add_special_tokens": True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def _preprocess(
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
if crop_to_patches:
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,23 +74,12 @@ class ConditionalDetrFastImageProcessorKwargs(DefaultFastImageProcessorKwargs):
Controls whether to convert the annotations to the format expected by the CONDITIONAL_DETR model. Converts the
bounding boxes to the format `(center_x, center_y, width, height)` and in the range `[0, 1]`.
Can be overridden by the `do_convert_annotations` parameter in the `preprocess` method.
do_pad (`bool`, *optional*, defaults to `True`):
Controls whether to pad the image. Can be overridden by the `do_pad` parameter in the `preprocess`
method. If `True`, padding will be applied to the bottom and right of the image with zeros.
If `pad_size` is provided, the image will be padded to the specified dimensions.
Otherwise, the image will be padded to the maximum height and width of the batch.
pad_size (`dict[str, int]`, *optional*):
The size `{"height": int, "width" int}` to pad the images to. Must be larger than any image size
provided for preprocessing. If `pad_size` is not provided, images will be padded to the largest
height and width in the batch.
return_segmentation_masks (`bool`, *optional*, defaults to `False`):
Whether to return segmentation masks.
"""

format: Optional[Union[str, AnnotationFormat]]
do_convert_annotations: Optional[bool]
do_pad: Optional[bool]
pad_size: Optional[dict[str, int]]
return_segmentation_masks: Optional[bool]


Expand Down Expand Up @@ -629,7 +618,7 @@ def _preprocess(
image_mean: Optional[Union[float, list[float]]],
image_std: Optional[Union[float, list[float]]],
do_pad: bool,
pad_size: Optional[dict[str, int]],
pad_size: Optional[SizeDict],
format: Optional[Union[str, AnnotationFormat]],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
Expand Down Expand Up @@ -698,7 +687,7 @@ def _preprocess(
if do_pad:
# depends on all resized image shapes so we need another loop
if pad_size is not None:
padded_size = (pad_size["height"], pad_size["width"])
padded_size = (pad_size.height, pad_size.width)
else:
padded_size = get_max_height_width(images)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ def _preprocess(
image_std: Optional[Union[float, list[float]]],
disable_grouping: Optional[bool],
return_tensors: Optional[Union[str, TensorType]],
**kwargs,
) -> BatchFeature:
# Group images by size for batched resizing
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
Expand Down
Loading