From d978c00e73b2c89124731a7f51cd57e4dc2e798e Mon Sep 17 00:00:00 2001 From: okaris Date: Tue, 29 Jul 2025 14:35:07 +0000 Subject: [PATCH 1/2] enable caching for WanImageToVideoPipeline --- .../pipelines/wan/pipeline_wan_i2v.py | 29 ++++++++++--------- 1 file changed, 16 insertions(+), 13 deletions(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 24e9cccdb440..ea64dfe48540 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -750,25 +750,28 @@ def __call__( latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype) timestep = t.expand(latents.shape[0]) - noise_pred = current_model( - hidden_states=latent_model_input, - timestep=timestep, - encoder_hidden_states=prompt_embeds, - encoder_hidden_states_image=image_embeds, - attention_kwargs=attention_kwargs, - return_dict=False, - )[0] - - if self.do_classifier_free_guidance: - noise_uncond = current_model( + with current_model.cache_context("cond"): + noise_pred = current_model( hidden_states=latent_model_input, timestep=timestep, - encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states=prompt_embeds, encoder_hidden_states_image=image_embeds, attention_kwargs=attention_kwargs, return_dict=False, )[0] - noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) + + + if self.do_classifier_free_guidance: + with current_model.cache_context("uncond"): + noise_uncond = current_model( + hidden_states=latent_model_input, + timestep=timestep, + encoder_hidden_states=negative_prompt_embeds, + encoder_hidden_states_image=image_embeds, + attention_kwargs=attention_kwargs, + return_dict=False, + )[0] + noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] From fc0677f979408dda1c7cfa9e750f39bd56f284ac Mon Sep 17 00:00:00 2001 From: okaris Date: Wed, 30 Jul 2025 09:42:09 +0000 Subject: [PATCH 2/2] ruff format --- src/diffusers/pipelines/wan/pipeline_wan_i2v.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index ea64dfe48540..a072824a4854 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -760,7 +760,6 @@ def __call__( return_dict=False, )[0] - if self.do_classifier_free_guidance: with current_model.cache_context("uncond"): noise_uncond = current_model(