From f61cec1e8864e9da23b7646556f4fd43e383d0ae Mon Sep 17 00:00:00 2001 From: yonigozlan Date: Fri, 12 Sep 2025 18:46:01 +0000 Subject: [PATCH] make center_crop fast equivalent to slow --- .../image_processing_utils_fast.py | 22 +++++++++++++++++-- .../image_processing_perceiver_fast.py | 2 +- .../test_image_processing_chinese_clip.py | 10 +-------- tests/test_image_processing_common.py | 5 ----- 4 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/transformers/image_processing_utils_fast.py b/src/transformers/image_processing_utils_fast.py index 38a4a3e32718..45624e013881 100644 --- a/src/transformers/image_processing_utils_fast.py +++ b/src/transformers/image_processing_utils_fast.py @@ -405,10 +405,11 @@ def rescale_and_normalize( def center_crop( self, image: "torch.Tensor", - size: dict[str, int], + size: SizeDict, **kwargs, ) -> "torch.Tensor": """ + Note: override torchvision's center_crop to have the same behavior as the slow processor. Center crop an image to `(size["height"], size["width"])`. If the input size is smaller than `crop_size` along any edge, the image is padded with 0's and then center cropped. @@ -423,7 +424,24 @@ def center_crop( """ if size.height is None or size.width is None: raise ValueError(f"The size dictionary must have keys 'height' and 'width'. Got {size.keys()}") - return F.center_crop(image, (size["height"], size["width"])) + image_height, image_width = image.shape[-2:] + crop_height, crop_width = size.height, size.width + + if crop_width > image_width or crop_height > image_height: + padding_ltrb = [ + (crop_width - image_width) // 2 if crop_width > image_width else 0, + (crop_height - image_height) // 2 if crop_height > image_height else 0, + (crop_width - image_width + 1) // 2 if crop_width > image_width else 0, + (crop_height - image_height + 1) // 2 if crop_height > image_height else 0, + ] + image = F.pad(image, padding_ltrb, fill=0) # PIL uses fill value 0 + image_height, image_width = image.shape[-2:] + if crop_width == image_width and crop_height == image_height: + return image + + crop_top = int((image_height - crop_height) / 2.0) + crop_left = int((image_width - crop_width) / 2.0) + return F.crop(image, crop_top, crop_left, crop_height, crop_width) def convert_to_rgb( self, diff --git a/src/transformers/models/perceiver/image_processing_perceiver_fast.py b/src/transformers/models/perceiver/image_processing_perceiver_fast.py index 640083ba82dd..ecd7f938f569 100644 --- a/src/transformers/models/perceiver/image_processing_perceiver_fast.py +++ b/src/transformers/models/perceiver/image_processing_perceiver_fast.py @@ -81,7 +81,7 @@ def center_crop( min_dim = min(height, width) cropped_height = int((size.height / crop_size.height) * min_dim) cropped_width = int((size.width / crop_size.width) * min_dim) - return F.center_crop(image, (cropped_height, cropped_width)) + return super().center_crop(image, SizeDict(height=cropped_height, width=cropped_width)) def _preprocess( self, diff --git a/tests/models/chinese_clip/test_image_processing_chinese_clip.py b/tests/models/chinese_clip/test_image_processing_chinese_clip.py index 7acae860b08a..18670bcb4d64 100644 --- a/tests/models/chinese_clip/test_image_processing_chinese_clip.py +++ b/tests/models/chinese_clip/test_image_processing_chinese_clip.py @@ -141,7 +141,7 @@ class ChineseCLIPImageProcessingTestFourChannels(ImageProcessingTestMixin, unitt def setUp(self): super().setUp() - self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=4, do_center_crop=True) + self.image_processor_tester = ChineseCLIPImageProcessingTester(self, num_channels=3, do_center_crop=True) self.expected_encoded_image_num_channels = 3 @property @@ -160,14 +160,6 @@ def test_image_processor_properties(self): self.assertTrue(hasattr(image_processing, "image_std")) self.assertTrue(hasattr(image_processing, "do_convert_rgb")) - @unittest.skip(reason="ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy - def test_call_numpy(self): - return super().test_call_numpy() - - @unittest.skip(reason="ChineseCLIPImageProcessor does not support 4 channels yet") # FIXME Amy - def test_call_pytorch(self): - return super().test_call_torch() - @unittest.skip( reason="ChineseCLIPImageProcessor doesn't treat 4 channel PIL and numpy consistently yet" ) # FIXME Amy diff --git a/tests/test_image_processing_common.py b/tests/test_image_processing_common.py index 635d6a35dc85..ac7ecec0fb70 100644 --- a/tests/test_image_processing_common.py +++ b/tests/test_image_processing_common.py @@ -200,11 +200,6 @@ def test_slow_fast_equivalence_batched(self): if self.image_processing_class is None or self.fast_image_processing_class is None: self.skipTest(reason="Skipping slow/fast equivalence test as one of the image processors is not defined") - if hasattr(self.image_processor_tester, "do_center_crop") and self.image_processor_tester.do_center_crop: - self.skipTest( - reason="Skipping as do_center_crop is True and center_crop functions are not equivalent for fast and slow processors" - ) - dummy_images = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) image_processor_slow = self.image_processing_class(**self.image_processor_dict) image_processor_fast = self.fast_image_processing_class(**self.image_processor_dict)