diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index 04ca51b19f05..810a6c8a97de 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -726,6 +726,46 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -863,6 +903,8 @@ def __call__( control_guidance_end, ) + self._guidance_scale = guidance_scale + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -872,10 +914,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -895,7 +933,7 @@ def __call__( prompt, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -905,7 +943,7 @@ def __call__( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare image @@ -918,7 +956,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] @@ -934,7 +972,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -962,6 +1000,14 @@ def __call__( latents, ) + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -986,11 +1032,11 @@ def __call__( if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1017,7 +1063,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1029,6 +1075,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, @@ -1036,7 +1083,7 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py index d6278c4f046a..4a54957af5a5 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py @@ -791,6 +791,46 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -986,6 +1026,8 @@ def __call__( control_guidance_end, ) + self._guidance_scale = guidance_scale + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -995,10 +1037,6 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float): controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets) @@ -1024,7 +1062,7 @@ def __call__( prompt_2, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, negative_prompt_2, prompt_embeds=prompt_embeds, @@ -1045,7 +1083,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) height, width = image.shape[-2:] @@ -1061,7 +1099,7 @@ def __call__( num_images_per_prompt=num_images_per_prompt, device=device, dtype=controlnet.dtype, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, guess_mode=guess_mode, ) @@ -1089,6 +1127,14 @@ def __call__( latents, ) + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) @@ -1133,7 +1179,7 @@ def __call__( else: negative_add_time_ids = add_time_ids - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) @@ -1154,13 +1200,13 @@ def __call__( if (is_unet_compiled and is_controlnet_compiled) and is_torch_higher_equal_2_1: torch._inductor.cudagraph_mark_step_begin() # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) added_cond_kwargs = {"text_embeds": add_text_embeds, "time_ids": add_time_ids} # controlnet(s) inference - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infer ControlNet only for the conditional batch. control_model_input = latents control_model_input = self.scheduler.scale_model_input(control_model_input, t) @@ -1193,7 +1239,7 @@ def __call__( return_dict=False, ) - if guess_mode and do_classifier_free_guidance: + if guess_mode and self.do_classifier_free_guidance: # Infered ControlNet only for the conditional batch. # To apply the output of ControlNet to both the unconditional and conditional batches, # add 0 to the unconditional batch to keep it unchanged. @@ -1205,6 +1251,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_block_additional_residuals=down_block_res_samples, mid_block_additional_residual=mid_block_res_sample, @@ -1213,7 +1260,7 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py index 103cd7095291..2f65b6cd391b 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py @@ -610,6 +610,46 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -723,6 +763,8 @@ def __call__( prompt, height, width, callback_steps, image, negative_prompt, prompt_embeds, negative_prompt_embeds ) + self._guidance_scale = guidance_scale + if isinstance(self.adapter, MultiAdapter): adapter_input = [] @@ -742,17 +784,12 @@ def __call__( else: batch_size = prompt_embeds.shape[0] - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt prompt_embeds, negative_prompt_embeds = self.encode_prompt( prompt, device, num_images_per_prompt, - do_classifier_free_guidance, + self.do_classifier_free_guidance, negative_prompt, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, @@ -761,7 +798,7 @@ def __call__( # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch # to avoid doing two forward passes - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) # 4. Prepare timesteps @@ -784,6 +821,14 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + # 7. Denoising loop if isinstance(self.adapter, MultiAdapter): adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) @@ -796,7 +841,7 @@ def __call__( if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: for k, v in enumerate(adapter_state): adapter_state[k] = torch.cat([v] * 2, dim=0) @@ -804,7 +849,7 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual @@ -812,13 +857,14 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, down_intrablock_additional_residuals=[state.clone() for state in adapter_state], return_dict=False, )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) diff --git a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py index a163a27a8bd7..8b676b8ad964 100644 --- a/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +++ b/src/diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py @@ -670,6 +670,46 @@ def disable_freeu(self): """Disables the FreeU mechanism if enabled.""" self.unet.disable_freeu() + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + @torch.no_grad() @replace_example_docstring(EXAMPLE_DOC_STRING) def __call__( @@ -882,6 +922,8 @@ def __call__( negative_pooled_prompt_embeds, ) + self._guidance_scale = guidance_scale + # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 @@ -892,11 +934,6 @@ def __call__( device = self._execution_device - # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) - # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` - # corresponds to doing no classifier free guidance. - do_classifier_free_guidance = guidance_scale > 1.0 - # 3. Encode input prompt ( prompt_embeds, @@ -908,7 +945,7 @@ def __call__( prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, - do_classifier_free_guidance=do_classifier_free_guidance, + do_classifier_free_guidance=self.do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, @@ -939,6 +976,14 @@ def __call__( # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + # 7. Prepare added time ids & embeddings & adapter features if isinstance(self.adapter, MultiAdapter): adapter_state = self.adapter(adapter_input, adapter_conditioning_scale) @@ -951,7 +996,7 @@ def __call__( if num_images_per_prompt > 1: for k, v in enumerate(adapter_state): adapter_state[k] = v.repeat(num_images_per_prompt, 1, 1, 1) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: for k, v in enumerate(adapter_state): adapter_state[k] = torch.cat([v] * 2, dim=0) @@ -979,7 +1024,7 @@ def __call__( else: negative_add_time_ids = add_time_ids - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) @@ -1005,7 +1050,7 @@ def __call__( with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) @@ -1021,6 +1066,7 @@ def __call__( latent_model_input, t, encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, cross_attention_kwargs=cross_attention_kwargs, added_cond_kwargs=added_cond_kwargs, return_dict=False, @@ -1028,11 +1074,11 @@ def __call__( )[0] # perform guidance - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) - if do_classifier_free_guidance and guidance_rescale > 0.0: + if self.do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 64baeea910b8..2d8c8869c23c 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -27,6 +27,7 @@ ControlNetModel, DDIMScheduler, EulerDiscreteScheduler, + LCMScheduler, StableDiffusionControlNetPipeline, UNet2DConditionModel, ) @@ -116,7 +117,7 @@ class ControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - def get_dummy_components(self): + def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(4, 8), @@ -128,6 +129,7 @@ def get_dummy_components(self): up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, norm_num_groups=1, + time_cond_proj_dim=time_cond_proj_dim, ) torch.manual_seed(0) controlnet = ControlNetModel( @@ -221,6 +223,28 @@ def test_xformers_attention_forwardGenerator_pass(self): def test_inference_batch_single_identical(self): self._test_inference_batch_single_identical(expected_max_diff=2e-3) + def test_controlnet_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionControlNetPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array( + [0.52700454, 0.3930534, 0.25509018, 0.7132304, 0.53696585, 0.46568912, 0.7095368, 0.7059624, 0.4744786] + ) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index be786ebe3000..36ddee36eb52 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -24,6 +24,7 @@ AutoencoderKL, ControlNetModel, EulerDiscreteScheduler, + LCMScheduler, StableDiffusionXLControlNetPipeline, UNet2DConditionModel, ) @@ -62,7 +63,7 @@ class StableDiffusionXLControlNetPipelineFastTests( image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS - def get_dummy_components(self): + def get_dummy_components(self, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -80,6 +81,7 @@ def get_dummy_components(self): transformer_layers_per_block=(1, 2), projection_class_embeddings_input_dim=80, # 6 * 8 + 32 cross_attention_dim=64, + time_cond_proj_dim=time_cond_proj_dim, ) torch.manual_seed(0) controlnet = ControlNetModel( @@ -330,6 +332,26 @@ def test_controlnet_sdxl_guess(self): # make sure that it's equal assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-4 + def test_controlnet_sdxl_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLControlNetPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.7799, 0.614, 0.6162, 0.7082, 0.6662, 0.5833, 0.4148, 0.5182, 0.4866]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionXLMultiControlNetPipelineFastTests( PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py index 2dcfb9d3612d..2252c8ef8e99 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_adapter.py @@ -25,6 +25,7 @@ import diffusers from diffusers import ( AutoencoderKL, + LCMScheduler, MultiAdapter, PNDMScheduler, StableDiffusionAdapterPipeline, @@ -56,7 +57,7 @@ class AdapterTests: params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - def get_dummy_components(self, adapter_type): + def get_dummy_components(self, adapter_type, time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -67,6 +68,7 @@ def get_dummy_components(self, adapter_type): down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"), up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"), cross_attention_dim=32, + time_cond_proj_dim=time_cond_proj_dim, ) scheduler = PNDMScheduler(skip_prk_steps=True) torch.manual_seed(0) @@ -264,13 +266,13 @@ def test_inference_batch_single_identical(self): @parameterized.expand( [ # (dim=264) The internal feature map will be 33x33 after initial pixel unshuffling (downscaled x8). - ((4 * 8 + 1) * 8), + (((4 * 8 + 1) * 8),), # (dim=272) The internal feature map will be 17x17 after the first T2I down block (downscaled x16). - ((4 * 4 + 1) * 16), + (((4 * 4 + 1) * 16),), # (dim=288) The internal feature map will be 9x9 after the second T2I down block (downscaled x32). - ((4 * 2 + 1) * 32), + (((4 * 2 + 1) * 32),), # (dim=320) The internal feature map will be 5x5 after the third T2I down block (downscaled x64). - ((4 * 1 + 1) * 64), + (((4 * 1 + 1) * 64),), ] ) def test_multiple_image_dimensions(self, dim): @@ -292,10 +294,30 @@ def test_multiple_image_dimensions(self, dim): assert image.shape == (1, dim, dim, 3) + def test_adapter_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionAdapterPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.4535, 0.5493, 0.4359, 0.5452, 0.6086, 0.4441, 0.5544, 0.501, 0.4859]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionFullAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): - def get_dummy_components(self): - return super().get_dummy_components("full_adapter") + def get_dummy_components(self, time_cond_proj_dim=None): + return super().get_dummy_components("full_adapter", time_cond_proj_dim=time_cond_proj_dim) def get_dummy_components_with_full_downscaling(self): return super().get_dummy_components_with_full_downscaling("full_adapter") @@ -317,8 +339,8 @@ def test_stable_diffusion_adapter_default_case(self): class StableDiffusionLightAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): - def get_dummy_components(self): - return super().get_dummy_components("light_adapter") + def get_dummy_components(self, time_cond_proj_dim=None): + return super().get_dummy_components("light_adapter", time_cond_proj_dim=time_cond_proj_dim) def get_dummy_components_with_full_downscaling(self): return super().get_dummy_components_with_full_downscaling("light_adapter") @@ -340,8 +362,8 @@ def test_stable_diffusion_adapter_default_case(self): class StableDiffusionMultiAdapterPipelineFastTests(AdapterTests, PipelineTesterMixin, unittest.TestCase): - def get_dummy_components(self): - return super().get_dummy_components("multi_adapter") + def get_dummy_components(self, time_cond_proj_dim=None): + return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) def get_dummy_components_with_full_downscaling(self): return super().get_dummy_components_with_full_downscaling("multi_adapter") diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index 1c83e80f3d6f..daf46000a1e0 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -26,6 +26,7 @@ from diffusers import ( AutoencoderKL, EulerDiscreteScheduler, + LCMScheduler, MultiAdapter, StableDiffusionXLAdapterPipeline, T2IAdapter, @@ -59,7 +60,7 @@ class StableDiffusionXLAdapterPipelineFastTests( params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - def get_dummy_components(self, adapter_type="full_adapter_xl"): + def get_dummy_components(self, adapter_type="full_adapter_xl", time_cond_proj_dim=None): torch.manual_seed(0) unet = UNet2DConditionModel( block_out_channels=(32, 64), @@ -77,6 +78,7 @@ def get_dummy_components(self, adapter_type="full_adapter_xl"): transformer_layers_per_block=(1, 2), projection_class_embeddings_input_dim=80, # 6 * 8 + 32 cross_attention_dim=64, + time_cond_proj_dim=time_cond_proj_dim, ) scheduler = EulerDiscreteScheduler( beta_start=0.00085, @@ -309,9 +311,9 @@ def test_stable_diffusion_adapter_default_case(self): @parameterized.expand( [ # (dim=144) The internal feature map will be 9x9 after initial pixel unshuffling (downscaled x16). - ((4 * 2 + 1) * 16), + (((4 * 2 + 1) * 16),), # (dim=160) The internal feature map will be 5x5 after the first T2I down block (downscaled x32). - ((4 * 1 + 1) * 32), + (((4 * 1 + 1) * 32),), ] ) def test_multiple_image_dimensions(self, dim): @@ -367,12 +369,32 @@ def test_total_downscale_factor(self, adapter_type): def test_save_load_optional_components(self): return self._test_save_load_optional_components() + def test_adapter_sdxl_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLAdapterPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5425, 0.5385, 0.4964, 0.5045, 0.6149, 0.4974, 0.5469, 0.5332, 0.5426]) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + class StableDiffusionXLMultiAdapterPipelineFastTests( StableDiffusionXLAdapterPipelineFastTests, PipelineTesterMixin, unittest.TestCase ): - def get_dummy_components(self): - return super().get_dummy_components("multi_adapter") + def get_dummy_components(self, time_cond_proj_dim=None): + return super().get_dummy_components("multi_adapter", time_cond_proj_dim=time_cond_proj_dim) def get_dummy_components_with_full_downscaling(self): return super().get_dummy_components_with_full_downscaling("multi_adapter") @@ -569,6 +591,29 @@ def test_inference_batch_single_identical( if test_mean_pixel_difference: assert_mean_pixel_difference(output_batch[0][0], output[0][0]) + def test_adapter_sdxl_lcm(self): + device = "cpu" # ensure determinism for the device-dependent torch.Generator + + components = self.get_dummy_components(time_cond_proj_dim=256) + sd_pipe = StableDiffusionXLAdapterPipeline(**components) + sd_pipe.scheduler = LCMScheduler.from_config(sd_pipe.scheduler.config) + sd_pipe = sd_pipe.to(torch_device) + sd_pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + output = sd_pipe(**inputs) + image = output.images + + image_slice = image[0, -3:, -3:, -1] + + assert image.shape == (1, 64, 64, 3) + expected_slice = np.array([0.5313, 0.5375, 0.4942, 0.5021, 0.6142, 0.4968, 0.5434, 0.5311, 0.5448]) + + debug = [str(round(i, 4)) for i in image_slice.flatten().tolist()] + print(",".join(debug)) + + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 + @slow @require_torch_gpu