diff --git a/examples/community/cogvideox_ddim_inversion.py b/examples/community/cogvideox_ddim_inversion.py index 36d95901c65f..a964d4618025 100644 --- a/examples/community/cogvideox_ddim_inversion.py +++ b/examples/community/cogvideox_ddim_inversion.py @@ -522,14 +522,15 @@ def sample( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if reference_latents is not None: # Recover the original batch size diff --git a/examples/community/pipeline_flux_differential_img2img.py b/examples/community/pipeline_flux_differential_img2img.py index 3677e73136f7..563f50226281 100644 --- a/examples/community/pipeline_flux_differential_img2img.py +++ b/examples/community/pipeline_flux_differential_img2img.py @@ -954,17 +954,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/examples/community/pipeline_flux_kontext_multiple_images.py b/examples/community/pipeline_flux_kontext_multiple_images.py index 9e6ae427dbfa..ee0e8497f23c 100644 --- a/examples/community/pipeline_flux_kontext_multiple_images.py +++ b/examples/community/pipeline_flux_kontext_multiple_images.py @@ -1150,33 +1150,35 @@ def __call__( latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) diff --git a/examples/community/pipeline_flux_rf_inversion.py b/examples/community/pipeline_flux_rf_inversion.py index 2cd6eb088cd8..c9e1d2cbfc4e 100644 --- a/examples/community/pipeline_flux_rf_inversion.py +++ b/examples/community/pipeline_flux_rf_inversion.py @@ -888,17 +888,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] latents_dtype = latents.dtype if do_rf_inversion: @@ -1058,17 +1059,18 @@ def invert( timestep = torch.tensor(t_i, dtype=Y_t.dtype, device=device).repeat(batch_size) # get the unconditional vector field - u_t_i = self.transformer( - hidden_states=Y_t, - timestep=timestep, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("uncond"): + u_t_i = self.transformer( + hidden_states=Y_t, + timestep=timestep, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # get the conditional vector field u_t_i_cond = (y_1 - Y_t) / (1 - t_i) diff --git a/examples/community/pipeline_flux_semantic_guidance.py b/examples/community/pipeline_flux_semantic_guidance.py index 74cd5c6981b0..6917e2da64a6 100644 --- a/examples/community/pipeline_flux_semantic_guidance.py +++ b/examples/community/pipeline_flux_semantic_guidance.py @@ -1135,23 +1135,25 @@ def __call__( else: guidance = None - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] if enable_edit_guidance and max_edit_cooldown_steps >= i >= min_edit_warmup_steps: noise_pred_edit_concepts = [] for e_embed, pooled_e_embed, e_text_id in zip( editing_prompts_embeds, pooled_editing_prompt_embeds, edit_text_ids ): + # TODO-context noise_pred_edit = self.transformer( hidden_states=latents, timestep=timestep / 1000, @@ -1168,17 +1170,18 @@ def __call__( if do_true_cfg: if negative_image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - noise_pred_uncond = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_guidance = true_cfg_scale * (noise_pred - noise_pred_uncond) else: noise_pred_uncond = noise_pred diff --git a/examples/community/pipeline_flux_with_cfg.py b/examples/community/pipeline_flux_with_cfg.py index 5bc13f7e5e11..f396d45d9800 100644 --- a/examples/community/pipeline_flux_with_cfg.py +++ b/examples/community/pipeline_flux_with_cfg.py @@ -815,17 +815,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latent_model_input.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] if do_true_cfg: neg_noise_pred, noise_pred = noise_pred.chunk(2) diff --git a/examples/community/pipeline_hunyuandit_differential_img2img.py b/examples/community/pipeline_hunyuandit_differential_img2img.py index fb7a4cb5e472..88d46e8285ac 100644 --- a/examples/community/pipeline_hunyuandit_differential_img2img.py +++ b/examples/community/pipeline_hunyuandit_differential_img2img.py @@ -1074,18 +1074,19 @@ def __call__( ) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py index 1803cf60cc4b..6f6a9b4d6c46 100644 --- a/examples/community/pipeline_stable_diffusion_3_differential_img2img.py +++ b/examples/community/pipeline_stable_diffusion_3_differential_img2img.py @@ -918,13 +918,14 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: diff --git a/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py b/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py index d9cee800e8ad..2e411ac86004 100644 --- a/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py +++ b/examples/community/pipeline_stable_diffusion_3_instruct_pix2pix.py @@ -1178,14 +1178,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) scaled_latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) - noise_pred = self.transformer( - hidden_states=scaled_latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=scaled_latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: @@ -1204,6 +1205,7 @@ def __call__( if skip_guidance_layers is not None and should_skip_layers: timestep = t.expand(latents.shape[0]) latent_model_input = latents + # TODO-context noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/examples/community/pipeline_stg_cogvideox.py b/examples/community/pipeline_stg_cogvideox.py index bdb6aecc30c3..9e29ba2f58c5 100644 --- a/examples/community/pipeline_stg_cogvideox.py +++ b/examples/community/pipeline_stg_cogvideox.py @@ -793,14 +793,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/examples/community/pipeline_stg_hunyuan_video.py b/examples/community/pipeline_stg_hunyuan_video.py index 028d54d047e4..ebdcf8f61862 100644 --- a/examples/community/pipeline_stg_hunyuan_video.py +++ b/examples/community/pipeline_stg_hunyuan_video.py @@ -751,17 +751,17 @@ def __call__( self.transformer.transformer_blocks[i].forward = types.MethodType( forward_without_stg, self.transformer.transformer_blocks[i] ) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] if self.do_spatio_temporal_guidance: for i in stg_applied_layers_idx: @@ -769,6 +769,7 @@ def __call__( forward_with_stg, self.transformer.transformer_blocks[i] ) + # TODO-context noise_pred_perturb = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/examples/community/pipeline_stg_ltx.py b/examples/community/pipeline_stg_ltx.py index 70069a33f5d9..8943b7a8529f 100644 --- a/examples/community/pipeline_stg_ltx.py +++ b/examples/community/pipeline_stg_ltx.py @@ -791,18 +791,19 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance: diff --git a/examples/community/pipeline_stg_ltx_image2video.py b/examples/community/pipeline_stg_ltx_image2video.py index c32805e1419f..f03e7ca4d133 100644 --- a/examples/community/pipeline_stg_ltx_image2video.py +++ b/examples/community/pipeline_stg_ltx_image2video.py @@ -864,18 +864,19 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) timestep = timestep.unsqueeze(-1) * (1 - conditioning_mask) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - num_frames=latent_num_frames, - height=latent_height, - width=latent_width, - rope_interpolation_scale=rope_interpolation_scale, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + num_frames=latent_num_frames, + height=latent_height, + width=latent_width, + rope_interpolation_scale=rope_interpolation_scale, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.float() if self.do_classifier_free_guidance and not self.do_spatio_temporal_guidance: diff --git a/examples/community/pipeline_stg_mochi.py b/examples/community/pipeline_stg_mochi.py index ad9317f6bc9d..1e7b115ef2e6 100644 --- a/examples/community/pipeline_stg_mochi.py +++ b/examples/community/pipeline_stg_mochi.py @@ -777,14 +777,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - encoder_attention_mask=prompt_attention_mask, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + encoder_attention_mask=prompt_attention_mask, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] # Mochi CFG + Sampling runs in FP32 noise_pred = noise_pred.to(torch.float32) diff --git a/examples/community/pipeline_stg_wan.py b/examples/community/pipeline_stg_wan.py index 39f208bad7c5..c6f30d0770a2 100644 --- a/examples/community/pipeline_stg_wan.py +++ b/examples/community/pipeline_stg_wan.py @@ -579,26 +579,30 @@ def __call__( for idx, block in enumerate(self.transformer.blocks): block.forward = types.MethodType(forward_without_stg, block) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] if self.do_spatio_temporal_guidance: for idx, block in enumerate(self.transformer.blocks): if idx in stg_applied_layers_idx: block.forward = types.MethodType(forward_with_stg, block) + + # TODO-context noise_perturb = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/examples/community/pipline_flux_fill_controlnet_Inpaint.py b/examples/community/pipline_flux_fill_controlnet_Inpaint.py index cc642a767f87..5095ce61c6fa 100644 --- a/examples/community/pipline_flux_fill_controlnet_Inpaint.py +++ b/examples/community/pipline_flux_fill_controlnet_Inpaint.py @@ -1246,20 +1246,21 @@ def __call__( masked_image_latents_fill = torch.cat((masked_latents_fill, mask_fill), dim=-1) latent_model_input = torch.cat([latents, masked_image_latents_fill], dim=2) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/examples/research_projects/anytext/anytext.py b/examples/research_projects/anytext/anytext.py index 7ae6ae57c22a..6deb78d6106f 100644 --- a/examples/research_projects/anytext/anytext.py +++ b/examples/research_projects/anytext/anytext.py @@ -697,7 +697,8 @@ def forward(self, text, **kwargs): z_list = [] for tokens in tokens_list: tokens = tokens.to(self.device) - _z = self.transformer(input_ids=tokens, **kwargs) + with self.transformer.cache_context("cond"): + _z = self.transformer(input_ids=tokens, **kwargs) z_list += [_z] return torch.cat(z_list, dim=1) diff --git a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py index 89228983d4d8..3711ac3349d6 100644 --- a/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py +++ b/examples/research_projects/pixart/pipeline_pixart_alpha_controlnet.py @@ -1042,17 +1042,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latent_model_input.shape[0]) - # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, - controlnet_cond=image_latents, - # rc todo: controlnet_conditioning_scale=1.0, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + context = "uncond,cond" if do_classifier_free_guidance else "cond" + with self.transformer.cache_context(context): + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + controlnet_cond=image_latents, + # rc todo: controlnet_conditioning_scale=1.0, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/examples/research_projects/sana/train_sana_sprint_diffusers.py b/examples/research_projects/sana/train_sana_sprint_diffusers.py index d127fee5fd0d..b5eb443871d3 100644 --- a/examples/research_projects/sana/train_sana_sprint_diffusers.py +++ b/examples/research_projects/sana/train_sana_sprint_diffusers.py @@ -813,13 +813,14 @@ def get_features(module, input, output): if i in self.block_hooks: hooks.append(block.register_forward_hook(get_features)) - self.transformer( - hidden_states=hidden_states, - timestep=timestep, - encoder_hidden_states=encoder_hidden_states, - return_logvar=False, - **kwargs, - ) + with self.transformer.cache_context("cond"): + self.transformer( + hidden_states=hidden_states, + timestep=timestep, + encoder_hidden_states=encoder_hidden_states, + return_logvar=False, + **kwargs, + ) for hook in hooks: hook.remove() diff --git a/src/diffusers/pipelines/allegro/pipeline_allegro.py b/src/diffusers/pipelines/allegro/pipeline_allegro.py index 3be0129088fb..c9931d7b7358 100644 --- a/src/diffusers/pipelines/allegro/pipeline_allegro.py +++ b/src/diffusers/pipelines/allegro/pipeline_allegro.py @@ -929,14 +929,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/amused/pipeline_amused.py b/src/diffusers/pipelines/amused/pipeline_amused.py index 131e34d1a4a1..37071168c2c9 100644 --- a/src/diffusers/pipelines/amused/pipeline_amused.py +++ b/src/diffusers/pipelines/amused/pipeline_amused.py @@ -281,13 +281,14 @@ def __call__( else: model_input = latents - model_output = self.transformer( - model_input, - micro_conds=micro_conds, - pooled_text_emb=prompt_embeds, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ) + with self.transformer.cache_context("cond"): + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) if guidance_scale > 1.0: uncond_logits, cond_logits = model_output.chunk(2) diff --git a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py index a122c12236dd..534968ee0613 100644 --- a/src/diffusers/pipelines/amused/pipeline_amused_img2img.py +++ b/src/diffusers/pipelines/amused/pipeline_amused_img2img.py @@ -309,13 +309,14 @@ def __call__( else: model_input = latents - model_output = self.transformer( - model_input, - micro_conds=micro_conds, - pooled_text_emb=prompt_embeds, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ) + with self.transformer.cache_context("cond"): + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) if guidance_scale > 1.0: uncond_logits, cond_logits = model_output.chunk(2) diff --git a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py index f4bd4944ff9a..20017000069f 100644 --- a/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py +++ b/src/diffusers/pipelines/amused/pipeline_amused_inpaint.py @@ -339,13 +339,14 @@ def __call__( else: model_input = latents - model_output = self.transformer( - model_input, - micro_conds=micro_conds, - pooled_text_emb=prompt_embeds, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - ) + with self.transformer.cache_context("cond"): + model_output = self.transformer( + model_input, + micro_conds=micro_conds, + pooled_text_emb=prompt_embeds, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + ) if guidance_scale > 1.0: uncond_logits, cond_logits = model_output.chunk(2) diff --git a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py index bb9884e41381..d5fbb57e6b3c 100644 --- a/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py +++ b/src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py @@ -615,13 +615,14 @@ def __call__( timestep = timestep.to(latents.device, dtype=latents.dtype) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/bria/pipeline_bria.py b/src/diffusers/pipelines/bria/pipeline_bria.py index a22a756005ac..86383af7f814 100644 --- a/src/diffusers/pipelines/bria/pipeline_bria.py +++ b/src/diffusers/pipelines/bria/pipeline_bria.py @@ -662,15 +662,16 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # This is predicts "v" from flow-matching or eps from diffusion - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=self.attention_kwargs, - return_dict=False, - txt_ids=text_ids, - img_ids=latent_image_ids, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + attention_kwargs=self.attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] # perform guidance if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py index 8fd29756b290..99289cff48e4 100644 --- a/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py +++ b/src/diffusers/pipelines/bria_fibo/pipeline_bria_fibo.py @@ -705,16 +705,17 @@ def __call__( ) # This is predicts "v" from flow-matching or eps from diffusion - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - text_encoder_layers=prompt_layers, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - txt_ids=text_ids, - img_ids=latent_image_ids, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + text_encoder_layers=prompt_layers, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + txt_ids=text_ids, + img_ids=latent_image_ids, + )[0] # perform guidance if guidance_scale > 1: diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma.py b/src/diffusers/pipelines/chroma/pipeline_chroma.py index ed6c2c2105b6..c2a1eb679c09 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma.py @@ -906,30 +906,33 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - attention_mask=attention_mask, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=negative_attention_mask, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + guidance_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py index 470c746e4146..f899cee74040 100644 --- a/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py +++ b/src/diffusers/pipelines/chroma/pipeline_chroma_img2img.py @@ -989,31 +989,33 @@ def __call__( if image_embeds is not None: self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - attention_mask=attention_mask, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_image_ids, - attention_mask=negative_attention_mask, + attention_mask=attention_mask, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_image_ids, + attention_mask=negative_attention_mask, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py b/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py index 79f6580fbed6..d319f1ede04e 100644 --- a/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py +++ b/src/diffusers/pipelines/chronoedit/pipeline_chronoedit.py @@ -680,24 +680,26 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py index 304a5c5ad00b..4fb56babbc3e 100644 --- a/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py +++ b/src/diffusers/pipelines/cogview3/pipeline_cogview3plus.py @@ -616,15 +616,16 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py index e26b7ba415de..f0f4a01492ad 100644 --- a/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py +++ b/src/diffusers/pipelines/cogview4/pipeline_cogview4_control.py @@ -668,22 +668,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]) - noise_pred_cond = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - original_size=original_size, - target_size=target_size, - crop_coords=crops_coords_top_left, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - # perform guidance - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( hidden_states=latent_model_input, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, timestep=timestep, original_size=original_size, target_size=target_size, @@ -692,6 +680,20 @@ def __call__( return_dict=False, )[0] + # perform guidance + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=negative_prompt_embeds, + timestep=timestep, + original_size=original_size, + target_size=target_size, + crop_coords=crops_coords_top_left, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) else: noise_pred = noise_pred_cond diff --git a/src/diffusers/pipelines/consisid/pipeline_consisid.py b/src/diffusers/pipelines/consisid/pipeline_consisid.py index 3e6c149d7f80..491acd680bc0 100644 --- a/src/diffusers/pipelines/consisid/pipeline_consisid.py +++ b/src/diffusers/pipelines/consisid/pipeline_consisid.py @@ -902,16 +902,17 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=timestep, - image_rotary_emb=image_rotary_emb, - attention_kwargs=attention_kwargs, - return_dict=False, - id_vit_hidden=id_vit_hidden, - id_cond=id_cond, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=timestep, + image_rotary_emb=image_rotary_emb, + attention_kwargs=attention_kwargs, + return_dict=False, + id_vit_hidden=id_vit_hidden, + id_cond=id_cond, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py index 2b5684de9511..455b370a7dad 100644 --- a/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +++ b/src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py @@ -983,19 +983,20 @@ def __call__( )[0] # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - return_dict=False, - controlnet_block_samples=control_block_samples, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + controlnet_block_samples=control_block_samples, + )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py index d605eac1f2b1..a518339aa6ff 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py @@ -1182,15 +1182,16 @@ def __call__( return_dict=False, )[0] - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - block_controlnet_hidden_states=control_block_samples, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_samples, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py index 9d0158c6b654..a8da1d8b0a3a 100644 --- a/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +++ b/src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py @@ -1201,15 +1201,16 @@ def __call__( return_dict=False, )[0] - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - block_controlnet_hidden_states=control_block_samples, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + block_controlnet_hidden_states=control_block_samples, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py index 66490c2be159..19ea0521c801 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py @@ -588,23 +588,25 @@ def __call__( latent_model_input = latents * c_in latent_model_input = latent_model_input.to(transformer_dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - padding_mask=padding_mask, - return_dict=False, - )[0] - noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) - - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, padding_mask=padding_mask, return_dict=False, )[0] + noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + padding_mask=padding_mask, + return_dict=False, + )[0] noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype) noise_pred = noise_pred + self.guidance_scale * (noise_pred - noise_pred_uncond) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py index 23a74ad00f93..42c298ddb9e6 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py @@ -696,15 +696,16 @@ def __call__( cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep cond_timestep = cond_timestep.to(transformer_dtype) - noise_pred = self.transformer( - hidden_states=cond_latent, - timestep=cond_timestep, - encoder_hidden_states=prompt_embeds, - fps=fps, - condition_mask=cond_mask, - padding_mask=padding_mask, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=cond_timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(transformer_dtype) noise_pred = cond_indicator * conditioning_latents + (1 - cond_indicator) * noise_pred @@ -715,15 +716,16 @@ def __call__( uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep uncond_timestep = uncond_timestep.to(transformer_dtype) - noise_pred_uncond = self.transformer( - hidden_states=uncond_latent, - timestep=uncond_timestep, - encoder_hidden_states=negative_prompt_embeds, - fps=fps, - condition_mask=uncond_mask, - padding_mask=padding_mask, - return_dict=False, - )[0] + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent, + timestep=uncond_timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + condition_mask=uncond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(transformer_dtype) noise_pred_uncond = ( uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * noise_pred_uncond diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py index f0aa1ecf0e0f..1ed999a11b83 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py @@ -566,25 +566,27 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) latent_model_input = latent_model_input.to(transformer_dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - fps=fps, - padding_mask=padding_mask, - return_dict=False, - )[0] - - sample = latents - if self.do_classifier_free_guidance: - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, fps=fps, padding_mask=padding_mask, return_dict=False, )[0] + + sample = latents + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + padding_mask=padding_mask, + return_dict=False, + )[0] noise_pred = torch.cat([noise_pred_uncond, noise_pred]) sample = torch.cat([sample, sample]) diff --git a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py index cd5a734cc311..59b90d5c89d8 100644 --- a/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +++ b/src/diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py @@ -707,15 +707,16 @@ def __call__( cond_latent = self.scheduler.scale_model_input(cond_latent, t) cond_latent = cond_latent.to(transformer_dtype) - noise_pred = self.transformer( - hidden_states=cond_latent, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - fps=fps, - condition_mask=cond_mask, - padding_mask=padding_mask, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=cond_latent, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + fps=fps, + condition_mask=cond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] sample = latents if self.do_classifier_free_guidance: @@ -727,15 +728,16 @@ def __call__( uncond_latent = self.scheduler.scale_model_input(uncond_latent, t) uncond_latent = uncond_latent.to(transformer_dtype) - noise_pred_uncond = self.transformer( - hidden_states=uncond_latent, - timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - fps=fps, - condition_mask=uncond_mask, - padding_mask=padding_mask, - return_dict=False, - )[0] + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=uncond_latent, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + fps=fps, + condition_mask=uncond_mask, + padding_mask=padding_mask, + return_dict=False, + )[0] noise_pred = torch.cat([noise_pred_uncond, noise_pred]) sample = torch.cat([sample, sample]) diff --git a/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py b/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py index e8617a54b691..99dfb1e829db 100644 --- a/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +++ b/src/diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py @@ -266,7 +266,10 @@ def __call__( # predict the un-noised image # model_output == `log_p_x_0` - model_output = self.transformer(latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t).sample + with self.transformer.cache_context("cond"): + model_output = self.transformer( + latent_model_input, encoder_hidden_states=prompt_embeds, timestep=t + ).sample if do_classifier_free_guidance: model_output_uncond, model_output_text = model_output.chunk(2) diff --git a/src/diffusers/pipelines/dit/pipeline_dit.py b/src/diffusers/pipelines/dit/pipeline_dit.py index 68ff6c9b559a..f1c8d204eb95 100644 --- a/src/diffusers/pipelines/dit/pipeline_dit.py +++ b/src/diffusers/pipelines/dit/pipeline_dit.py @@ -200,9 +200,10 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timesteps = timesteps.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, timestep=timesteps, class_labels=class_labels_input - ).sample + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, timestep=timesteps, class_labels=class_labels_input + ).sample # perform guidance if guidance_scale > 1: diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py index 92239c0d32f0..25a3c560b2eb 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate.py @@ -716,12 +716,13 @@ def __call__( ) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] if noise_pred.size()[1] != self.vae.config.latent_channels: noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py index f74a11f87d75..5611f39d9c58 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py @@ -940,13 +940,14 @@ def __call__( dtype=latent_model_input.dtype ) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - control_latents=control_latents, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + control_latents=control_latents, + return_dict=False, + )[0] if noise_pred.size()[1] != self.vae.config.latent_channels: noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py index b16ef92d8e6b..c5313ac48e55 100755 --- a/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +++ b/src/diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py @@ -1165,13 +1165,14 @@ def __call__( ) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - inpaint_latents=inpaint_latents, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + inpaint_latents=inpaint_latents, + return_dict=False, + )[0] if noise_pred.size()[1] != self.vae.config.latent_channels: noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control.py b/src/diffusers/pipelines/flux/pipeline_flux_control.py index 848d7bd39254..09d882fce833 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control.py @@ -857,17 +857,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py index 262345c75afc..0dacb9666702 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_img2img.py @@ -886,17 +886,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py index 6915a83a7ca7..b8341d09ac14 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_control_inpaint.py @@ -1093,18 +1093,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py index 507ec687347c..e21a8c7bbd62 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet.py @@ -1099,30 +1099,13 @@ def __call__( ) guidance = guidance.expand(latents.shape[0]) if guidance is not None else None - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, controlnet_block_samples=controlnet_block_samples, controlnet_single_block_samples=controlnet_single_block_samples, txt_ids=text_ids, @@ -1131,6 +1114,25 @@ def __call__( return_dict=False, controlnet_blocks_repeat=controlnet_blocks_repeat, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py index 582c7bbad84e..bd15455fcead 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py @@ -950,20 +950,21 @@ def __call__( ) guidance = guidance.expand(latents.shape[0]) if guidance is not None else None - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] latents_dtype = latents.dtype latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py index f7f34ef231af..4f141bb3116d 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py @@ -1132,20 +1132,21 @@ def __call__( else: guidance = None - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - controlnet_block_samples=controlnet_block_samples, - controlnet_single_block_samples=controlnet_single_block_samples, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - controlnet_blocks_repeat=controlnet_blocks_repeat, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + controlnet_block_samples=controlnet_block_samples, + controlnet_single_block_samples=controlnet_single_block_samples, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + controlnet_blocks_repeat=controlnet_blocks_repeat, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/flux/pipeline_flux_fill.py b/src/diffusers/pipelines/flux/pipeline_flux_fill.py index 5cb9c82204b2..c8ad127900ad 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_fill.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_fill.py @@ -1006,17 +1006,18 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=torch.cat((latents, masked_image_latents), dim=2), - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=torch.cat((latents, masked_image_latents), dim=2), + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py index ab9140dae921..a954e3032b25 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_img2img.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_img2img.py @@ -1021,32 +1021,36 @@ def __call__( self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py index 3bfe82cf4382..2e4ed1c05c19 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_inpaint.py @@ -1119,32 +1119,35 @@ def __call__( self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latents, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, txt_ids=text_ids, img_ids=latent_image_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py index 94ae460afcd0..e05932cc29cc 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext.py @@ -1085,33 +1085,36 @@ def __call__( latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) diff --git a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py index b6f957981e14..2ab746c31697 100644 --- a/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +++ b/src/diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py @@ -1400,33 +1400,36 @@ def __call__( latent_model_input = torch.cat([latents, image_latents], dim=1) timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] - noise_pred = noise_pred[:, : latents.size(1)] - - if do_true_cfg: - if negative_image_embeds is not None: - self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep / 1000, guidance=guidance, - pooled_projections=negative_pooled_prompt_embeds, - encoder_hidden_states=negative_prompt_embeds, - txt_ids=negative_text_ids, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, img_ids=latent_ids, joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] + noise_pred = noise_pred[:, : latents.size(1)] + + if do_true_cfg: + if negative_image_embeds is not None: + self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds + + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=negative_prompt_embeds, + txt_ids=negative_text_ids, + img_ids=latent_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] neg_noise_pred = neg_noise_pred[:, : latents.size(1)] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) diff --git a/src/diffusers/pipelines/flux2/pipeline_flux2.py b/src/diffusers/pipelines/flux2/pipeline_flux2.py index b54a43dd89a5..ac1aa657d170 100644 --- a/src/diffusers/pipelines/flux2/pipeline_flux2.py +++ b/src/diffusers/pipelines/flux2/pipeline_flux2.py @@ -968,16 +968,17 @@ def __call__( latent_model_input = torch.cat([latents, image_latents], dim=1).to(self.transformer.dtype) latent_image_ids = torch.cat([latent_ids, image_latent_ids], dim=1) - noise_pred = self.transformer( - hidden_states=latent_model_input, # (B, image_seq_len, C) - timestep=timestep / 1000, - guidance=guidance, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, # B, text_seq_len, 4 - img_ids=latent_image_ids, # B, image_seq_len, 4 - joint_attention_kwargs=self._attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, # (B, image_seq_len, C) + timestep=timestep / 1000, + guidance=guidance, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, # B, text_seq_len, 4 + img_ids=latent_image_ids, # B, image_seq_len, 4 + joint_attention_kwargs=self._attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred[:, : latents.size(1) :] diff --git a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py index b6af23bca8fd..10c9e451b859 100644 --- a/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py +++ b/src/diffusers/pipelines/hidream_image/pipeline_hidream_image.py @@ -990,14 +990,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timesteps=timestep, - encoder_hidden_states_t5=prompt_embeds_t5, - encoder_hidden_states_llama3=prompt_embeds_llama3, - pooled_embeds=pooled_prompt_embeds, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timesteps=timestep, + encoder_hidden_states_t5=prompt_embeds_t5, + encoder_hidden_states_llama3=prompt_embeds_llama3, + pooled_embeds=pooled_prompt_embeds, + return_dict=False, + )[0] noise_pred = -noise_pred # perform guidance diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py index b50a6ae3ed27..8cbdd90d3f30 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py @@ -767,28 +767,30 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latents.shape[0]).to(latents.dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py index 8006514f47ea..055eb03491d4 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py @@ -1001,32 +1001,13 @@ def __call__( self._current_timestep = t timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latents.to(transformer_dtype), - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - image_embeds=image_embeds, - indices_latents=indices_latents, - guidance=guidance, - latents_clean=latents_clean.to(transformer_dtype), - indices_latents_clean=indices_clean_latents, - latents_history_2x=latents_history_2x.to(transformer_dtype), - indices_latents_history_2x=indices_latents_history_2x, - latents_history_4x=latents_history_4x.to(transformer_dtype), - indices_latents_history_4x=indices_latents_history_4x, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latents.to(transformer_dtype), timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, image_embeds=image_embeds, indices_latents=indices_latents, guidance=guidance, @@ -1039,6 +1020,27 @@ def __call__( attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latents.to(transformer_dtype), + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + image_embeds=image_embeds, + indices_latents=indices_latents, + guidance=guidance, + latents_clean=latents_clean.to(transformer_dtype), + indices_latents_clean=indices_clean_latents, + latents_history_2x=latents_history_2x.to(transformer_dtype), + indices_latents_history_2x=indices_latents_history_2x, + latents_history_4x=latents_history_4x.to(transformer_dtype), + indices_latents_history_4x=indices_latents_history_4x, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py index aa04e6509730..9c8e966d9e28 100644 --- a/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +++ b/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py @@ -932,28 +932,30 @@ def __call__( elif image_condition_type == "token_replace": latent_model_input = torch.cat([image_latents, latents[:, :, 1:]], dim=2).to(transformer_dtype) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - pooled_projections=pooled_prompt_embeds, - guidance=guidance, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if do_true_cfg: - neg_noise_pred = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, - pooled_projections=negative_pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + pooled_projections=pooled_prompt_embeds, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if do_true_cfg: + with self.transformer.cache_context("uncond"): + neg_noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + pooled_projections=negative_pooled_prompt_embeds, + guidance=guidance, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py index e2f935aaf4b9..470a944ab1b9 100644 --- a/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +++ b/src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py @@ -840,18 +840,19 @@ def __call__( ) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/src/diffusers/pipelines/kandinsky/text_encoder.py b/src/diffusers/pipelines/kandinsky/text_encoder.py index caa0029f00ca..2c9fea09f8a2 100644 --- a/src/diffusers/pipelines/kandinsky/text_encoder.py +++ b/src/diffusers/pipelines/kandinsky/text_encoder.py @@ -22,6 +22,7 @@ def __init__(self, config, *args, **kwargs): ) def forward(self, input_ids, attention_mask): - embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] + with self.transformer.cache_context("cond"): + embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] return self.LinearTransformation(embs2), embs diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py index 2b666f0ec697..38c63d6378e0 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky.py @@ -877,31 +877,33 @@ def __call__( timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) # Predict noise residual - pred_velocity = self.transformer( - hidden_states=latents.to(dtype), - encoder_hidden_states=prompt_embeds_qwen.to(dtype), - pooled_projections=prompt_embeds_clip.to(dtype), - timestep=timestep.to(dtype), - visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=scale_factor, - sparse_params=sparse_params, - return_dict=True, - ).sample - - if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: - uncond_pred_velocity = self.transformer( + with self.transformer.cache_context("cond"): + pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), - pooled_projections=negative_prompt_embeds_clip.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, + text_rope_pos=text_rope_pos, scale_factor=scale_factor, sparse_params=sparse_params, return_dict=True, ).sample + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + with self.transformer.cache_context("uncond"): + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) # Compute previous sample using the scheduler latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py index f965cdad8f3e..6e44e1f0159c 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2i.py @@ -771,31 +771,33 @@ def __call__( timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt) # Predict noise residual - pred_velocity = self.transformer( - hidden_states=latents.to(dtype), - encoder_hidden_states=prompt_embeds_qwen.to(dtype), - pooled_projections=prompt_embeds_clip.to(dtype), - timestep=timestep.to(dtype), - visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=scale_factor, - sparse_params=sparse_params, - return_dict=True, - ).sample - - if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: - uncond_pred_velocity = self.transformer( + with self.transformer.cache_context("cond"): + pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), - pooled_projections=negative_prompt_embeds_clip.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, + text_rope_pos=text_rope_pos, scale_factor=scale_factor, sparse_params=sparse_params, return_dict=True, ).sample + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + with self.transformer.cache_context("uncond"): + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) latents[:, :, :, :, :num_channels_latents] = self.scheduler.step( diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py index d457c9b69657..08317f41fc5d 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_i2v.py @@ -958,31 +958,33 @@ def __call__( timestep = t.unsqueeze(0).repeat(batch_size * num_videos_per_prompt) # Predict noise residual - pred_velocity = self.transformer( - hidden_states=latents.to(dtype), - encoder_hidden_states=prompt_embeds_qwen.to(dtype), - pooled_projections=prompt_embeds_clip.to(dtype), - timestep=timestep.to(dtype), - visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=scale_factor, - sparse_params=sparse_params, - return_dict=True, - ).sample - - if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: - uncond_pred_velocity = self.transformer( + with self.transformer.cache_context("cond"): + pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), - pooled_projections=negative_prompt_embeds_clip.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, + text_rope_pos=text_rope_pos, scale_factor=scale_factor, sparse_params=sparse_params, return_dict=True, ).sample + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + with self.transformer.cache_context("uncond"): + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) latents[:, 1:, :, :, :num_channels_latents] = self.scheduler.step( diff --git a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py index bb5c40327b4e..2f1ad6f0209b 100644 --- a/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py +++ b/src/diffusers/pipelines/kandinsky5/pipeline_kandinsky_t2i.py @@ -728,31 +728,33 @@ def __call__( timestep = t.unsqueeze(0).repeat(batch_size * num_images_per_prompt) # Predict noise residual - pred_velocity = self.transformer( - hidden_states=latents.to(dtype), - encoder_hidden_states=prompt_embeds_qwen.to(dtype), - pooled_projections=prompt_embeds_clip.to(dtype), - timestep=timestep.to(dtype), - visual_rope_pos=visual_rope_pos, - text_rope_pos=text_rope_pos, - scale_factor=scale_factor, - sparse_params=sparse_params, - return_dict=True, - ).sample - - if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: - uncond_pred_velocity = self.transformer( + with self.transformer.cache_context("cond"): + pred_velocity = self.transformer( hidden_states=latents.to(dtype), - encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), - pooled_projections=negative_prompt_embeds_clip.to(dtype), + encoder_hidden_states=prompt_embeds_qwen.to(dtype), + pooled_projections=prompt_embeds_clip.to(dtype), timestep=timestep.to(dtype), visual_rope_pos=visual_rope_pos, - text_rope_pos=negative_text_rope_pos, + text_rope_pos=text_rope_pos, scale_factor=scale_factor, sparse_params=sparse_params, return_dict=True, ).sample + if self.guidance_scale > 1.0 and negative_prompt_embeds_qwen is not None: + with self.transformer.cache_context("uncond"): + uncond_pred_velocity = self.transformer( + hidden_states=latents.to(dtype), + encoder_hidden_states=negative_prompt_embeds_qwen.to(dtype), + pooled_projections=negative_prompt_embeds_clip.to(dtype), + timestep=timestep.to(dtype), + visual_rope_pos=visual_rope_pos, + text_rope_pos=negative_text_rope_pos, + scale_factor=scale_factor, + sparse_params=sparse_params, + return_dict=True, + ).sample + pred_velocity = uncond_pred_velocity + guidance_scale * (pred_velocity - uncond_pred_velocity) latents = self.scheduler.step(pred_velocity[:, :], t, latents, return_dict=False)[0] diff --git a/src/diffusers/pipelines/latte/pipeline_latte.py b/src/diffusers/pipelines/latte/pipeline_latte.py index 4d42a7049ec9..9145d60e8785 100644 --- a/src/diffusers/pipelines/latte/pipeline_latte.py +++ b/src/diffusers/pipelines/latte/pipeline_latte.py @@ -816,13 +816,14 @@ def __call__( current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - hidden_states=latent_model_input, - encoder_hidden_states=prompt_embeds, - timestep=current_timestep, - enable_temporal_attentions=enable_temporal_attentions, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + encoder_hidden_states=prompt_embeds, + timestep=current_timestep, + enable_temporal_attentions=enable_temporal_attentions, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/lumina/pipeline_lumina.py b/src/diffusers/pipelines/lumina/pipeline_lumina.py index b59c265646cd..4a668daa97aa 100644 --- a/src/diffusers/pipelines/lumina/pipeline_lumina.py +++ b/src/diffusers/pipelines/lumina/pipeline_lumina.py @@ -864,15 +864,16 @@ def __call__( ntk_factor=ntk_factor, ) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=current_timestep, - encoder_hidden_states=prompt_embeds, - encoder_mask=prompt_attention_mask, - image_rotary_emb=image_rotary_emb, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=current_timestep, + encoder_hidden_states=prompt_embeds, + encoder_mask=prompt_attention_mask, + image_rotary_emb=image_rotary_emb, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_pred.chunk(2, dim=1)[0] # perform guidance scale diff --git a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py index 937803edbcbc..bd28d7867cc3 100644 --- a/src/diffusers/pipelines/lumina2/pipeline_lumina2.py +++ b/src/diffusers/pipelines/lumina2/pipeline_lumina2.py @@ -724,25 +724,27 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML current_timestep = current_timestep.expand(latents.shape[0]) - noise_pred_cond = self.transformer( - hidden_states=latents, - timestep=current_timestep, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] - - # perform normalization-based guidance scale on a truncated timestep interval - if self.do_classifier_free_guidance and not do_classifier_free_truncation: - noise_pred_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred_cond = self.transformer( hidden_states=latents, timestep=current_timestep, - encoder_hidden_states=negative_prompt_embeds, - encoder_attention_mask=negative_prompt_attention_mask, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, return_dict=False, attention_kwargs=self.attention_kwargs, )[0] + + # perform normalization-based guidance scale on a truncated timestep interval + if self.do_classifier_free_guidance and not do_classifier_free_truncation: + with self.transformer.cache_context("uncond"): + noise_pred_uncond = self.transformer( + hidden_states=latents, + timestep=current_timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_attention_mask=negative_prompt_attention_mask, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) # apply normalization after classifier-free guidance if cfg_normalization: diff --git a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py index 090cb46aace4..b68328f465b0 100644 --- a/src/diffusers/pipelines/omnigen/pipeline_omnigen.py +++ b/src/diffusers/pipelines/omnigen/pipeline_omnigen.py @@ -490,16 +490,17 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - input_ids=processed_data["input_ids"], - input_img_latents=input_img_latents, - input_image_sizes=processed_data["input_image_sizes"], - attention_mask=processed_data["attention_mask"], - position_ids=processed_data["position_ids"], - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + input_ids=processed_data["input_ids"], + input_img_latents=input_img_latents, + input_image_sizes=processed_data["input_image_sizes"], + attention_mask=processed_data["attention_mask"], + position_ids=processed_data["position_ids"], + return_dict=False, + )[0] if num_cfg == 2: cond, uncond, img_cond = torch.split(noise_pred, len(noise_pred) // 3, dim=0) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py index d156eac8f3f7..f01ee6ec5d12 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py @@ -886,18 +886,19 @@ def __call__( ) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t_expand, - encoder_hidden_states=prompt_embeds, - text_embedding_mask=prompt_attention_mask, - encoder_hidden_states_t5=prompt_embeds_2, - text_embedding_mask_t5=prompt_attention_mask_2, - image_meta_size=add_time_ids, - style=style, - image_rotary_emb=image_rotary_emb, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t_expand, + encoder_hidden_states=prompt_embeds, + text_embedding_mask=prompt_attention_mask, + encoder_hidden_states_t5=prompt_embeds_2, + text_embedding_mask_t5=prompt_attention_mask_2, + image_meta_size=add_time_ids, + style=style, + image_rotary_emb=image_rotary_emb, + return_dict=False, + )[0] noise_pred, _ = noise_pred.chunk(2, dim=1) diff --git a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py index 9031877b5b8d..df93f887c08f 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py @@ -819,14 +819,15 @@ def __call__( current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_perturbed_attention_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sana.py b/src/diffusers/pipelines/pag/pipeline_pag_sana.py index 9e91ccbe8006..4cad59da0e2c 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sana.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sana.py @@ -899,13 +899,14 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py index 941b675099b9..7efc57eae3df 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3.py @@ -923,14 +923,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_perturbed_attention_guidance: diff --git a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py index f40dd52fc244..b210afffbb2b 100644 --- a/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +++ b/src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py @@ -987,14 +987,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_perturbed_attention_guidance: diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py index 1d718a4852a4..d9bd6593c027 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py @@ -921,14 +921,15 @@ def __call__( current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py index bb169ac5c443..72db2cddcc7c 100644 --- a/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +++ b/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py @@ -854,14 +854,15 @@ def __call__( current_timestep = current_timestep.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - latent_model_input, - encoder_hidden_states=prompt_embeds, - encoder_attention_mask=prompt_attention_mask, - timestep=current_timestep, - added_cond_kwargs=added_cond_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + encoder_hidden_states=prompt_embeds, + encoder_attention_mask=prompt_attention_mask, + timestep=current_timestep, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/prx/pipeline_prx.py b/src/diffusers/pipelines/prx/pipeline_prx.py index 873f25316e6d..38c7f992d4ed 100644 --- a/src/diffusers/pipelines/prx/pipeline_prx.py +++ b/src/diffusers/pipelines/prx/pipeline_prx.py @@ -750,13 +750,14 @@ def __call__( t_cont = (t.float() / self.scheduler.config.num_train_timesteps).view(1).to(device) # Forward through transformer - noise_pred = self.transformer( - hidden_states=latents_in, - timestep=t_cont, - encoder_hidden_states=ca_embed, - attention_mask=ca_mask, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latents_in, + timestep=t_cont, + encoder_hidden_states=ca_embed, + attention_mask=ca_mask, + return_dict=False, + )[0] # Apply CFG if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/sana/pipeline_sana.py b/src/diffusers/pipelines/sana/pipeline_sana.py index 2beff802c6e0..1bc721789f7e 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana.py +++ b/src/diffusers/pipelines/sana/pipeline_sana.py @@ -964,14 +964,15 @@ def __call__( timestep = timestep * self.transformer.config.timestep_scale # predict noise model_output - noise_pred = self.transformer( - latent_model_input.to(dtype=transformer_dtype), - encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py index 55ed7b84ebdf..f2e61efae9cc 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_controlnet.py @@ -1059,15 +1059,18 @@ def __call__( )[0] # predict noise model_output - noise_pred = self.transformer( - latent_model_input.to(dtype=transformer_dtype), - encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - controlnet_block_samples=tuple(t.to(dtype=transformer_dtype) for t in controlnet_block_samples), - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + controlnet_block_samples=tuple( + t.to(dtype=transformer_dtype) for t in controlnet_block_samples + ), + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py index 04f45f817efb..ce3c0e03110b 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint.py @@ -847,15 +847,16 @@ def __call__( ) # predict noise model_output - noise_pred = self.transformer( - latent_model_input.to(dtype=transformer_dtype), - encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), - encoder_attention_mask=prompt_attention_mask, - guidance=guidance, - timestep=scm_timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] noise_pred = ( (1 - 2 * scm_timestep_expanded) * latent_model_input diff --git a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py index 8899ed84c4e5..011b4ff14ce2 100644 --- a/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +++ b/src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py @@ -935,15 +935,16 @@ def __call__( ) # predict noise model_output - noise_pred = self.transformer( - latent_model_input.to(dtype=transformer_dtype), - encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), - encoder_attention_mask=prompt_attention_mask, - guidance=guidance, - timestep=scm_timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + guidance=guidance, + timestep=scm_timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] noise_pred = ( (1 - 2 * scm_timestep_expanded) * latent_model_input diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py index a786275e45a9..021f8b402c9a 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video.py @@ -936,14 +936,15 @@ def __call__( timestep = t.expand(latent_model_input.shape[0]) # predict noise model_output - noise_pred = self.transformer( - latent_model_input.to(dtype=transformer_dtype), - encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py index e87880b64cee..8004834b269d 100644 --- a/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py +++ b/src/diffusers/pipelines/sana_video/pipeline_sana_video_i2v.py @@ -979,14 +979,15 @@ def __call__( timestep = timestep * (1 - conditioning_mask) # predict noise model_output - noise_pred = self.transformer( - latent_model_input.to(dtype=transformer_dtype), - encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), - encoder_attention_mask=prompt_attention_mask, - timestep=timestep, - return_dict=False, - attention_kwargs=self.attention_kwargs, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input.to(dtype=transformer_dtype), + encoder_hidden_states=prompt_embeds.to(dtype=transformer_dtype), + encoder_attention_mask=prompt_attention_mask, + timestep=timestep, + return_dict=False, + attention_kwargs=self.attention_kwargs, + )[0] noise_pred = noise_pred.float() # perform guidance diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py index d6cd7d7feceb..1b1c8ee097c5 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py @@ -545,22 +545,24 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py index 089f92632d38..2dedc168823c 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py @@ -886,26 +886,27 @@ def __call__( * noise_factor ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py index 2951a9447386..10820d3e4e5b 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py @@ -966,25 +966,27 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py index 6fedfc795a40..ae2664fbb867 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py @@ -974,25 +974,27 @@ def __call__( ) timestep[:, valid_interval_start:prefix_video_latents_frames] = addnoise_condition - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - enable_diffusion_forcing=True, - fps=fps_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, enable_diffusion_forcing=True, fps=fps_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + enable_diffusion_forcing=True, + fps=fps_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) update_mask_i = step_update_mask[i] diff --git a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py index d61b687eadc3..d1df7f5f34cb 100644 --- a/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +++ b/src/diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py @@ -678,24 +678,26 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py index b7faf097ab0d..05e787eab1ba 100644 --- a/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py +++ b/src/diffusers/pipelines/stable_audio/pipeline_stable_audio.py @@ -718,14 +718,15 @@ def __call__( latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) # predict the noise residual - noise_pred = self.transformer( - latent_model_input, - t.unsqueeze(0), - encoder_hidden_states=text_audio_duration_embeds, - global_hidden_states=audio_duration_embeds, - rotary_embedding=rotary_embedding, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + latent_model_input, + t.unsqueeze(0), + encoder_hidden_states=text_audio_duration_embeds, + global_hidden_states=audio_duration_embeds, + rotary_embedding=rotary_embedding, + return_dict=False, + )[0] # perform guidance if do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py index 660d9801df56..3a56dca0a52c 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py @@ -1061,14 +1061,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: @@ -1083,6 +1084,8 @@ def __call__( if skip_guidance_layers is not None and should_skip_layers: timestep = t.expand(latents.shape[0]) latent_model_input = latents + + # TODO-context noise_pred_skip_layers = self.transformer( hidden_states=latent_model_input, timestep=timestep, diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 9b11bc8781e7..586e845d570d 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -1093,14 +1093,15 @@ def __call__( # broadcast to batch dimension in a way that's compatible with ONNX/Core ML timestep = t.expand(latent_model_input.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py index b947cbff0914..cc7ace6c09b8 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py @@ -1297,14 +1297,15 @@ def __call__( if num_channels_transformer == 33: latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - pooled_projections=pooled_prompt_embeds, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=prompt_embeds, + pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # perform guidance if self.do_classifier_free_guidance: diff --git a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py index 0ddcbf735770..393758644212 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_text_decoder.py @@ -154,7 +154,8 @@ def forward( if labels is not None: dummy_token = self.get_dummy_token(input_ids.shape[0], input_ids.device) labels = torch.cat((dummy_token, input_ids), dim=1) - out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) + with self.transformer.cache_context("cond"): + out = self.transformer(inputs_embeds=embedding_cat, labels=labels, attention_mask=attention_mask) if self.prefix_hidden_dim is not None: return out, hidden else: @@ -250,6 +251,7 @@ def generate_beam( generated = self.transformer.transformer.wte(input_ids) for i in range(entry_length): + # TODO-context outputs = self.transformer(inputs_embeds=generated) logits = outputs.logits logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) diff --git a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py index 2a04ec2e4030..2814e303a3fb 100644 --- a/src/diffusers/pipelines/unidiffuser/modeling_uvit.py +++ b/src/diffusers/pipelines/unidiffuser/modeling_uvit.py @@ -1151,16 +1151,17 @@ def forward( hidden_states = self.pos_embed_drop(hidden_states) # 2. Blocks - hidden_states = self.transformer( - hidden_states, - encoder_hidden_states=encoder_hidden_states, - timestep=None, - class_labels=None, - cross_attention_kwargs=cross_attention_kwargs, - return_dict=False, - hidden_states_is_embedding=True, - unpatchify=False, - )[0] + with self.transformer.cache_context("cond"): + hidden_states = self.transformer( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=None, + class_labels=None, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + hidden_states_is_embedding=True, + unpatchify=False, + )[0] # 3. Output # Split out the predicted noise representation. diff --git a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py index e12995106bcf..d717130b1b15 100644 --- a/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +++ b/src/diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py @@ -901,17 +901,18 @@ def __call__( timestep = t.expand(latents.shape[0]).to(latents.dtype) latent_model_input = torch.cat((latents, masked_image_latents), dim=2) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep / 1000, - guidance=guidance, - pooled_projections=pooled_prompt_embeds, - encoder_hidden_states=prompt_embeds, - txt_ids=text_ids, - img_ids=latent_image_ids, - joint_attention_kwargs=self.joint_attention_kwargs, - return_dict=False, - )[0] + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( + hidden_states=latent_model_input, + timestep=timestep / 1000, + guidance=guidance, + pooled_projections=pooled_prompt_embeds, + encoder_hidden_states=prompt_embeds, + txt_ids=text_ids, + img_ids=latent_image_ids, + joint_attention_kwargs=self.joint_attention_kwargs, + return_dict=False, + )[0] # Compute the previous noisy sample x_t -> x_t-1 latents_dtype = latents.dtype diff --git a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py index a976126da7fe..67c4429a8867 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_video2video.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_video2video.py @@ -660,22 +660,24 @@ def __call__( latent_model_input = latents.to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = self.transformer( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = self.transformer( + with self.transformer.cache_context("cond"): + noise_pred = self.transformer( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] + + if self.do_classifier_free_guidance: + with self.transformer.cache_context("uncond"): + noise_uncond = self.transformer( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image.py b/src/diffusers/pipelines/z_image/pipeline_z_image.py index 82bdd7d361b7..5c66b29cb990 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image.py @@ -524,9 +524,10 @@ def __call__( latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) - model_out_list = self.transformer( - latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False - )[0] + with self.transformer.cache_context("cond"): + model_out_list = self.transformer( + latent_model_input_list, timestep_model_input, prompt_embeds_model_input, return_dict=False + )[0] if apply_cfg: # Perform CFG diff --git a/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py index 2b3e80a2082b..94cdc11c6dca 100644 --- a/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py +++ b/src/diffusers/pipelines/z_image/pipeline_z_image_img2img.py @@ -637,11 +637,12 @@ def __call__( latent_model_input = latent_model_input.unsqueeze(2) latent_model_input_list = list(latent_model_input.unbind(dim=0)) - model_out_list = self.transformer( - latent_model_input_list, - timestep_model_input, - prompt_embeds_model_input, - )[0] + with self.transformer.cache_context("cond"): + model_out_list = self.transformer( + latent_model_input_list, + timestep_model_input, + prompt_embeds_model_input, + )[0] if apply_cfg: # Perform CFG