diff --git a/src/diffusers/image_processor.py b/src/diffusers/image_processor.py index 5d2e2c2b9c8c..4d8d73c863ae 100644 --- a/src/diffusers/image_processor.py +++ b/src/diffusers/image_processor.py @@ -531,17 +531,12 @@ def preprocess( ) image = torch.cat(image, axis=0) - # ensure the input is a list of images: - # - if it is a batch of images (4d torch.Tensor or np.ndarray), it is converted to a list of images (a list of 3d torch.Tensor or np.ndarray) - # - if it is a single image, it is converted to a list of one image - if isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim == 4: - image = list(image) - if is_valid_image(image): - image = [image] - if not all(is_valid_image(img) for img in image): + if not is_valid_image_imagelist(image): raise ValueError( - f"Input is in incorrect format: {[type(i) for i in image]}. Currently, we only support {', '.join(supported_formats)}" + f"Input is in incorrect format. Currently, we only support {', '.join(supported_formats)}" ) + if not isinstance(image, list): + image = [image] if isinstance(image[0], PIL.Image.Image): if crops_coords is not None: