From 5a1ab522d1ac3455ecd7d86cee04f5a10c3b8f05 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 10 Nov 2023 10:02:06 +0530 Subject: [PATCH 1/6] feat: add resolution binning Co-authored-by: lawrence-cj --- .../pipelines/pixart_alpha/__init__.py | 5 +- .../pixart_alpha/pipeline_pixart_alpha.py | 72 +++++++++++++++++++ src/diffusers/utils/__init__.py | 1 + src/diffusers/utils/import_utils.py | 17 +++++ 4 files changed, 93 insertions(+), 2 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/__init__.py b/src/diffusers/pipelines/pixart_alpha/__init__.py index 0bfa28fcde50..9f9bc2d4649a 100644 --- a/src/diffusers/pipelines/pixart_alpha/__init__.py +++ b/src/diffusers/pipelines/pixart_alpha/__init__.py @@ -6,6 +6,7 @@ _LazyModule, get_objects_from_module, is_torch_available, + is_torchvision_available, is_transformers_available, ) @@ -15,7 +16,7 @@ try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_torch_available() and is_torchvision_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 @@ -26,7 +27,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: - if not (is_transformers_available() and is_torch_available()): + if not (is_transformers_available() and is_torch_available() and is_torchvision_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 147e2b76e6c6..ac9ccf9ca30d 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -28,6 +28,7 @@ BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, + is_torchvision_available, logging, replace_example_docstring, ) @@ -43,6 +44,8 @@ if is_ftfy_available(): import ftfy +if is_torchvision_available(): + from torchvision import transforms as T EXAMPLE_DOC_STRING = """ Examples: @@ -60,6 +63,42 @@ ``` """ +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + class PixArtAlphaPipeline(DiffusionPipeline): r""" @@ -495,6 +534,28 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents + @staticmethod + def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]: + """Returns binned height and width.""" + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: + orig_hw = torch.tensor([samples.shape[2], samples.shape[3]]) + custom_hw = torch.tensor([new_height, new_width]) + + if (orig_hw != custom_hw).all(): + ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1]) + resized_width = int(orig_hw[1] * ratio) + resized_height = int(orig_hw[0] * ratio) + transform = T.Compose([T.Resize((resized_height, resized_width)), T.CenterCrop(custom_hw.tolist())]) + return transform(samples) + else: + return samples + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -518,6 +579,7 @@ def __call__( callback_steps: int = 1, clean_caption: bool = True, mask_feature: bool = True, + use_resolution_bin: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -580,6 +642,10 @@ def __call__( be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. + use_resolution_bin: + (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the + closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, + they are resized back to the requested resolution. Useful for generating non-square images. Examples: @@ -591,6 +657,10 @@ def __call__( # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor + if use_resolution_bin: + orig_height, orig_width = height, width + height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) + self.check_inputs( prompt, height, width, negative_prompt, callback_steps, prompt_embeds, negative_prompt_embeds ) @@ -709,6 +779,8 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] + if use_resolution_bin: + image = self.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index b4d6bdab33eb..9aea30a23a6c 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -77,6 +77,7 @@ is_torch_version, is_torch_xla_available, is_torchsde_available, + is_torchvision_available, is_transformers_available, is_transformers_version, is_unidecode_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b3278af2f6a5..b4c8471e9d67 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -284,6 +284,13 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False +_torchvision_available = importlib.util.find_spec("torchvision") is not None +try: + _torchvision_version = importlib_metadata.version("torchvision") + logger.debug(f"Successfully imported torchvision version {_torchvision_version}") +except importlib_metadata.PackageNotFoundError: + _torchvision_available = False + def is_torch_available(): return _torch_available @@ -377,6 +384,10 @@ def is_peft_available(): return _peft_available +def is_torchvision_available(): + return _torchvision_available + + # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -485,6 +496,11 @@ def is_peft_available(): {0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` """ +# docstyle-ignore +TORCHVISION_IMPORT_ERROR = """ +{0} requires the torchvision library but it was not found in your environment. You can install it with pip: `pip install torchvision` +""" + # docstyle-ignore INVISIBLE_WATERMARK_IMPORT_ERROR = """ {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` @@ -512,6 +528,7 @@ def is_peft_available(): ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), + ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), ] ) From bbb2ab6f24d150ca188f2e4dffacb0567abd659c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 10 Nov 2023 10:20:02 +0530 Subject: [PATCH 2/6] rename --- .../pipelines/pixart_alpha/pipeline_pixart_alpha.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index ac9ccf9ca30d..b75929da621c 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -579,7 +579,7 @@ def __call__( callback_steps: int = 1, clean_caption: bool = True, mask_feature: bool = True, - use_resolution_bin: bool = True, + use_resolution_binning: bool = True, ) -> Union[ImagePipelineOutput, Tuple]: """ Function invoked when calling the pipeline for generation. @@ -642,7 +642,7 @@ def __call__( be installed. If the dependencies are not installed, the embeddings will be created from the raw prompt. mask_feature (`bool` defaults to `True`): If set to `True`, the text embeddings will be masked. - use_resolution_bin: + use_resolution_binning: (`bool` defaults to `True`): If set to `True`, the requested height and width are first mapped to the closest resolutions using `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to the requested resolution. Useful for generating non-square images. @@ -657,7 +657,7 @@ def __call__( # 1. Check inputs. Raise error if not correct height = height or self.transformer.config.sample_size * self.vae_scale_factor width = width or self.transformer.config.sample_size * self.vae_scale_factor - if use_resolution_bin: + if use_resolution_binning: orig_height, orig_width = height, width height, width = self.classify_height_width_bin(height, width, ratios=ASPECT_RATIO_1024_BIN) @@ -779,7 +779,7 @@ def __call__( if not output_type == "latent": image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0] - if use_resolution_bin: + if use_resolution_binning: image = self.resize_and_crop_tensor(image, orig_width, orig_height) else: image = latents From dbba34993a3f966f5f8f2f463ab6d5d2b00f608e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 11 Nov 2023 09:18:24 +0530 Subject: [PATCH 3/6] debug --- .../text_to_image/train_text_to_image_flax.py | 5 +-- .../train_text_to_image_lora_sdxl.py | 5 +-- .../pipelines/pixart_alpha/__init__.py | 5 +-- .../pixart_alpha/pipeline_pixart_alpha.py | 38 +++++++++++-------- src/diffusers/utils/__init__.py | 1 - src/diffusers/utils/import_utils.py | 17 --------- 6 files changed, 27 insertions(+), 44 deletions(-) diff --git a/examples/text_to_image/train_text_to_image_flax.py b/examples/text_to_image/train_text_to_image_flax.py index 9ebe34555310..e62d03c730b1 100644 --- a/examples/text_to_image/train_text_to_image_flax.py +++ b/examples/text_to_image/train_text_to_image_flax.py @@ -272,10 +272,7 @@ def main(): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - data_dir=args.train_data_dir + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) else: data_files = {} diff --git a/examples/text_to_image/train_text_to_image_lora_sdxl.py b/examples/text_to_image/train_text_to_image_lora_sdxl.py index 1a6ef0c856db..b69940603128 100644 --- a/examples/text_to_image/train_text_to_image_lora_sdxl.py +++ b/examples/text_to_image/train_text_to_image_lora_sdxl.py @@ -765,10 +765,7 @@ def load_model_hook(models, input_dir): if args.dataset_name is not None: # Downloading and loading a dataset from the hub. dataset = load_dataset( - args.dataset_name, - args.dataset_config_name, - cache_dir=args.cache_dir, - data_dir=args.train_data_dir + args.dataset_name, args.dataset_config_name, cache_dir=args.cache_dir, data_dir=args.train_data_dir ) else: data_files = {} diff --git a/src/diffusers/pipelines/pixart_alpha/__init__.py b/src/diffusers/pipelines/pixart_alpha/__init__.py index 9f9bc2d4649a..0bfa28fcde50 100644 --- a/src/diffusers/pipelines/pixart_alpha/__init__.py +++ b/src/diffusers/pipelines/pixart_alpha/__init__.py @@ -6,7 +6,6 @@ _LazyModule, get_objects_from_module, is_torch_available, - is_torchvision_available, is_transformers_available, ) @@ -16,7 +15,7 @@ try: - if not (is_transformers_available() and is_torch_available() and is_torchvision_available()): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: from ...utils import dummy_torch_and_transformers_objects # noqa F403 @@ -27,7 +26,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT: try: - if not (is_transformers_available() and is_torch_available() and is_torchvision_available()): + if not (is_transformers_available() and is_torch_available()): raise OptionalDependencyNotAvailable() except OptionalDependencyNotAvailable: diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index b75929da621c..2447c2673a11 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -19,6 +19,7 @@ from typing import Callable, List, Optional, Tuple, Union import torch +import torch.nn.functional as F from transformers import T5EncoderModel, T5Tokenizer from ...image_processor import VaeImageProcessor @@ -28,7 +29,6 @@ BACKENDS_MAPPING, is_bs4_available, is_ftfy_available, - is_torchvision_available, logging, replace_example_docstring, ) @@ -44,9 +44,6 @@ if is_ftfy_available(): import ftfy -if is_torchvision_available(): - from torchvision import transforms as T - EXAMPLE_DOC_STRING = """ Examples: ```py @@ -544,17 +541,28 @@ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[in @staticmethod def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: - orig_hw = torch.tensor([samples.shape[2], samples.shape[3]]) - custom_hw = torch.tensor([new_height, new_width]) - - if (orig_hw != custom_hw).all(): - ratio = max(custom_hw[0] / orig_hw[0], custom_hw[1] / orig_hw[1]) - resized_width = int(orig_hw[1] * ratio) - resized_height = int(orig_hw[0] * ratio) - transform = T.Compose([T.Resize((resized_height, resized_width)), T.CenterCrop(custom_hw.tolist())]) - return transform(samples) - else: - return samples + orig_height, orig_width = samples.shape[2], samples.shape[3] + torch.tensor([new_height, new_width]) + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Resize + samples = F.interpolate( + samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False + ) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[:, :, start_y:end_y, start_x:end_x] + + return samples @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) diff --git a/src/diffusers/utils/__init__.py b/src/diffusers/utils/__init__.py index 9aea30a23a6c..b4d6bdab33eb 100644 --- a/src/diffusers/utils/__init__.py +++ b/src/diffusers/utils/__init__.py @@ -77,7 +77,6 @@ is_torch_version, is_torch_xla_available, is_torchsde_available, - is_torchvision_available, is_transformers_available, is_transformers_version, is_unidecode_available, diff --git a/src/diffusers/utils/import_utils.py b/src/diffusers/utils/import_utils.py index b4c8471e9d67..b3278af2f6a5 100644 --- a/src/diffusers/utils/import_utils.py +++ b/src/diffusers/utils/import_utils.py @@ -284,13 +284,6 @@ except importlib_metadata.PackageNotFoundError: _peft_available = False -_torchvision_available = importlib.util.find_spec("torchvision") is not None -try: - _torchvision_version = importlib_metadata.version("torchvision") - logger.debug(f"Successfully imported torchvision version {_torchvision_version}") -except importlib_metadata.PackageNotFoundError: - _torchvision_available = False - def is_torch_available(): return _torch_available @@ -384,10 +377,6 @@ def is_peft_available(): return _peft_available -def is_torchvision_available(): - return _torchvision_available - - # docstyle-ignore FLAX_IMPORT_ERROR = """ {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the @@ -496,11 +485,6 @@ def is_torchvision_available(): {0} requires the torchsde library but it was not found in your environment. You can install it with pip: `pip install torchsde` """ -# docstyle-ignore -TORCHVISION_IMPORT_ERROR = """ -{0} requires the torchvision library but it was not found in your environment. You can install it with pip: `pip install torchvision` -""" - # docstyle-ignore INVISIBLE_WATERMARK_IMPORT_ERROR = """ {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0` @@ -528,7 +512,6 @@ def is_torchvision_available(): ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)), ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)), ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)), - ("torchvision", (is_torchvision_available, TORCHVISION_IMPORT_ERROR)), ] ) From 429ed2efd61c346b872657b0486dce93916301ea Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 11 Nov 2023 09:43:08 +0530 Subject: [PATCH 4/6] add :test --- tests/pipelines/pixart/test_pixart.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index a04f4e1a8804..8f869a0191d7 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -189,12 +189,35 @@ def test_inference_non_square_images(self): inputs = self.get_dummy_inputs(device) image = pipe(**inputs, height=32, width=48).images image_slice = image[0, -3:, -3:, -1] - self.assertEqual(image.shape, (1, 32, 48, 3)) + expected_slice = np.array([0.3859, 0.2987, 0.2333, 0.5243, 0.6721, 0.4436, 0.5292, 0.5373, 0.4416]) max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) + def test_resolution_binning(self): + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + image = pipe(**inputs, height=32, width=48).images + image_slice = image[0, -3:, -3:, -1] + + inputs = self.get_dummy_inputs(device) + no_res_binning_image = pipe(**inputs, height=32, width=48, use_resolution_binning=False).images + no_res_binning_image_slice = no_res_binning_image[0, -3:, -3:, -1] + + self.assertEqual(image.shape, (1, 32, 48, 3)) + self.assertEqual(no_res_binning_image.shape, (1, 32, 48, 3)) + + assert np.allclose( + image_slice, no_res_binning_image_slice, atol=1e-3, rtol=1e-3 + ), "Resolution binning should change the results." + def test_inference_with_embeddings_and_multiple_images(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) From be40866ba013dfcf4d0fe6a514cdbed36862ef87 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 11 Nov 2023 09:55:45 +0530 Subject: [PATCH 5/6] remove unused variable --- src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 2447c2673a11..c3f667ba16be 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -542,7 +542,6 @@ def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[in @staticmethod def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor: orig_height, orig_width = samples.shape[2], samples.shape[3] - torch.tensor([new_height, new_width]) # Check if resizing is needed if orig_height != new_height or orig_width != new_width: From fd5a63a48030a92c2073484d0733299837f55c25 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 11 Nov 2023 18:50:59 +0530 Subject: [PATCH 6/6] set resolution_binning to False. --- tests/pipelines/pixart/test_pixart.py | 30 ++++++--------------------- 1 file changed, 6 insertions(+), 24 deletions(-) diff --git a/tests/pipelines/pixart/test_pixart.py b/tests/pipelines/pixart/test_pixart.py index 8f869a0191d7..1fb2560b29b6 100644 --- a/tests/pipelines/pixart/test_pixart.py +++ b/tests/pipelines/pixart/test_pixart.py @@ -89,7 +89,8 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 5.0, - "output_type": "numpy", + "use_resolution_binning": False, + "output_type": "np", } return inputs @@ -120,6 +121,7 @@ def test_save_load_optional_components(self): "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, + "use_resolution_binning": False, } # set all optional components to None @@ -154,6 +156,7 @@ def test_save_load_optional_components(self): "generator": generator, "num_inference_steps": num_inference_steps, "output_type": output_type, + "use_resolution_binning": False, } output_loaded = pipe_loaded(**inputs)[0] @@ -195,29 +198,6 @@ def test_inference_non_square_images(self): max_diff = np.abs(image_slice.flatten() - expected_slice).max() self.assertLessEqual(max_diff, 1e-3) - def test_resolution_binning(self): - device = "cpu" - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - image = pipe(**inputs, height=32, width=48).images - image_slice = image[0, -3:, -3:, -1] - - inputs = self.get_dummy_inputs(device) - no_res_binning_image = pipe(**inputs, height=32, width=48, use_resolution_binning=False).images - no_res_binning_image_slice = no_res_binning_image[0, -3:, -3:, -1] - - self.assertEqual(image.shape, (1, 32, 48, 3)) - self.assertEqual(no_res_binning_image.shape, (1, 32, 48, 3)) - - assert np.allclose( - image_slice, no_res_binning_image_slice, atol=1e-3, rtol=1e-3 - ), "Resolution binning should change the results." - def test_inference_with_embeddings_and_multiple_images(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) @@ -242,6 +222,7 @@ def test_inference_with_embeddings_and_multiple_images(self): "num_inference_steps": num_inference_steps, "output_type": output_type, "num_images_per_prompt": 2, + "use_resolution_binning": False, } # set all optional components to None @@ -277,6 +258,7 @@ def test_inference_with_embeddings_and_multiple_images(self): "num_inference_steps": num_inference_steps, "output_type": output_type, "num_images_per_prompt": 2, + "use_resolution_binning": False, } output_loaded = pipe_loaded(**inputs)[0]