diff --git a/docs/source/en/api/pipelines/pix2pix.md b/docs/source/en/api/pipelines/pix2pix.md index 08990048e80b..f921922e4bb5 100644 --- a/docs/source/en/api/pipelines/pix2pix.md +++ b/docs/source/en/api/pipelines/pix2pix.md @@ -35,4 +35,12 @@ Make sure to check out the Schedulers [guide](/using-diffusers/schedulers) to le - save_lora_weights ## StableDiffusionPipelineOutput -[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput \ No newline at end of file +[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput + +## StableDiffusionXLInstructPix2PixPipeline +[[autodoc]] StableDiffusionXLInstructPix2PixPipeline + - __call__ + - all + +## StableDiffusionXLPipelineOutput +[[autodoc]] pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput \ No newline at end of file diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py index b7633acaffa4..fe9fc1a53d32 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py @@ -28,6 +28,7 @@ LoRAXFormersAttnProcessor, XFormersAttnProcessor, ) +from ...models.lora import adjust_lora_scale_text_encoder from ...schedulers import KarrasDiffusionSchedulers from ...utils import ( deprecate, @@ -36,6 +37,7 @@ is_invisible_watermark_available, logging, randn_tensor, + replace_example_docstring, ) from ..pipeline_utils import DiffusionPipeline from . import StableDiffusionXLPipelineOutput @@ -47,6 +49,36 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionXLInstructPix2PixPipeline + >>> from diffusers.utils import load_image + + >>> resolution = 768 + >>> image = load_image( + ... "https://hf.co/datasets/diffusers/diffusers-images-docs/resolve/main/mountain.png" + ... ).resize((resolution, resolution)) + >>> edit_instruction = "Turn sky into a cloudy one" + + >>> pipe = StableDiffusionXLInstructPix2PixPipeline.from_pretrained( + ... "diffusers/sdxl-instructpix2pix-768", torch_dtype=torch.float16 + ... ).to("cuda") + + >>> edited_image = pipe( + ... prompt=edit_instruction, + ... image=image, + ... height=resolution, + ... width=resolution, + ... guidance_scale=3.0, + ... image_guidance_scale=1.5, + ... num_inference_steps=30, + ... ).images[0] + >>> edited_image + ``` +""" + def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ @@ -121,7 +153,6 @@ def __init__( tokenizer_2: CLIPTokenizer, unet: UNet2DConditionModel, scheduler: KarrasDiffusionSchedulers, - requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True, add_watermarker: Optional[bool] = None, ): @@ -137,11 +168,9 @@ def __init__( scheduler=scheduler, ) self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) - self.register_to_config(requires_aesthetics_score=requires_aesthetics_score) self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) - - self.vae.config.force_upcast = True # force the VAE to be in float32 mode, as it overflows in float16 + self.default_sample_size = self.unet.config.sample_size add_watermarker = add_watermarker if add_watermarker is not None else is_invisible_watermark_available() @@ -213,13 +242,16 @@ def enable_model_cpu_offload(self, gpu_id=0): # We'll offload the last model manually. self.final_offload_hook = hook + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_prompt def encode_prompt( self, - prompt, + prompt: str, + prompt_2: Optional[str] = None, device: Optional[torch.device] = None, num_images_per_prompt: int = 1, do_classifier_free_guidance: bool = True, - negative_prompt=None, + negative_prompt: Optional[str] = None, + negative_prompt_2: Optional[str] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, @@ -230,8 +262,11 @@ def encode_prompt( Encodes the prompt into text encoder hidden states. Args: - prompt (`str` or `List[str]`, *optional*): + prompt (`str` or `List[str]`, *optional*): prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders device: (`torch.device`): torch device num_images_per_prompt (`int`): @@ -242,6 +277,9 @@ def encode_prompt( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders prompt_embeds (`torch.FloatTensor`, *optional*): Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not provided, text embeddings will be generated from `prompt` input argument. @@ -266,6 +304,10 @@ def encode_prompt( if lora_scale is not None and isinstance(self, LoraLoaderMixin): self._lora_scale = lora_scale + # dynamically adjust the LoRA scale + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + adjust_lora_scale_text_encoder(self.text_encoder_2, lora_scale) + if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): @@ -280,9 +322,11 @@ def encode_prompt( ) if prompt_embeds is None: + prompt_2 = prompt_2 or prompt # textual inversion: procecss multi-vector tokens if necessary prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): + prompts = [prompt, prompt_2] + for prompt, tokenizer, text_encoder in zip(prompts, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): prompt = self.maybe_convert_prompt(prompt, tokenizer) @@ -293,6 +337,7 @@ def encode_prompt( truncation=True, return_tensors="pt", ) + text_input_ids = text_inputs.input_ids untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids @@ -314,11 +359,6 @@ def encode_prompt( pooled_prompt_embeds = prompt_embeds[0] prompt_embeds = prompt_embeds.hidden_states[-2] - bs_embed, seq_len, _ = prompt_embeds.shape - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) - prompt_embeds_list.append(prompt_embeds) prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) @@ -330,6 +370,8 @@ def encode_prompt( negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) elif do_classifier_free_guidance and negative_prompt_embeds is None: negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + uncond_tokens: List[str] if prompt is not None and type(prompt) is not type(negative_prompt): raise TypeError( @@ -337,7 +379,7 @@ def encode_prompt( f" {type(prompt)}." ) elif isinstance(negative_prompt, str): - uncond_tokens = [negative_prompt] + uncond_tokens = [negative_prompt, negative_prompt_2] elif batch_size != len(negative_prompt): raise ValueError( f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" @@ -345,17 +387,16 @@ def encode_prompt( " the batch size of `prompt`." ) else: - uncond_tokens = negative_prompt + uncond_tokens = [negative_prompt, negative_prompt_2] negative_prompt_embeds_list = [] - for tokenizer, text_encoder in zip(tokenizers, text_encoders): - # textual inversion: procecss multi-vector tokens if necessary + for negative_prompt, tokenizer, text_encoder in zip(uncond_tokens, tokenizers, text_encoders): if isinstance(self, TextualInversionLoaderMixin): - uncond_tokens = self.maybe_convert_prompt(uncond_tokens, tokenizer) + negative_prompt = self.maybe_convert_prompt(negative_prompt, tokenizer) max_length = prompt_embeds.shape[1] uncond_input = tokenizer( - uncond_tokens, + negative_prompt, padding="max_length", max_length=max_length, truncation=True, @@ -370,32 +411,30 @@ def encode_prompt( negative_pooled_prompt_embeds = negative_prompt_embeds[0] negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] - if do_classifier_free_guidance: - # duplicate unconditional embeddings for each generation per prompt, using mps friendly method - seq_len = negative_prompt_embeds.shape[1] - - negative_prompt_embeds = negative_prompt_embeds.to(dtype=text_encoder.dtype, device=device) - - negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) - negative_prompt_embeds = negative_prompt_embeds.view( - batch_size * num_images_per_prompt, seq_len, -1 - ) - - # For classifier free guidance, we need to do two forward passes. - # Here we concatenate the unconditional and text embeddings into a single batch - # to avoid doing two forward passes - negative_prompt_embeds_list.append(negative_prompt_embeds) negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) - bs_embed = pooled_prompt_embeds.shape[0] + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder_2.dtype, device=device) + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( bs_embed * num_images_per_prompt, -1 ) - negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( - bs_embed * num_images_per_prompt, -1 - ) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1 + ) return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds @@ -417,15 +456,7 @@ def prepare_extra_step_kwargs(self, generator, eta): extra_step_kwargs["generator"] = generator return extra_step_kwargs - def get_timesteps(self, num_inference_steps, strength, device): - # get the original timestep using init_timestep - init_timestep = min(int(num_inference_steps * strength), num_inference_steps) - - t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] - - return timesteps, num_inference_steps - t_start - + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_instruct_pix2pix.StableDiffusionInstructPix2PixPipeline.check_inputs def check_inputs( self, prompt, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None ): @@ -463,6 +494,7 @@ def check_inputs( f" {negative_prompt_embeds.shape}." ) + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None): shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) if isinstance(generator, list) and len(generator) != batch_size: @@ -496,9 +528,9 @@ def prepare_image_latents( image_latents = image else: # make sure the VAE is in float32 mode, as it overflows in float16 - if self.vae.config.force_upcast: - image = image.float() - self.vae.to(dtype=torch.float32) + if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: + self.upcast_vae() + image = image.to(next(iter(self.vae.post_quant_conv.parameters())).dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -536,45 +568,24 @@ def prepare_image_latents( return image_latents - def _get_add_time_ids( - self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, dtype - ): - if self.config.requires_aesthetics_score: - add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) - add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,)) - else: - add_time_ids = list(original_size + crops_coords_top_left + target_size) - add_neg_time_ids = list(original_size + crops_coords_top_left + target_size) + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._get_add_time_ids + def _get_add_time_ids(self, original_size, crops_coords_top_left, target_size, dtype): + add_time_ids = list(original_size + crops_coords_top_left + target_size) passed_add_embed_dim = ( self.unet.config.addition_time_embed_dim * len(add_time_ids) + self.text_encoder_2.config.projection_dim ) expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features - if ( - expected_add_embed_dim > passed_add_embed_dim - and (expected_add_embed_dim - passed_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." - ) - elif ( - expected_add_embed_dim < passed_add_embed_dim - and (passed_add_embed_dim - expected_add_embed_dim) == self.unet.config.addition_time_embed_dim - ): - raise ValueError( - f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." - ) - elif expected_add_embed_dim != passed_add_embed_dim: + if expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) - add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) - - return add_time_ids, add_neg_time_ids + return add_time_ids + # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae def upcast_vae(self): dtype = self.vae.dtype self.vae.to(dtype=torch.float32) @@ -595,14 +606,20 @@ def upcast_vae(self): self.vae.decoder.mid_block.to(dtype) @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( self, prompt: Union[str, List[str]] = None, + prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, num_inference_steps: int = 100, - guidance_scale: float = 7.5, + denoising_end: Optional[float] = None, + guidance_scale: float = 5.0, image_guidance_scale: float = 1.5, negative_prompt: Optional[Union[str, List[str]]] = None, + negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, @@ -620,8 +637,6 @@ def __call__( original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, - aesthetic_score: float = 6.0, - negative_aesthetic_score: float = 2.5, ): r""" Function invoked when calling the pipeline for generation. @@ -630,12 +645,26 @@ def __call__( prompt (`str` or `List[str]`, *optional*): The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. instead. + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders image (`torch.FloatTensor` or `PIL.Image.Image` or `np.ndarray` or `List[torch.FloatTensor]` or `List[PIL.Image.Image]` or `List[np.ndarray]`): The image(s) to modify with the pipeline. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the expense of slower inference. - guidance_scale (`float`, *optional*, defaults to 7.5): + denoising_end (`float`, *optional*): + When specified, determines the fraction (between 0.0 and 1.0) of the total denoising process to be + completed before it is intentionally prematurely terminated. As a result, the returned sample will + still retain a substantial amount of noise as determined by the discrete timesteps selected by the + scheduler. The denoising_end parameter should ideally be utilized when this pipeline forms a part of a + "Mixture of Denoisers" multi-pipeline setup, as elaborated in [**Refining the Image + Output**](https://huggingface.co/docs/diffusers/api/pipelines/stable_diffusion/stable_diffusion_xl#refining-the-image-output) + guidance_scale (`float`, *optional*, defaults to 5.0): Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > @@ -650,6 +679,9 @@ def __call__( The prompt or prompts not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders. num_images_per_prompt (`int`, *optional*, defaults to 1): The number of images to generate per prompt. eta (`float`, *optional*, defaults to 0.0): @@ -698,25 +730,34 @@ def __call__( [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when using zero terminal SNR. original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - TODO + If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled. + `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as + explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). crops_coords_top_left (`Tuple[int]`, *optional*, defaults to (0, 0)): - TODO + `crops_coords_top_left` can be used to generate an image that appears to be "cropped" from the position + `crops_coords_top_left` downwards. Favorable, well-centered images are usually achieved by setting + `crops_coords_top_left` to (0, 0). Part of SDXL's micro-conditioning as explained in section 2.2 of + [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). target_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)): - TODO - aesthetic_score (`float`, *optional*, defaults to 6.0): - TODO - negative_aesthetic_score (`float`, *optional*, defaults to 2.5): - TDOO + For most cases, `target_size` should be set to the desired height and width of the generated image. If + not specified it will default to `(width, height)`. Part of SDXL's micro-conditioning as explained in + section 2.2 of [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952). Examples: Returns: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] or `tuple`: - [`~pipelines.stable_diffusion.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a - `tuple. When returning a tuple, the first element is a list with the generated images, and the second - element is a list of `bool`s denoting whether the corresponding generated image likely represents - "not-safe-for-work" (nsfw) content, according to the `safety_checker`. + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] if `return_dict` is True, otherwise a + `tuple`. When returning a tuple, the first element is a list with the generated images. """ + # 0. Default height and width to unet + height = height or self.default_sample_size * self.vae_scale_factor + width = width or self.default_sample_size * self.vae_scale_factor + + original_size = original_size or (height, width) + target_size = target_size or (height, width) + # 1. Check inputs. Raise error if not correct self.check_inputs(prompt, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) @@ -750,11 +791,13 @@ def __call__( pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( - prompt, - device, - num_images_per_prompt, - do_classifier_free_guidance, - negative_prompt, + prompt=prompt, + prompt_2=prompt_2, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, @@ -780,10 +823,6 @@ def __call__( generator, ) - height, width = image_latents.shape[-2:] - height = height * self.vae_scale_factor - width = width * self.vae_scale_factor - # 7. Prepare latent variables num_channels_latents = self.vae.config.latent_channels latents = self.prepare_latents( @@ -811,47 +850,40 @@ def __call__( # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) - original_size = original_size or (height, width) - target_size = target_size or (height, width) - # 10. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds - add_time_ids, add_neg_time_ids = self._get_add_time_ids( - original_size, - crops_coords_top_left, - target_size, - aesthetic_score, - negative_aesthetic_score, - dtype=prompt_embeds.dtype, + add_time_ids = self._get_add_time_ids( + original_size, crops_coords_top_left, target_size, dtype=prompt_embeds.dtype ) - add_time_ids = add_time_ids.repeat(batch_size * num_images_per_prompt, 1) - - original_prompt_embeds_len = len(prompt_embeds) - original_add_text_embeds_len = len(add_text_embeds) - original_add_time_ids = len(add_time_ids) if do_classifier_free_guidance: - prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds], dim=0) - add_text_embeds = torch.cat([add_text_embeds, negative_pooled_prompt_embeds], dim=0) - add_neg_time_ids = add_neg_time_ids.repeat(batch_size * num_images_per_prompt, 1) - add_time_ids = torch.cat([add_time_ids, add_neg_time_ids], dim=0) - - # Make dimensions consistent - add_text_embeds = torch.concat((add_text_embeds, add_text_embeds[:original_add_text_embeds_len]), dim=0) - add_time_ids = torch.concat((add_time_ids, add_time_ids.clone()[:original_add_time_ids]), dim=0) - prompt_embeds = torch.concat((prompt_embeds, prompt_embeds.clone()[:original_prompt_embeds_len]), dim=0) + # The extra concat similar to how it's done in SD InstructPix2Pix. + prompt_embeds = torch.cat([prompt_embeds, negative_prompt_embeds, negative_prompt_embeds], dim=0) + add_text_embeds = torch.cat( + [add_text_embeds, negative_pooled_prompt_embeds, negative_pooled_prompt_embeds], dim=0 + ) + add_time_ids = torch.cat([add_time_ids, add_time_ids, add_time_ids], dim=0) - prompt_embeds = prompt_embeds.to(device).to(torch.float32) - add_text_embeds = add_text_embeds.to(device).to(torch.float32) - add_time_ids = add_time_ids.to(device) + prompt_embeds = prompt_embeds.to(device) + add_text_embeds = add_text_embeds.to(device) + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) # 11. Denoising loop - self.unet = self.unet.to(torch.float32) - num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0) + if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: + discrete_timestep_cutoff = int( + round( + self.scheduler.config.num_train_timesteps + - (denoising_end * self.scheduler.config.num_train_timesteps) + ) + ) + num_inference_steps = len(list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) + timesteps = timesteps[:num_inference_steps] + with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # Expand the latents if we are doing classifier free guidance. - # The latents are expanded 3 times because for pix2pix the guidance\ + # The latents are expanded 3 times because for pix2pix the guidance # is applied for both the text and the input image. latent_model_input = torch.cat([latents] * 3) if do_classifier_free_guidance else latents diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py index bbb0fe698087..2608886ded98 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_instruction_pix2pix.py @@ -68,7 +68,7 @@ def get_dummy_components(self): addition_embed_type="text_time", addition_time_embed_dim=8, transformer_layers_per_block=(1, 2), - projection_class_embeddings_input_dim=72, # 5 * 8 + 32 + projection_class_embeddings_input_dim=80, # 5 * 8 + 32 cross_attention_dim=64, ) @@ -118,12 +118,11 @@ def get_dummy_components(self): "tokenizer": tokenizer, "text_encoder_2": text_encoder_2, "tokenizer_2": tokenizer_2, - "requires_aesthetics_score": True, } return components def get_dummy_inputs(self, device, seed=0): - image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) + image = floats_tensor((1, 3, 64, 64), rng=random.Random(seed)).to(device) image = image / 2 + 0.5 if str(device).startswith("mps"): generator = torch.manual_seed(seed) @@ -142,7 +141,6 @@ def get_dummy_inputs(self, device, seed=0): def test_components_function(self): init_components = self.get_dummy_components() - init_components.pop("requires_aesthetics_score") pipe = self.pipeline_class(**init_components) self.assertTrue(hasattr(pipe, "components"))