Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 7 additions & 20 deletions src/diffusers/pipelines/wan/pipeline_wan_i2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,31 +108,16 @@ def prompt_clean(text):
return text


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
def retrieve_latents(
encoder_output: torch.Tensor,
latents_mean: torch.Tensor,
latents_std: torch.Tensor,
generator: Optional[torch.Generator] = None,
sample_mode: str = "sample",
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
):
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.sample(generator)
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
encoder_output.latent_dist.mean = (encoder_output.latent_dist.mean - latents_mean) * latents_std
encoder_output.latent_dist.logvar = torch.clamp(
(encoder_output.latent_dist.logvar - latents_mean) * latents_std, -30.0, 20.0
)
encoder_output.latent_dist.std = torch.exp(0.5 * encoder_output.latent_dist.logvar)
encoder_output.latent_dist.var = torch.exp(encoder_output.latent_dist.logvar)
return encoder_output.latent_dist.mode()
elif hasattr(encoder_output, "latents"):
return (encoder_output.latents - latents_mean) * latents_std
return encoder_output.latents
else:
raise AttributeError("Could not access latents of provided encoder_output")

Expand Down Expand Up @@ -412,13 +397,15 @@ def prepare_latents(

if isinstance(generator, list):
latent_condition = [
retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, g) for g in generator
retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator
]
latent_condition = torch.cat(latent_condition)
else:
latent_condition = retrieve_latents(self.vae.encode(video_condition), latents_mean, latents_std, generator)
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)

latent_condition = (latent_condition - latents_mean) * latents_std

mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
mask_lat_size[:, :, list(range(1, num_frames))] = 0
first_frame_mask = mask_lat_size[:, :, 0:1]
Expand Down
Loading