From 654889ae46235234889037b05f9e466b85d1e21c Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Oct 2023 17:19:56 +0200 Subject: [PATCH 01/32] fix: sdxl pipeline when unet is not available. --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 4c1bd857d7cb..3f44d9a6e660 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -158,7 +158,11 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - self.default_sample_size = self.unet.config.sample_size + + if hasattr(self, "unet"): + self.default_sample_size = self.unet.config.sample_size + else: + self.default_sample_size = None add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() From f204066b86afe38fdb477203016aff8323fdb306 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 5 Oct 2023 17:27:53 +0200 Subject: [PATCH 02/32] fix moe --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 3f44d9a6e660..a8d331f5522e 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -159,7 +159,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - if hasattr(self, "unet"): + if hasattr(self, "unet") and self.unet is not None: self.default_sample_size = self.unet.config.sample_size else: self.default_sample_size = None From 1c052635dc80af88af643a31392665e59c4e527a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Oct 2023 09:41:47 +0200 Subject: [PATCH 03/32] account for text --- .../pipeline_stable_diffusion_xl.py | 45 ++++++++++++------- 1 file changed, 28 insertions(+), 17 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index a8d331f5522e..af0867d6644f 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -158,7 +158,7 @@ def __init__( self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - + if hasattr(self, "unet") and self.unet is not None: self.default_sample_size = self.unet.config.sample_size else: @@ -266,16 +266,17 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale - - # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -391,7 +392,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -400,7 +405,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -412,10 +422,11 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds From fe2e9eed50e8103c2bce4d5b8162f23288aeb11a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Fri, 6 Oct 2023 12:05:30 +0200 Subject: [PATCH 04/32] ifx more --- .../pipeline_stable_diffusion_xl.py | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index af0867d6644f..2ad70f619a43 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -539,11 +539,16 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype latents = latents * self.scheduler.init_noise_sigma return latents - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) + if text_encoder_projection_dim is None: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -850,8 +855,16 @@ def __call__( # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = None add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -859,6 +872,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids From e3bf831ed03b3cc8996dd6a80ea8043e671d5aab Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Thu, 12 Oct 2023 18:15:59 +0530 Subject: [PATCH 05/32] don't make unet optional. --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2ad70f619a43..2dc29d60877b 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -159,10 +159,7 @@ def __init__( self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - if hasattr(self, "unet") and self.unet is not None: - self.default_sample_size = self.unet.config.sample_size - else: - self.default_sample_size = None + self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() From 8fe73496c432e419e77b6ea9aa33ce076dafb76d Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 16:18:47 +0530 Subject: [PATCH 06/32] Apply suggestions from code review Co-authored-by: Patrick von Platen --- .../pipeline_stable_diffusion_xl.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 2dc29d60877b..72662e737d8d 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -420,6 +420,15 @@ def encode_prompt( ) if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) @@ -541,8 +550,6 @@ def _get_add_time_ids( ): add_time_ids = list(original_size + crops_coords_top_left + target_size) - if text_encoder_projection_dim is None: - text_encoder_projection_dim = self.text_encoder_2.config.projection_dim passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim @@ -855,7 +862,7 @@ def __call__( if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: - text_encoder_projection_dim = None + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, From 8b0bfda11b381c78a600ea837527a1bc7c8ba590 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 16:25:46 +0530 Subject: [PATCH 07/32] split conditionals. --- .../pipeline_stable_diffusion_xl.py | 26 +++++++++---------- 1 file changed, 12 insertions(+), 14 deletions(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 72662e737d8d..5c1e2ed06281 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -263,16 +263,20 @@ def encode_prompt( # set lora scale so that monkey patched LoRA # function of text encoder can correctly access it - if self.text_encoder is not None: - if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): - self._lora_scale = lora_scale + if lora_scale is not None and isinstance(self, StableDiffusionXLLoraLoaderMixin): + self._lora_scale = lora_scale - # dynamically adjust the LoRA scale + # dynamically adjust the LoRA scale + if self.text_encoder is not None: if not self.use_peft_backend: adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) else: scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -420,18 +424,13 @@ def encode_prompt( ) if self.text_encoder is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(self.text_encoder) - - if self.text_encoder_2 is not None: - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder_2) - + + if self.text_encoder_2 is not None: if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -550,7 +549,6 @@ def _get_add_time_ids( ): add_time_ids = list(original_size + crops_coords_top_left + target_size) - passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) From 9351ea9321c5a8bb5202d7b98b6b2103d1aa9d0a Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 16:28:10 +0530 Subject: [PATCH 08/32] add optional components to sdxl pipeline --- .../stable_diffusion_xl/pipeline_stable_diffusion_xl.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 5c1e2ed06281..dcc91a81c8f7 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -131,6 +131,7 @@ class StableDiffusionXLPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, From 646ecd1f9d398ba6dec9ff870606c17941804534 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 16:31:12 +0530 Subject: [PATCH 09/32] propagate changes to the rest of the pipelines. --- .../pipeline_controlnet_inpaint_sd_xl.py | 43 +++++++++++----- .../controlnet/pipeline_controlnet_sd_xl.py | 49 +++++++++++++------ .../pipeline_controlnet_sd_xl_img2img.py | 43 +++++++++++----- .../pipeline_stable_diffusion_xl_img2img.py | 43 +++++++++++----- .../pipeline_stable_diffusion_xl_inpaint.py | 43 +++++++++++----- ...ne_stable_diffusion_xl_instruct_pix2pix.py | 6 ++- .../pipeline_stable_diffusion_xl_adapter.py | 49 +++++++++++++------ 7 files changed, 198 insertions(+), 78 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 41b0d5434386..1e40a8ac6072 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -316,12 +316,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -437,7 +442,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -446,7 +455,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -458,10 +472,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 7f230c2ec058..ad21eeaaea49 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -285,12 +285,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -406,7 +411,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -415,7 +424,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -427,10 +441,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -706,11 +725,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index aeffc219674d..91396b713a1f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -328,12 +328,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -449,7 +454,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -458,7 +467,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -470,10 +484,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 9612a8e28f8e..5786027c85ea 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -273,12 +273,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -394,7 +399,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -403,7 +412,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -415,10 +429,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 209c9b339aec..51cfbe8423c0 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -422,12 +422,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -543,7 +548,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -552,7 +561,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -564,10 +578,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 6fd1be88b284..5216ee32b7a2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -544,11 +544,13 @@ def prepare_image_latents( return image_latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index b32c852481ab..d0e68b1c78ff 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -283,12 +283,17 @@ def encode_prompt( self._lora_scale = lora_scale # dynamically adjust the LoRA scale - if not self.use_peft_backend: - adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) - adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) - else: - scale_lora_layers(self.text_encoder, lora_scale) - scale_lora_layers(self.text_encoder_2, lora_scale) + if self.text_encoder is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if self.text_encoder_2 is not None: + if not self.use_peft_backend: + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + else: + scale_lora_layers(self.text_encoder_2, lora_scale) prompt = [prompt] if isinstance(prompt, str) else prompt @@ -404,7 +409,11 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + if self.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=self.unet.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) @@ -413,7 +422,12 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + + if self.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.unet.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) @@ -425,10 +439,15 @@ def encode_prompt( bs_embed * num_images_per_prompt, -1 ) - if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: - # Retrieve the original scale by scaling back the LoRA layers - unscale_lora_layers(self.text_encoder) - unscale_lora_layers(self.text_encoder_2) + if self.text_encoder is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder) + + if self.text_encoder_2 is not None: + if isinstance(self, StableDiffusionXLLoraLoaderMixin) and self.use_peft_backend: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder_2) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -543,11 +562,13 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype return latents # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids - def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + def _get_add_time_ids( + self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None + ): add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features From 6b0ae28bfae80096a3ba83bcdbe22d553927c204 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 16:54:38 +0530 Subject: [PATCH 10/32] add: test --- .../test_stable_diffusion_xl.py | 106 +++++++++++++++++- 1 file changed, 104 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 65c7526e3aa2..a9f1c4ad3bba 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -41,6 +41,80 @@ enable_full_determinism() +def encode_prompt(tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None): + device = text_encoders[0].device + + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + if negative_prompt is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + else: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # for classifier-free guidance + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # for classifier-free guidance + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionXLPipeline params = TEXT_TO_IMAGE_PARAMS @@ -113,8 +187,6 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, - # "safety_checker": None, - # "feature_extractor": None, } return components @@ -226,6 +298,36 @@ def test_stable_diffusion_xl_negative_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 + def test_stable_diffusion_xl_without_text_encoders(self): + components = self.get_dummy_components() + inputs = self.get_dummy_inputs(torch_device) + + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + + tokenizers = [tokenizer, tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] + prompt = inputs.pop("prompt") + (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = encode_prompt( + tokenizers, text_encoders, prompt + ) + + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + sd_pipe = StableDiffusionXLPipeline( + **components, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None + ) + output = sd_pipe(**inputs) + image_slice = output.images[0, -3:, -3:, -1].flatten() + + expected_image_slice = np.array([0.5849, 0.6108, 0.4788, 0.5089, 0.5648, 0.4639, 0.5217, 0.5131, 0.4748]) + assert np.allclose(image_slice, expected_image_slice, atol=1e-4) + def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) From d41ddef85253f8deb746d3b501ec6513f870428f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 19:58:54 +0530 Subject: [PATCH 11/32] add to all --- .../controlnet/pipeline_controlnet_inpaint_sd_xl.py | 2 +- .../pipelines/controlnet/pipeline_controlnet_sd_xl.py | 6 +++--- .../controlnet/pipeline_controlnet_sd_xl_img2img.py | 2 +- .../pipeline_stable_diffusion_xl_img2img.py | 3 +-- .../pipeline_stable_diffusion_xl_inpaint.py | 2 +- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 1 + .../t2i_adapter/pipeline_stable_diffusion_xl_adapter.py | 1 + 7 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index 1e40a8ac6072..dc9aa1d7507f 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -167,7 +167,7 @@ class StableDiffusionXLControlNetInpaintPipeline( [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index ad21eeaaea49..523cd7b1bb41 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -139,9 +139,9 @@ class StableDiffusionXLControlNetPipeline( watermark output images. If not defined, it defaults to `True` if the package is installed; otherwise no watermarker is used. """ - model_cpu_offload_seq = ( - "text_encoder->text_encoder_2->unet->vae" # leave controlnet out on purpose because it iterates with unet - ) + # leave controlnet out on purpose because it iterates with unet + model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 91396b713a1f..03950f83202e 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -182,7 +182,7 @@ class StableDiffusionXLControlNetImg2ImgPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 19de838fb607..17abcffca2ec 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -142,8 +142,7 @@ class StableDiffusionXLImg2ImgPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 5623a918e915..42084d06ee97 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -289,7 +289,7 @@ class StableDiffusionXLInpaintPipeline( """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" - _optional_components = ["tokenizer", "text_encoder"] + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 08fc5d1cdcc9..20ee30f747ac 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -150,6 +150,7 @@ class StableDiffusionXLInstructPix2PixPipeline( watermarker will be used. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index d0e68b1c78ff..ef5b2551ace9 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -153,6 +153,7 @@ class StableDiffusionXLAdapterPipeline( Model that extracts features from generated images to be used as inputs for the `safety_checker`. """ model_cpu_offload_seq = "text_encoder->text_encoder_2->unet->vae" + _optional_components = ["tokenizer", "tokenizer_2", "text_encoder", "text_encoder_2"] def __init__( self, From da2185ed8e4ca271638a176a8095200c6586ef94 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Fri, 13 Oct 2023 22:03:13 +0530 Subject: [PATCH 12/32] fix: rest of the pipelines. --- .../pipeline_controlnet_inpaint_sd_xl.py | 17 +++++++++++++++-- .../controlnet/pipeline_controlnet_sd_xl.py | 12 +++++++++++- .../pipeline_controlnet_sd_xl_img2img.py | 10 +++++++++- .../pipeline_stable_diffusion_xl.py | 1 + .../pipeline_stable_diffusion_xl_img2img.py | 9 ++++++++- .../pipeline_stable_diffusion_xl_inpaint.py | 9 ++++++++- ...line_stable_diffusion_xl_instruct_pix2pix.py | 11 ++++++++++- .../pipeline_stable_diffusion_xl_adapter.py | 12 +++++++++++- 8 files changed, 73 insertions(+), 8 deletions(-) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py index dc9aa1d7507f..2027a6d1f29a 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py @@ -903,7 +903,14 @@ def get_timesteps(self, num_inference_steps, strength, device, denoising_start=N return timesteps, num_inference_steps - t_start def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype + self, + original_size, + crops_coords_top_left, + target_size, + aesthetic_score, + negative_aesthetic_score, + dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -913,7 +920,7 @@ def _get_add_time_ids( add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1381,6 +1388,11 @@ def denoising_value_valid(dnv): # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1388,6 +1400,7 @@ def denoising_value_valid(dnv): aesthetic_score, negative_aesthetic_score, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index 523cd7b1bb41..eb64241460f1 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -1081,8 +1081,17 @@ def __call__( target_size = target_size or (height, width) add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: @@ -1091,6 +1100,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py index 03950f83202e..d23a12156718 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py @@ -850,6 +850,7 @@ def _get_add_time_ids( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -861,7 +862,7 @@ def _get_add_time_ids( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1265,6 +1266,12 @@ def __call__( if negative_target_size is None: negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1275,6 +1282,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py index 3c2c95e7155e..9cb60a1f3cab 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py @@ -870,6 +870,7 @@ def __call__( text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py index 17abcffca2ec..2ded9a388506 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py @@ -635,6 +635,7 @@ def _get_add_time_ids( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -646,7 +647,7 @@ def _get_add_time_ids( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -972,6 +973,11 @@ def denoising_value_valid(dnv): negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -982,6 +988,7 @@ def denoising_value_valid(dnv): negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py index 42084d06ee97..4b4f6e3ab5b2 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py @@ -854,6 +854,7 @@ def _get_add_time_ids( negative_crops_coords_top_left, negative_target_size, dtype, + text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) @@ -865,7 +866,7 @@ def _get_add_time_ids( add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( - self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim + self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features @@ -1279,6 +1280,11 @@ def denoising_value_valid(dnv): negative_target_size = target_size add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, @@ -1289,6 +1295,7 @@ def denoising_value_valid(dnv): negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index 20ee30f747ac..298315b25ded 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -846,8 +846,17 @@ def __call__( # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index ef5b2551ace9..04fb75051716 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -915,8 +915,17 @@ def __call__( adapter_state[k] = torch.cat([v] * 2, dim=0) add_text_embeds = pooled_prompt_embeds + if self.text_encoder_2 is None: + text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) + else: + text_encoder_projection_dim = self.text_encoder_2.config.projection_dim + add_time_ids = self._get_add_time_ids( - original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype + original_size, + crops_coords_top_left, + target_size, + dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) if negative_original_size is not None and negative_target_size is not None: negative_add_time_ids = self._get_add_time_ids( @@ -924,6 +933,7 @@ def __call__( negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, + text_encoder_projection_dim=text_encoder_projection_dim, ) else: negative_add_time_ids = add_time_ids From 53fe5adf2ced1de050d40344721f52c747dfc33e Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 14 Oct 2023 07:42:11 +0530 Subject: [PATCH 13/32] use pipeline_class variable --- tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index eae6448d74cd..337081388003 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -320,7 +320,7 @@ def test_stable_diffusion_xl_without_text_encoders(self): inputs["pooled_prompt_embeds"] = pooled_prompt_embeds inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds - sd_pipe = StableDiffusionXLPipeline( + sd_pipe = self.pipeline_class( **components, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None ) output = sd_pipe(**inputs) From 3f412f68e01383f251c2007c037ab2110425b216 Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Sat, 14 Oct 2023 08:00:49 +0530 Subject: [PATCH 14/32] separate pipeline mixin --- .../test_stable_diffusion_xl.py | 108 +------------- tests/pipelines/test_pipelines_common.py | 135 ++++++++++++++++++ 2 files changed, 140 insertions(+), 103 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 337081388003..89aae5bc361f 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -35,87 +35,16 @@ from diffusers.utils.testing_utils import enable_full_determinism, require_torch_gpu, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin enable_full_determinism() -def encode_prompt(tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None): - device = text_encoders[0].device - if isinstance(prompt, str): - prompt = [prompt] - batch_size = len(prompt) - prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - text_inputs = tokenizer( - prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - text_input_ids = text_inputs.input_ids - - prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) - pooled_prompt_embeds = prompt_embeds[0] - prompt_embeds = prompt_embeds.hidden_states[-2] - prompt_embeds_list.append(prompt_embeds) - - prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) - - if negative_prompt is None: - negative_prompt_embeds = torch.zeros_like(prompt_embeds) - negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) - else: - negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt - - negative_prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - uncond_input = tokenizer( - negative_prompt, - padding="max_length", - max_length=tokenizer.model_max_length, - truncation=True, - return_tensors="pt", - ) - - negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True) - negative_pooled_prompt_embeds = negative_prompt_embeds[0] - negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - negative_prompt_embeds_list.append(negative_prompt_embeds) - - negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - - bs_embed, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - - # for classifier-free guidance - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - - pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - - # for classifier-free guidance - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) - return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds - - -class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase): pipeline_class = StableDiffusionXLPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -299,41 +228,14 @@ def test_stable_diffusion_xl_negative_prompt_embeds(self): # make sure that it's equal assert np.abs(image_slice_1.flatten() - image_slice_2.flatten()).max() < 1e-4 - def test_stable_diffusion_xl_without_text_encoders(self): - components = self.get_dummy_components() - inputs = self.get_dummy_inputs(torch_device) - - tokenizer = components.pop("tokenizer") - tokenizer_2 = components.pop("tokenizer_2") - text_encoder = components.pop("text_encoder") - text_encoder_2 = components.pop("text_encoder_2") - - tokenizers = [tokenizer, tokenizer_2] - text_encoders = [text_encoder, text_encoder_2] - prompt = inputs.pop("prompt") - (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = encode_prompt( - tokenizers, text_encoders, prompt - ) - - inputs["prompt_embeds"] = prompt_embeds - inputs["negative_prompt_embeds"] = negative_prompt_embeds - inputs["pooled_prompt_embeds"] = pooled_prompt_embeds - inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds - - sd_pipe = self.pipeline_class( - **components, text_encoder=None, text_encoder_2=None, tokenizer=None, tokenizer_2=None - ) - output = sd_pipe(**inputs) - image_slice = output.images[0, -3:, -3:, -1].flatten() - - expected_image_slice = np.array([0.5849, 0.6108, 0.4788, 0.5089, 0.5648, 0.4639, 0.5217, 0.5131, 0.4748]) - assert np.allclose(image_slice, expected_image_slice, atol=1e-4) - def test_attention_slicing_forward_pass(self): super().test_attention_slicing_forward_pass(expected_max_diff=3e-3) def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() @require_torch_gpu def test_stable_diffusion_xl_offloads(self): diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 6f2674a7b8f6..7f3796d74abb 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -974,6 +974,141 @@ def test_push_to_hub_in_organization(self): delete_repo(self.org_repo_id, token=TOKEN) +# For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders +# and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()` +# test for all such pipelines. This requires us to use a custom `encode_prompt()` function. +class SDXLOptionalComponentsTesterMixin: + def encode_prompt(self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None): + device = text_encoders[0].device + + if isinstance(prompt, str): + prompt = [prompt] + batch_size = len(prompt) + + prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + text_inputs = tokenizer( + prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + text_input_ids = text_inputs.input_ids + + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + pooled_prompt_embeds = prompt_embeds[0] + prompt_embeds = prompt_embeds.hidden_states[-2] + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + if negative_prompt is None: + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + else: + negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt + + negative_prompt_embeds_list = [] + for tokenizer, text_encoder in zip(tokenizers, text_encoders): + uncond_input = tokenizer( + negative_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + negative_prompt_embeds = text_encoder(uncond_input.input_ids.to(device), output_hidden_states=True) + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + bs_embed, seq_len, _ = prompt_embeds.shape + + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # for classifier-free guidance + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + # for classifier-free guidance + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) + + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def _test_save_load_optional_components(self, expected_max_difference=1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + for component in pipe.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe.to(torch_device) + pipe.set_progress_bar_config(disable=None) + + # set all optional components to None + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + + generator_device = "cpu" + inputs = self.get_dummy_inputs(generator_device) + + prompt = inputs.pop("prompt") + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + + tokenizers = [tokenizer, tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] + (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = self.encode_prompt( + tokenizers, text_encoders, prompt + ) + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + + output = pipe(**inputs) + + with tempfile.TemporaryDirectory() as tmpdir: + pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) + for component in pipe_loaded.components.values(): + if hasattr(component, "set_default_attn_processor"): + component.set_default_attn_processor() + pipe_loaded.to(torch_device) + pipe_loaded.set_progress_bar_config(disable=None) + + for optional_component in pipe._optional_components: + self.assertTrue( + getattr(pipe_loaded, optional_component) is None, + f"`{optional_component}` did not stay set to None after loading.", + ) + + inputs = self.get_dummy_inputs(generator_device) + output_loaded = pipe_loaded(**inputs)[0] + + max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() + self.assertLess(max_diff, expected_max_difference) + + + + # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From 04fee728c7469e16027f622398f3a9f6eb6d7bfb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 10:54:21 +0530 Subject: [PATCH 15/32] use safe_serialization --- tests/pipelines/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 7f3796d74abb..f2752ef54d7c 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1086,7 +1086,7 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): output = pipe(**inputs) with tempfile.TemporaryDirectory() as tmpdir: - pipe.save_pretrained(tmpdir, safe_serialization=False) + pipe.save_pretrained(tmpdir) pipe_loaded = self.pipeline_class.from_pretrained(tmpdir) for component in pipe_loaded.components.values(): if hasattr(component, "set_default_attn_processor"): From 0b556d8ca10cdbfca30d00d49a60e66e416a1fdb Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 11:01:48 +0530 Subject: [PATCH 16/32] fix: test --- .../test_stable_diffusion_xl.py | 9 ++++---- tests/pipelines/test_pipelines_common.py | 23 +++++++++++++------ 2 files changed, 20 insertions(+), 12 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 89aae5bc361f..4906670890e8 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -41,10 +41,9 @@ enable_full_determinism() - - - -class StableDiffusionXLPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase): +class StableDiffusionXLPipelineFastTests( + PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS @@ -233,7 +232,7 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) - + def test_save_load_optional_components(self): self._test_save_load_optional_components() diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index f2752ef54d7c..4d37b62584a9 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -976,9 +976,11 @@ def test_push_to_hub_in_organization(self): # For SDXL and its derivative pipelines (such as ControlNet), we have the text encoders # and the tokenizers as optional components. So, we need to override the `test_save_load_optional_components()` -# test for all such pipelines. This requires us to use a custom `encode_prompt()` function. +# test for all such pipelines. This requires us to use a custom `encode_prompt()` function. class SDXLOptionalComponentsTesterMixin: - def encode_prompt(self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None): + def encode_prompt( + self, tokenizers, text_encoders, prompt: str, num_images_per_prompt: int = 1, negative_prompt: str = None + ): device = text_encoders[0].device if isinstance(prompt, str): @@ -1075,9 +1077,12 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): tokenizers = [tokenizer, tokenizer_2] text_encoders = [text_encoder, text_encoder_2] - (prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds) = self.encode_prompt( - tokenizers, text_encoders, prompt - ) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.encode_prompt(tokenizers, text_encoders, prompt) inputs["prompt_embeds"] = prompt_embeds inputs["negative_prompt_embeds"] = negative_prompt_embeds inputs["pooled_prompt_embeds"] = pooled_prompt_embeds @@ -1101,14 +1106,18 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): ) inputs = self.get_dummy_inputs(generator_device) + _ = inputs.pop("prompt") + inputs["prompt_embeds"] = prompt_embeds + inputs["negative_prompt_embeds"] = negative_prompt_embeds + inputs["pooled_prompt_embeds"] = pooled_prompt_embeds + inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds + output_loaded = pipe_loaded(**inputs)[0] max_diff = np.abs(to_np(output) - to_np(output_loaded)).max() self.assertLess(max_diff, expected_max_difference) - - # Some models (e.g. unCLIP) are extremely likely to significantly deviate depending on which hardware is used. # This helper function is used to check that the image doesn't deviate on average more than 10 pixels from a # reference image. From 7aa432ac4f5a271ec7b405c20dceb9dfb2eeb83f Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 11:04:47 +0530 Subject: [PATCH 17/32] access actual output. --- tests/pipelines/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 4d37b62584a9..999930dd2265 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1088,7 +1088,7 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): inputs["pooled_prompt_embeds"] = pooled_prompt_embeds inputs["negative_pooled_prompt_embeds"] = negative_pooled_prompt_embeds - output = pipe(**inputs) + output = pipe(**inputs)[0] with tempfile.TemporaryDirectory() as tmpdir: pipe.save_pretrained(tmpdir) From 91e4e19faf1f678bfbb2d302ff02d7b23a06f0ae Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 11:14:29 +0530 Subject: [PATCH 18/32] add: optional test to adapter and ip2p sdxl pipeline tests/ --- .../test_stable_diffusion_xl_adapter.py | 13 +++++++++++-- ...st_stable_diffusion_xl_instruction_pix2pix.py | 16 ++++++++++++++-- 2 files changed, 25 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 92c22ca2c34c..0e7a13bc876b 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -34,13 +34,19 @@ from diffusers.utils.testing_utils import enable_full_determinism, floats_tensor, torch_device from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS -from ..test_pipelines_common import PipelineTesterMixin, assert_mean_pixel_difference +from ..test_pipelines_common import ( + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + assert_mean_pixel_difference, +) enable_full_determinism() -class StableDiffusionXLAdapterPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLAdapterPipelineFastTests( + PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLAdapterPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS @@ -215,6 +221,9 @@ def test_total_downscale_factor(self, adapter_type): expected_out_image_size, ) + def test_save_load_optional_components(self): + return self._test_save_load_optional_components() + class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py index ca4017d11b79..e20f8a0b54db 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py @@ -36,14 +36,23 @@ TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) enable_full_determinism() class StableDiffusionXLInstructPix2PixPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLInstructPix2PixPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "cross_attention_kwargs"} @@ -175,3 +184,6 @@ def test_latents_input(self): def test_cfg(self): pass + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() From 7d339281be68de9cc2bbdf0a0b4028756b58bf9e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 11:16:39 +0530 Subject: [PATCH 19/32] add optional test to controlnet sdxl. --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 4fff88434bc3..b5e067b6dd80 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -42,6 +42,7 @@ PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, ) @@ -49,7 +50,11 @@ class StableDiffusionXLControlNetPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -179,6 +184,9 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + @require_torch_gpu def test_stable_diffusion_xl_offloads(self): pipes = [] From 30239a9c0e15f786e8f90254b2715efb53092e14 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 19:51:45 +0530 Subject: [PATCH 20/32] fix tests --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 4 ++-- .../stable_diffusion_xl/test_stable_diffusion_xl_img2img.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index b5e067b6dd80..cc9fc79f1642 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -332,7 +332,7 @@ def test_controlnet_sdxl_guess(self): class StableDiffusionXLMultiControlNetPipelineFastTests( - PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase + PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -532,7 +532,7 @@ def test_inference_batch_single_identical(self): class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( - PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLControlNetPipeline params = TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index ba7d3e8be30f..813443dd2d86 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -38,7 +38,7 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin enable_full_determinism() @@ -341,7 +341,7 @@ def test_stable_diffusion_xl_img2img_negative_conditions(self): class StableDiffusionXLImg2ImgRefinerOnlyPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} From 55c22f93b42738c81e169f2c2182660354f24e3a Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 19:53:47 +0530 Subject: [PATCH 21/32] fix ip2p tests --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b6d06ed421a3..d3d8b4471038 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -402,7 +402,8 @@ def encode_prompt( negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + prompt_embeds_dtype = self.text_encoder_2.dtype if self.text_encoder_2 is not None else self.unet.dtype + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) bs_embed, seq_len, _ = prompt_embeds.shape # duplicate text embeddings for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) From 40f44b8f8beec7879711f884e6cc33bea81ab28e Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 19:55:29 +0530 Subject: [PATCH 22/32] fix more --- .../pipeline_stable_diffusion_xl_instruct_pix2pix.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index d3d8b4471038..0427214f8374 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -412,7 +412,7 @@ def encode_prompt( if do_classifier_free_guidance: # duplicate unconditional embeddings for each generation per prompt, using mps friendly method seq_len = negative_prompt_embeds.shape[1] - negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) From fe107e8895d7535b1ff89e1e6f1469a17a1f71ee Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 20:29:30 +0530 Subject: [PATCH 23/32] fifx more. --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 3 +++ .../stable_diffusion_xl/test_stable_diffusion_xl_img2img.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index cc9fc79f1642..8b6578b3fdee 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -710,6 +710,9 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + self._test_save_load_optional_components() + def test_negative_conditions(self): components = self.get_dummy_components() pipe = self.pipeline_class(**components) diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 813443dd2d86..97c19108947f 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -600,3 +600,6 @@ def test_attention_slicing_forward_pass(self): def test_inference_batch_single_identical(self): super().test_inference_batch_single_identical(expected_max_diff=3e-3) + + def test_save_load_optional_components(self): + self._test_save_load_optional_components() From 66e71be0cd12e7ba53469bc6e528fe19021721c9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sat, 14 Oct 2023 20:33:41 +0530 Subject: [PATCH 24/32] use np output type. --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 8b6578b3fdee..e41ba970f1a2 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -478,7 +478,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": images, } @@ -654,7 +654,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", "image": images, } From d9f9d6d81e9659de8e826af7e1fb6faf4c11c6e0 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 15 Oct 2023 08:14:06 +0530 Subject: [PATCH 25/32] fix for StableDiffusionXLMultiControlNetPipelineFastTests. --- tests/pipelines/controlnet/test_controlnet_sdxl.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index e41ba970f1a2..be786ebe3000 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -530,6 +530,9 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_save_load_optional_components(self): + return self._test_save_load_optional_components() + class StableDiffusionXLMultiControlNetOneModelPipelineFastTests( PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase From 6cade1a4bdf37531fb89599f54f030abff5f9655 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Sun, 15 Oct 2023 08:28:00 +0530 Subject: [PATCH 26/32] fix: SDXLOptionalComponentsTesterMixin --- tests/pipelines/test_pipelines_common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 999930dd2265..0c65a134b7f2 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1075,8 +1075,8 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): text_encoder = components.pop("text_encoder") text_encoder_2 = components.pop("text_encoder_2") - tokenizers = [tokenizer, tokenizer_2] - text_encoders = [text_encoder, text_encoder_2] + tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] + text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] ( prompt_embeds, negative_prompt_embeds, From 773ca8678c5213a159bd960ce0964281479cc48f Mon Sep 17 00:00:00 2001 From: Sayak Paul Date: Mon, 16 Oct 2023 19:05:24 +0530 Subject: [PATCH 27/32] Apply suggestions from code review Co-authored-by: Patrick von Platen --- tests/pipelines/test_pipelines_common.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 0c65a134b7f2..16d606cac05b 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1055,6 +1055,13 @@ def encode_prompt( def _test_save_load_optional_components(self, expected_max_difference=1e-4): components = self.get_dummy_components() + + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + + components = {k: v if k not in pipe._optional_components else None for k, v in components.items()} pipe = self.pipeline_class(**components) for component in pipe.components.values(): if hasattr(component, "set_default_attn_processor"): @@ -1062,19 +1069,9 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): pipe.to(torch_device) pipe.set_progress_bar_config(disable=None) - # set all optional components to None - for optional_component in pipe._optional_components: - setattr(pipe, optional_component, None) - generator_device = "cpu" inputs = self.get_dummy_inputs(generator_device) - prompt = inputs.pop("prompt") - tokenizer = components.pop("tokenizer") - tokenizer_2 = components.pop("tokenizer_2") - text_encoder = components.pop("text_encoder") - text_encoder_2 = components.pop("text_encoder_2") - tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] ( From 5aaa23eb98f2adb81c2ba380801e2d55da90e809 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Oct 2023 19:08:04 +0530 Subject: [PATCH 28/32] fix tests --- tests/pipelines/test_pipelines_common.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 16d606cac05b..2c91fc315834 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1055,14 +1055,14 @@ def encode_prompt( def _test_save_load_optional_components(self, expected_max_difference=1e-4): components = self.get_dummy_components() - + tokenizer = components.pop("tokenizer") tokenizer_2 = components.pop("tokenizer_2") text_encoder = components.pop("text_encoder") text_encoder_2 = components.pop("text_encoder_2") - - components = {k: v if k not in pipe._optional_components else None for k, v in components.items()} + pipe = self.pipeline_class(**components) + components = {k: v if k not in pipe._optional_components else None for k, v in components.items()} for component in pipe.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() @@ -1074,6 +1074,7 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] + prompt = components.pop("prompt") ( prompt_embeds, negative_prompt_embeds, From 1ddd5867df6ef9c0f74fad78371f715b65d63df9 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Oct 2023 19:22:33 +0530 Subject: [PATCH 29/32] Empty-Commit From 38e16f8f49d3e5ba49e2ab15a6f78ccc2a0f3ae8 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Oct 2023 19:45:38 +0530 Subject: [PATCH 30/32] revert previous --- tests/pipelines/test_pipelines_common.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 2c91fc315834..79815dced72e 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1056,13 +1056,10 @@ def encode_prompt( def _test_save_load_optional_components(self, expected_max_difference=1e-4): components = self.get_dummy_components() - tokenizer = components.pop("tokenizer") - tokenizer_2 = components.pop("tokenizer_2") - text_encoder = components.pop("text_encoder") - text_encoder_2 = components.pop("text_encoder_2") - pipe = self.pipeline_class(**components) - components = {k: v if k not in pipe._optional_components else None for k, v in components.items()} + for optional_component in pipe._optional_components: + setattr(pipe, optional_component, None) + for component in pipe.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() @@ -1072,6 +1069,11 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): generator_device = "cpu" inputs = self.get_dummy_inputs(generator_device) + tokenizer = components.pop("tokenizer") + tokenizer_2 = components.pop("tokenizer_2") + text_encoder = components.pop("text_encoder") + text_encoder_2 = components.pop("text_encoder_2") + tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] prompt = components.pop("prompt") From f5748dcc8e52f94dbb55b5aab23dc515475444c1 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 16 Oct 2023 19:55:03 +0530 Subject: [PATCH 31/32] quality --- tests/pipelines/test_pipelines_common.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 79815dced72e..679250c4af54 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1059,7 +1059,7 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): pipe = self.pipeline_class(**components) for optional_component in pipe._optional_components: setattr(pipe, optional_component, None) - + for component in pipe.components.values(): if hasattr(component, "set_default_attn_processor"): component.set_default_attn_processor() From e30b3e58740593372f79b514edd746c2110bbfd4 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 17 Oct 2023 10:53:03 +0530 Subject: [PATCH 32/32] fix: test --- .../pipelines/versatile_diffusion/modeling_text_unet.py | 2 ++ tests/pipelines/test_pipelines_common.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py index a70903b4bd74..717db3bbdb34 100644 --- a/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/versatile_diffusion/modeling_text_unet.py @@ -5,6 +5,8 @@ import torch.nn as nn import torch.nn.functional as F +from diffusers.utils import deprecate + from ...configuration_utils import ConfigMixin, register_to_config from ...models import ModelMixin from ...models.activations import get_activation diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 679250c4af54..ae13d0d3e9fa 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -1076,7 +1076,7 @@ def _test_save_load_optional_components(self, expected_max_difference=1e-4): tokenizers = [tokenizer, tokenizer_2] if tokenizer is not None else [tokenizer_2] text_encoders = [text_encoder, text_encoder_2] if text_encoder is not None else [text_encoder_2] - prompt = components.pop("prompt") + prompt = inputs.pop("prompt") ( prompt_embeds, negative_prompt_embeds,