Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 56 additions & 39 deletions src/transformers/models/glm4v/image_processing_glm4v_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from ...image_processing_utils_fast import (
BaseImageProcessorFast,
DefaultFastImageProcessorKwargs,
group_images_by_shape,
reorder_images,
)
from ...image_utils import (
OPENAI_CLIP_MEAN,
Expand Down Expand Up @@ -128,46 +130,54 @@ def _preprocess(
Preprocess an image or batch of images. Copy of the `preprocess` method from `CLIPImageProcessor`.
"""

processed_images = []
processed_grids = []

all_target_sizes = []
for image in images:
height, width = image.shape[-2:]
resized_height, resized_width = smart_resize(
num_frames=temporal_patch_size,
height=height,
width=width,
temporal_factor=temporal_patch_size,
factor=patch_size * merge_size,
min_pixels=size.shortest_edge,
max_pixels=size.longest_edge,
)
all_target_sizes.append((resized_height, resized_width))

target_height = max([s[0] for s in all_target_sizes])
target_width = max([s[1] for s in all_target_sizes])

for image in images:
grouped_images, grouped_images_index = group_images_by_shape(images, disable_grouping=disable_grouping)
resized_images_grouped = {}
for shape, stacked_images in grouped_images.items():
height, width = stacked_images.shape[-2:]
if do_resize:
image = self.resize(
image,
size=SizeDict(height=target_height, width=target_width),
resized_height, resized_width = smart_resize(
num_frames=temporal_patch_size,
height=height,
width=width,
temporal_factor=temporal_patch_size,
factor=patch_size * merge_size,
min_pixels=size.shortest_edge,
max_pixels=size.longest_edge,
)
stacked_images = self.resize(
stacked_images,
size=SizeDict(height=resized_height, width=resized_width),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like this might be breaking backward compatibility, as resized size used to be computed as the max of all target sizes in the batch. Not exactly sure why this was the case in the first place, but let's make sure we don't have edge cases here that would make this a breaking change. In particular, having the same resized size for all images in the batch ensured that we could stack the images in the end, not sure this is the case now.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The "backward compatibility" concern is fundamentally invalid because the current behavior is already broken—it fails to process mixed-size batches entirely, while same-size batches remain completely unaffected since identical input dimensions produce identical output dimensions, and the Fast version has already proven the safety of this fix through successful implementation and testing.
The Fast version's proven approach:
group images by their original dimensions, process each group independently by applying smart_resize per group, maintain the original batch sequence order through proper reconstruction, and only stack dimensionally compatible tensors.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I follow you, when you say the current behavior is broken, do you mean it's incorrect because the images are not resized to the correct size, or because it crashes?
A big part of the issue is that there is not image processing tests for this model for some reason. adding a test file would make it clearer what works and what doesn't.
Would you mind adding this test file? you can look at other test_image_processing_....py files to see how they should be written.
If you don't have the bandwidth for that, we can open a separate PR.
Thanks a lot!

interpolation=interpolation,
)
resized_images_grouped[shape] = stacked_images

resized_images = reorder_images(resized_images_grouped, grouped_images_index)

grouped_images, grouped_images_index = group_images_by_shape(resized_images, disable_grouping=disable_grouping)
processed_images_grouped = {}
processed_grids = {}

for shape, stacked_images in grouped_images.items():
resized_height, resized_width = stacked_images.shape[-2:]

patches = self.rescale_and_normalize(
stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std
)
if patches.ndim == 4: # (B, C, H, W)
patches = patches.unsqueeze(1) # (B, T=1, C, H, W)

if patches.shape[1] % temporal_patch_size != 0:
repeats = patches[:, -1:].repeat(
1, temporal_patch_size - (patches.shape[1] % temporal_patch_size), 1, 1, 1
)
patches = torch.cat([patches, repeats], dim=1)

batch_size, t_len, channel = patches.shape[:3]
grid_t = t_len // temporal_patch_size
grid_h, grid_w = resized_height // patch_size, resized_width // patch_size

image = self.rescale_and_normalize(
image.unsqueeze(0), do_rescale, rescale_factor, do_normalize, image_mean, image_std
).squeeze(0)

patches = image.unsqueeze(0)
if patches.shape[0] % temporal_patch_size != 0:
repeats = patches[-1:].repeat(temporal_patch_size - (patches.shape[0] % temporal_patch_size), 1, 1, 1)
patches = torch.cat([patches, repeats], dim=0)
channel = patches.shape[1]
grid_t = patches.shape[0] // temporal_patch_size
grid_h, grid_w = target_height // patch_size, target_width // patch_size
patches = patches.view(
batch_size,
grid_t,
temporal_patch_size,
channel,
Expand All @@ -178,15 +188,22 @@ def _preprocess(
merge_size,
patch_size,
)
patches = patches.permute(0, 3, 6, 4, 7, 2, 1, 5, 8)
# (B, grid_t, gh, gw, mh, mw, C, tp, ph, pw)
patches = patches.permute(0, 1, 4, 7, 5, 8, 3, 2, 6, 9)

flatten_patches = patches.reshape(
batch_size,
grid_t * grid_h * grid_w,
channel * temporal_patch_size * patch_size * patch_size,
)
processed_images.append(flatten_patches)
processed_grids.append([grid_t, grid_h, grid_w])

pixel_values = torch.stack(processed_images, dim=0)
processed_images_grouped[shape] = flatten_patches
processed_grids[shape] = [[grid_t, grid_h, grid_w]] * batch_size

processed_images = reorder_images(processed_images_grouped, grouped_images_index)
processed_grids = reorder_images(processed_grids, grouped_images_index)

pixel_values = torch.cat(processed_images, dim=0)
image_grid_thw = torch.tensor(processed_grids)

return BatchFeature(
Expand Down
254 changes: 254 additions & 0 deletions tests/models/glm4v/test_image_processing_glm4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,254 @@
# Copyright 2021 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import unittest

import numpy as np

from transformers.testing_utils import require_torch, require_vision
from transformers.utils import is_torch_available, is_torchvision_available, is_vision_available

from ...test_image_processing_common import ImageProcessingTestMixin, prepare_image_inputs


if is_torch_available():
import torch


if is_vision_available():
from PIL import Image

from transformers import Glm4vImageProcessor
from transformers.models.glm4v.image_processing_glm4v import smart_resize

if is_torchvision_available():
from transformers import Glm4vImageProcessorFast


class Glm4vImageProcessingTester:
def __init__(
self,
parent,
batch_size=7,
num_channels=3,
min_resolution=30,
max_resolution=80,
do_resize=True,
size=None,
do_normalize=True,
image_mean=[0.5, 0.5, 0.5],
image_std=[0.5, 0.5, 0.5],
temporal_patch_size=2,
patch_size=14,
merge_size=2,
):
size = size if size is not None else {"longest_edge": 20, "shortest_edge": 10}
self.parent = parent
self.batch_size = batch_size
self.num_channels = num_channels
self.min_resolution = min_resolution
self.max_resolution = max_resolution
self.do_resize = do_resize
self.size = size
self.do_normalize = do_normalize
self.image_mean = image_mean
self.image_std = image_std
self.temporal_patch_size = temporal_patch_size
self.patch_size = patch_size
self.merge_size = merge_size

def prepare_image_processor_dict(self):
return {
"image_mean": self.image_mean,
"image_std": self.image_std,
"do_normalize": self.do_normalize,
"do_resize": self.do_resize,
"size": self.size,
"temporal_patch_size": self.temporal_patch_size,
"patch_size": self.patch_size,
"merge_size": self.merge_size,
}

def expected_output_image_shape(self, images):
grid_t = 1
hidden_dim = self.num_channels * self.temporal_patch_size * self.patch_size * self.patch_size
seq_len = 0
for image in images:
if isinstance(image, list) and isinstance(image[0], Image.Image):
image = np.stack([np.array(frame) for frame in image])
elif hasattr(image, "shape"):
pass
else:
image = np.array(image)
if hasattr(image, "shape") and len(image.shape) >= 3:
if isinstance(image, np.ndarray):
if len(image.shape) == 4:
height, width = image.shape[1:3]
elif len(image.shape) == 3:
height, width = image.shape[:2]
else:
height, width = self.min_resolution, self.min_resolution
else:
height, width = image.shape[-2:]
else:
height, width = self.min_resolution, self.min_resolution

resized_height, resized_width = smart_resize(
self.temporal_patch_size,
height,
width,
factor=self.patch_size * self.merge_size,
min_pixels=self.size["shortest_edge"],
max_pixels=self.size["longest_edge"],
)
grid_h, grid_w = resized_height // self.patch_size, resized_width // self.patch_size
seq_len += grid_t * grid_h * grid_w
return (seq_len, hidden_dim)

def prepare_image_inputs(self, equal_resolution=False, numpify=False, torchify=False):
return prepare_image_inputs(
batch_size=self.batch_size,
num_channels=self.num_channels,
min_resolution=self.min_resolution,
max_resolution=self.max_resolution,
equal_resolution=equal_resolution,
numpify=numpify,
torchify=torchify,
)


@require_torch
@require_vision
class ViTImageProcessingTest(ImageProcessingTestMixin, unittest.TestCase):
image_processing_class = Glm4vImageProcessor if is_vision_available() else None
fast_image_processing_class = Glm4vImageProcessorFast if is_torchvision_available() else None

def setUp(self):
super().setUp()
self.image_processor_tester = Glm4vImageProcessingTester(self)

@property
def image_processor_dict(self):
return self.image_processor_tester.prepare_image_processor_dict()

def test_image_processor_properties(self):
for image_processing_class in self.image_processor_list:
image_processing = image_processing_class(**self.image_processor_dict)
self.assertTrue(hasattr(image_processing, "image_mean"))
self.assertTrue(hasattr(image_processing, "image_std"))
self.assertTrue(hasattr(image_processing, "do_normalize"))
self.assertTrue(hasattr(image_processing, "do_resize"))
self.assertTrue(hasattr(image_processing, "size"))

def test_image_processor_from_dict_with_kwargs(self):
for image_processing_class in self.image_processor_list:
image_processor = image_processing_class.from_dict(self.image_processor_dict)
self.assertEqual(image_processor.size, {"shortest_edge": 10, "longest_edge": 20})

image_processor = image_processing_class.from_dict(
self.image_processor_dict, size={"shortest_edge": 42, "longest_edge": 42}
)
self.assertEqual(image_processor.size, {"shortest_edge": 42, "longest_edge": 42})

# batch size is flattened
def test_call_pil(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PIL images
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False)
for image in image_inputs:
self.assertIsInstance(image, Image.Image)

# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

def test_call_numpy(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random numpy tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)
for image in image_inputs:
self.assertIsInstance(image, np.ndarray)

# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

# Test batched
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

def test_call_pytorch(self):
for image_processing_class in self.image_processor_list:
# Initialize image_processing
image_processing = image_processing_class(**self.image_processor_dict)
# create random PyTorch tensors
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, torchify=True)

for image in image_inputs:
self.assertIsInstance(image, torch.Tensor)

# Test not batched input
encoded_images = image_processing(image_inputs[0], return_tensors="pt").pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

# Test batched
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
encoded_images = image_processing(image_inputs, return_tensors="pt").pixel_values
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

def test_call_numpy_4_channels(self):
for image_processing_class in self.image_processor_list:
# Test that can process images which have an arbitrary number of channels
# Initialize image_processing
image_processor = image_processing_class(**self.image_processor_dict)

# create random numpy tensors
self.image_processor_tester.num_channels = 4
image_inputs = self.image_processor_tester.prepare_image_inputs(equal_resolution=False, numpify=True)

# Test not batched input
encoded_images = image_processor(
image_inputs[0],
return_tensors="pt",
input_data_format="channels_last",
image_mean=0,
image_std=1,
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape([image_inputs[0]])
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)

# Test batched
encoded_images = image_processor(
image_inputs,
return_tensors="pt",
input_data_format="channels_last",
image_mean=0,
image_std=1,
).pixel_values
expected_output_image_shape = self.image_processor_tester.expected_output_image_shape(image_inputs)
self.assertEqual(tuple(encoded_images.shape), expected_output_image_shape)