diff --git a/src/transformers/models/glm4v/image_processing_glm4v_fast.py b/src/transformers/models/glm4v/image_processing_glm4v_fast.py index d93bc5370219..061654519d21 100644 --- a/src/transformers/models/glm4v/image_processing_glm4v_fast.py +++ b/src/transformers/models/glm4v/image_processing_glm4v_fast.py @@ -22,6 +22,8 @@ from ...image_processing_utils_fast import ( BaseImageProcessorFast, DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, ) from ...image_utils import ( OPENAI_CLIP_MEAN, @@ -128,46 +130,54 @@ def _preprocess( Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`. """ - processed_images = [] - processed_grids = [] - - all_target_sizes = [] - for image in images: - height, width = image.shape[-2:] - resized_height, resized_width = smart_resize( - num_frames=temporal_patch_size, - height=height, - width=width, - temporal_factor=temporal_patch_size, - factor=patch_size * merge_size, - min_pixels=size.shortest_edge, - max_pixels=size.longest_edge, - ) - all_target_sizes.append((resized_height, resized_width)) - - target_height = max([s[0] for s in all_target_sizes]) - target_width = max([s[1] for s in all_target_sizes]) - - for image in images: + grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping) + resized_images_grouped = {} + for shape, stacked_images in grouped_images.items(): + height, width = stacked_images.shape[-2:] if do_resize: - image = self.resize( - image, - size=SizeDict(height=target_height, width=target_width), + resized_height, resized_width = smart_resize( + num_frames=temporal_patch_size, + height=height, + width=width, + temporal_factor=temporal_patch_size, + factor=patch_size * merge_size, + min_pixels=size.shortest_edge, + max_pixels=size.longest_edge, + ) + stacked_images = self.resize( + stacked_images, + size=SizeDict(height=resized_height, width=resized_width), interpolation=interpolation, ) + resized_images_grouped[shape] = stacked_images + + resized_images = reorder_images(resized_images_grouped, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping) + processed_images_grouped = {} + processed_grids = {} + + for shape, stacked_images in grouped_images.items(): + resized_height, resized_width = stacked_images.shape[-2:] + + patches = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + if patches.ndim == 4: # (B, C, H, W) + patches = patches.unsqueeze(1) # (B, T=1, C, H, W) + + if patches.shape[1] % temporal_patch_size != 0: + repeats = patches[:, -1:].repeat( + 1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1 + ) + patches = torch.cat([patches, repeats], dim=1) + + batch_size, t_len, channel = patches.shape[:3] + grid_t = t_len // temporal_patch_size + grid_h, grid_w = resized_height // patch_size, resized_width // patch_size - image = self.rescale_and_normalize( - image.unsqueeze(0), do_rescale, rescale_factor, do_normalize, image_mean, image_std - ).squeeze(0) - - patches = image.unsqueeze(0) - if patches.shape[0] % temporal_patch_size != 0: - repeats = patches[-1:].repeat(temporal_patch_size - (patches.shape[0] % temporal_patch_size), 1, 1, 1) - patches = torch.cat([patches, repeats], dim=0) - channel = patches.shape[1] - grid_t = patches.shape[0] // temporal_patch_size - grid_h, grid_w = target_height // patch_size, target_width // patch_size patches = patches.view( + batch_size, grid_t, temporal_patch_size, channel, @@ -178,15 +188,22 @@ def _preprocess( merge_size, patch_size, ) - patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8) + # (B, grid_t, gh, gw, mh, mw, C, tp, ph, pw) + patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9) + flatten_patches = patches.reshape( + batch_size, grid_t * grid_h * grid_w, channel * temporal_patch_size * patch_size * patch_size, ) - processed_images.append(flatten_patches) - processed_grids.append([grid_t, grid_h, grid_w]) - pixel_values = torch.stack(processed_images, dim=0) + processed_images_grouped[shape] = flatten_patches + processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size + + processed_images = reorder_images(processed_images_grouped, grouped_images_index) + processed_grids = reorder_images(processed_grids, grouped_images_index) + + pixel_values = torch.cat(processed_images, dim=0) image_grid_thw = torch.tensor(processed_grids) return BatchFeature( diff --git a/tests/models/glm4v/test_image_processing_glm4v.py b/tests/models/glm4v/test_image_processing_glm4v.py new file mode 100644 index 000000000000..cb5af4b275d2 --- /dev/null +++ b/tests/models/glm4v/test_image_processing_glm4v.py @@ -0,0 +1,254 @@ +# Copyright 2021 HuggingFace Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import unittest + +import numpy as np + +from transformers.testing_utils import require_torch, require_vision +from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available + +from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs + + +if is_torch_available(): + import torch + + +if is_vision_available(): + from PIL import Image + + from transformers import Glm4vImageProcessor + from transformers.models.glm4v.image_processing_glm4v import smart_resize + + if is_torchvision_available(): + from transformers import Glm4vImageProcessorFast + + +class Glm4vImageProcessingTester: + def __init__( + self, + parent, + batch_size=7, + num_channels=3, + min_resolution=30, + max_resolution=80, + do_resize=True, + size=None, + do_normalize=True, + image_mean=[0.5, 0.5, 0.5], + image_std=[0.5, 0.5, 0.5], + temporal_patch_size=2, + patch_size=14, + merge_size=2, + ): + size = size if size is not None else {"longest_edge": 20, "shortest_edge": 10} + self.parent = parent + self.batch_size = batch_size + self.num_channels = num_channels + self.min_resolution = min_resolution + self.max_resolution = max_resolution + self.do_resize = do_resize + self.size = size + self.do_normalize = do_normalize + self.image_mean = image_mean + self.image_std = image_std + self.temporal_patch_size = temporal_patch_size + self.patch_size = patch_size + self.merge_size = merge_size + + def prepare_image_processor_dict(self): + return { + "image_mean": self.image_mean, + "image_std": self.image_std, + "do_normalize": self.do_normalize, + "do_resize": self.do_resize, + "size": self.size, + "temporal_patch_size": self.temporal_patch_size, + "patch_size": self.patch_size, + "merge_size": self.merge_size, + } + + def expected_output_image_shape(self, images): + grid_t = 1 + hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size + seq_len = 0 + for image in images: + if isinstance(image, list) and isinstance(image[0], Image.Image): + image = np.stack([np.array(frame) for frame in image]) + elif hasattr(image, "shape"): + pass + else: + image = np.array(image) + if hasattr(image, "shape") and len(image.shape) >= 3: + if isinstance(image, np.ndarray): + if len(image.shape) == 4: + height, width = image.shape[1:3] + elif len(image.shape) == 3: + height, width = image.shape[:2] + else: + height, width = self.min_resolution, self.min_resolution + else: + height, width = image.shape[-2:] + else: + height, width = self.min_resolution, self.min_resolution + + resized_height, resized_width = smart_resize( + self.temporal_patch_size, + height, + width, + factor=self.patch_size * self.merge_size, + min_pixels=self.size["shortest_edge"], + max_pixels=self.size["longest_edge"], + ) + grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size + seq_len += grid_t * grid_h * grid_w + return (seq_len, hidden_dim) + + def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False): + return prepare_image_inputs( + batch_size=self.batch_size, + num_channels=self.num_channels, + min_resolution=self.min_resolution, + max_resolution=self.max_resolution, + equal_resolution=equal_resolution, + numpify=numpify, + torchify=torchify, + ) + + +@require_torch +@require_vision +class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase): + image_processing_class = Glm4vImageProcessor if is_vision_available() else None + fast_image_processing_class = Glm4vImageProcessorFast if is_torchvision_available() else None + + def setUp(self): + super().setUp() + self.image_processor_tester = Glm4vImageProcessingTester(self) + + @property + def image_processor_dict(self): + return self.image_processor_tester.prepare_image_processor_dict() + + def test_image_processor_properties(self): + for image_processing_class in self.image_processor_list: + image_processing = image_processing_class(**self.image_processor_dict) + self.assertTrue(hasattr(image_processing, "image_mean")) + self.assertTrue(hasattr(image_processing, "image_std")) + self.assertTrue(hasattr(image_processing, "do_normalize")) + self.assertTrue(hasattr(image_processing, "do_resize")) + self.assertTrue(hasattr(image_processing, "size")) + + def test_image_processor_from_dict_with_kwargs(self): + for image_processing_class in self.image_processor_list: + image_processor = image_processing_class.from_dict(self.image_processor_dict) + self.assertEqual(image_processor.size, {"shortest_edge": 10, "longest_edge": 20}) + + image_processor = image_processing_class.from_dict( + self.image_processor_dict, size={"shortest_edge": 42, "longest_edge": 42} + ) + self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 42}) + + # batch size is flattened + def test_call_pil(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PIL images + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False) + for image in image_inputs: + self.assertIsInstance(image, Image.Image) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random numpy tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + for image in image_inputs: + self.assertIsInstance(image, np.ndarray) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_pytorch(self): + for image_processing_class in self.image_processor_list: + # Initialize image_processing + image_processing = image_processing_class(**self.image_processor_dict) + # create random PyTorch tensors + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True) + + for image in image_inputs: + self.assertIsInstance(image, torch.Tensor) + + # Test not batched input + encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + def test_call_numpy_4_channels(self): + for image_processing_class in self.image_processor_list: + # Test that can process images which have an arbitrary number of channels + # Initialize image_processing + image_processor = image_processing_class(**self.image_processor_dict) + + # create random numpy tensors + self.image_processor_tester.num_channels = 4 + image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True) + + # Test not batched input + encoded_images = image_processor( + image_inputs[0], + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]]) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape) + + # Test batched + encoded_images = image_processor( + image_inputs, + return_tensors="pt", + input_data_format="channels_last", + image_mean=0, + image_std=1, + ).pixel_values + expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs) + self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)