-
Notifications
You must be signed in to change notification settings - Fork 30.7k
Consistent naming for images kwargs #40834
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
db736e2
cfd94c3
c12e7f4
7bc8f2a
d1c2344
8b6dcb7
48c8179
193524c
a863e97
1810171
29c7784
8c97603
75fbe5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -79,8 +79,6 @@ def validate_fast_preprocess_arguments( | |
do_normalize: Optional[bool] = None, | ||
image_mean: Optional[Union[float, list[float]]] = None, | ||
image_std: Optional[Union[float, list[float]]] = None, | ||
do_pad: Optional[bool] = None, | ||
size_divisibility: Optional[int] = None, | ||
do_center_crop: Optional[bool] = None, | ||
crop_size: Optional[SizeDict] = None, | ||
do_resize: Optional[bool] = None, | ||
|
@@ -99,8 +97,6 @@ def validate_fast_preprocess_arguments( | |
do_normalize=do_normalize, | ||
image_mean=image_mean, | ||
image_std=image_std, | ||
do_pad=do_pad, | ||
size_divisibility=size_divisibility, | ||
do_center_crop=do_center_crop, | ||
crop_size=crop_size, | ||
do_resize=do_resize, | ||
|
@@ -181,6 +177,8 @@ class DefaultFastImageProcessorKwargs(TypedDict, total=False): | |
do_normalize: Optional[bool] | ||
image_mean: Optional[Union[float, list[float]]] | ||
image_std: Optional[Union[float, list[float]]] | ||
do_pad: Optional[bool] | ||
pad_size: Optional[dict[str, int]] | ||
do_convert_rgb: Optional[bool] | ||
return_tensors: Optional[Union[str, TensorType]] | ||
data_format: Optional[ChannelDimension] | ||
|
@@ -199,6 +197,8 @@ class BaseImageProcessorFast(BaseImageProcessor): | |
crop_size = None | ||
do_resize = None | ||
do_center_crop = None | ||
do_pad = None | ||
pad_size = None | ||
do_rescale = None | ||
rescale_factor = 1 / 255 | ||
do_normalize = None | ||
|
@@ -222,6 +222,9 @@ def __init__(self, **kwargs: Unpack[DefaultFastImageProcessorKwargs]): | |
) | ||
crop_size = kwargs.pop("crop_size", self.crop_size) | ||
self.crop_size = get_size_dict(crop_size, param_name="crop_size") if crop_size is not None else None | ||
pad_size = kwargs.pop("pad_size", self.pad_size) | ||
self.pad_size = get_size_dict(size=pad_size, param_name="pad_size") if pad_size is not None else None | ||
|
||
for key in self.valid_kwargs.__annotations__: | ||
kwarg = kwargs.pop(key, None) | ||
if kwarg is not None: | ||
|
@@ -239,6 +242,74 @@ def is_fast(self) -> bool: | |
""" | ||
return True | ||
|
||
def pad( | ||
self, | ||
images: "torch.Tensor", | ||
pad_size: SizeDict = None, | ||
fill_value: Optional[int] = 0, | ||
padding_mode: Optional[str] = "constant", | ||
return_mask: Optional[bool] = False, | ||
disable_grouping: Optional[bool] = False, | ||
**kwargs, | ||
) -> "torch.Tensor": | ||
""" | ||
Pads images to `(pad_size["height"], pad_size["width"])` or to the largest size in the batch. | ||
|
||
Args: | ||
images (`torch.Tensor`): | ||
Images to pad. | ||
pad_size (`SizeDict`, *optional*): | ||
Dictionary in the format `{"height": int, "width": int}` specifying the size of the output image. | ||
fill_value (`int`, *optional*, defaults to `0`): | ||
The constant value used to fill the padded area. | ||
padding_mode (`str`, *optional*, defaults to "constant"): | ||
The padding mode to use. Can be any of the modes supported by | ||
`torch.nn.functional.pad` (e.g. constant, reflection, replication). | ||
return_mask (`bool`, *optional*, defaults to `False`): | ||
Whether to return a pixel mask to denote padded regions. | ||
disable_grouping (`bool`, *optional*, defaults to `False`): | ||
Whether to disable grouping of images by size. | ||
|
||
Returns: | ||
`torch.Tensor`: The resized image. | ||
""" | ||
if pad_size is not None: | ||
if not (pad_size.height and pad_size.width): | ||
raise ValueError(f"Pad size must contain 'height' and 'width' keys only. Got pad_size={pad_size}.") | ||
pad_size = (pad_size.height, pad_size.width) | ||
else: | ||
pad_size = get_max_height_width(images) | ||
|
||
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) | ||
processed_images_grouped = {} | ||
processed_masks_grouped = {} | ||
for shape, stacked_images in grouped_images.items(): | ||
image_size = stacked_images.shape[-2:] | ||
padding_height = pad_size[0] - image_size[0] | ||
padding_width = pad_size[1] - image_size[1] | ||
if padding_height < 0 or padding_width < 0: | ||
raise ValueError( | ||
f"Padding dimensions are negative. Please make sure that the `pad_size` is larger than the " | ||
f"image size. Got pad_size={pad_size}, image_size={image_size}." | ||
) | ||
if image_size != pad_size: | ||
padding = (0, 0, padding_width, padding_height) | ||
stacked_images = F.pad(stacked_images, padding, fill=fill_value, padding_mode=padding_mode) | ||
processed_images_grouped[shape] = stacked_images | ||
|
||
if return_mask: | ||
# keep only one from the channel dimension in pixel mask | ||
stacked_masks = torch.zeros_like(stacked_images, dtype=torch.int64)[..., 0, :, :] | ||
stacked_masks[..., : image_size[0], : image_size[1]] = 1 | ||
processed_masks_grouped[shape] = stacked_masks | ||
|
||
processed_images = reorder_images(processed_images_grouped, grouped_images_index) | ||
if return_mask: | ||
processed_masks = reorder_images(processed_masks_grouped, grouped_images_index) | ||
return processed_images, processed_masks | ||
|
||
return processed_images | ||
Comment on lines
+245
to
+311
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nice! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I was also wondering if we should include the whole group and reorder process in other functions like resize or rescale and normalize. Bu that means we can't have several processing functions under one grouping loop. And a new grouping is only needed if the previous function changed the shape of the images (not need after a rescale and normalize for example), and they do cost a bit of processing time, so it might be best to leave this to the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The grouping for padding was added only because it's supposed to be called outside of the loop if we want to pad all images to the same size. I see what you mean and it's partially related to above thread. We need to handle two types of paddings. WDYT of adding an arg like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh yes that makes sense then, if we're padding to the maximum size. I don't think we need to add anything, we can just override the function when we have a different type of padding needed There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, that is what most processors do. Though it's not "overriding", they use fn called |
||
|
||
def resize( | ||
self, | ||
image: "torch.Tensor", | ||
|
@@ -577,6 +648,7 @@ def _further_process_kwargs( | |
self, | ||
size: Optional[SizeDict] = None, | ||
crop_size: Optional[SizeDict] = None, | ||
pad_size: Optional[SizeDict] = None, | ||
default_to_square: Optional[bool] = None, | ||
image_mean: Optional[Union[float, list[float]]] = None, | ||
image_std: Optional[Union[float, list[float]]] = None, | ||
|
@@ -593,6 +665,8 @@ def _further_process_kwargs( | |
size = SizeDict(**get_size_dict(size=size, default_to_square=default_to_square)) | ||
if crop_size is not None: | ||
crop_size = SizeDict(**get_size_dict(crop_size, param_name="crop_size")) | ||
if pad_size is not None: | ||
pad_size = SizeDict(**get_size_dict(size=pad_size, param_name="pad_size")) | ||
if isinstance(image_mean, list): | ||
image_mean = tuple(image_mean) | ||
if isinstance(image_std, list): | ||
|
@@ -602,6 +676,7 @@ def _further_process_kwargs( | |
|
||
kwargs["size"] = size | ||
kwargs["crop_size"] = crop_size | ||
kwargs["pad_size"] = pad_size | ||
kwargs["image_mean"] = image_mean | ||
kwargs["image_std"] = image_std | ||
kwargs["data_format"] = data_format | ||
|
@@ -714,6 +789,8 @@ def _preprocess( | |
do_normalize: bool, | ||
image_mean: Optional[Union[float, list[float]]], | ||
image_std: Optional[Union[float, list[float]]], | ||
do_pad: Optional[bool], | ||
pad_size: Optional[SizeDict], | ||
disable_grouping: Optional[bool], | ||
return_tensors: Optional[Union[str, TensorType]], | ||
**kwargs, | ||
|
@@ -739,10 +816,12 @@ def _preprocess( | |
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std | ||
) | ||
processed_images_grouped[shape] = stacked_images | ||
|
||
processed_images = reorder_images(processed_images_grouped, grouped_images_index) | ||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images | ||
|
||
if do_pad: | ||
processed_images = self.pad(processed_images, pad_size=pad_size, disable_grouping=disable_grouping) | ||
|
||
processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images | ||
return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors) | ||
|
||
def to_dict(self): | ||
|
Uh oh!
There was an error while loading. Please reload this page.