diff --git a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py index 86d5496d1623..10fa7b55c36e 100644 --- a/src/diffusers/pipelines/wan/pipeline_wan_i2v.py +++ b/src/diffusers/pipelines/wan/pipeline_wan_i2v.py @@ -409,7 +409,7 @@ def prepare_latents( [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], dim=2, ) - video_condition = video_condition.to(device=device, dtype=dtype) + video_condition = video_condition.to(device=device, dtype=self.vae.dtype) latents_mean = ( torch.tensor(self.vae.config.latents_mean) @@ -429,6 +429,7 @@ def prepare_latents( latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) + latent_condition = latent_condition.to(dtype) latent_condition = (latent_condition - latents_mean) * latents_std mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)