-
Notifications
You must be signed in to change notification settings - Fork 6.6k
fix spatial compression ratio error for AutoEncoderKLWan doing tiled encode #12753
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Hi thanks for the PR, |
hello, @yiyixuxu , this is my test code and the corresponding model output. the task is a Text-Image to video task so WanImageToVideoPipeline is used, the input image and generated video is attatched bellow. from diffusers import WanImageToVideoPipeline
from PIL import Image
from diffusers.utils import export_to_video
import torch
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
model_path = 'models/Wan2.2-TI2V-5B-Diffusers'
pipe = WanImageToVideoPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16)
pipe = pipe.to('cuda')
pipe.vae.enable_tiling()
prompt = "a cute anime girl with fennec ears and a fluffy tail walking in a beautiful field"
negative_prompt = ("色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,"
"整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,"
"画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,"
"静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走")
image_file = 'images/fennec_girl_hug.png'
image = Image.open(image_file).convert('RGB')
video = pipe(
prompt=prompt,
negative_prompt = negative_prompt,
image=image,
num_frames=81,
num_inference_steps=40,
height=800,
width=1280,
guidance_scale = 6.0,
generator = torch.Generator(device='cuda').manual_seed(869173064731527),
).frames[0]
export_to_video(video, "fennec_girl_hug.mp4", fps = 20)the generated video without slicing image encode and video decode: the generated video with slicing enabled: |
|
hello @yiyixuxu , can you review this fix, or give more feedback? |
yiyixuxu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks!
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |

What does this PR do?
WanImageToVideoPipeline with vae.enable_tiling() run with the following errors, and the error is caused because encode spatial compression ratio is computed without considering the vae.config.patch_size. here is a fix, please considering merge it, thanks!
Traceback (most recent call last):
File "/mnt/bn/lirui926-hl/mlx/users/lirui.926/playground/codes/dit-inference/wan22_ti2v.py", line 402, in
output_filename, perf_stats = wan.run(
~~~~~~~^
task_name=task_name,
^^^^^^^^^^^^^^^^^^^^
**case_param
^^^^^^^^^^^^
)
^
File "/mnt/bn/lirui926-hl/mlx/users/lirui.926/playground/codes/dit-inference/wan22_ti2v.py", line 120, in wrapper
return func(self, *args, **kwargs)
File "/mnt/bn/lirui926-hl/mlx/users/lirui.926/playground/codes/dit-inference/wan22_ti2v.py", line 283, in run
output = self.taskstask_name
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/utils/_contextlib.py", line 120, in decorate_context
return func(*args, **kwargs)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 699, in call
latents_outputs = self.prepare_latents(
image,
...<9 lines>...
last_image,
)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 455, in prepare_latents
latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax")
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1191, in encode
h = self._encode(x)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1149, in _encode
return self.tiled_encode(x)
~~~~~~~~~~~~~~~~~^^^
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 1313, in tiled_encode
tile = self.encoder(tile, feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 593, in forward
x = self.conv_in(x, feat_cache[idx])
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
return forward_call(*args, **kwargs)
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 176, in forward
return super().forward(x)
~~~~~~~~~~~~~~~^^^
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/nn/modules/conv.py", line 717, in forward
return self._conv_forward(input, self.weight, self.bias)
~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/root/miniconda3/envs/torch/lib/python3.13/site-packages/torch/nn/modules/conv.py", line 712, in _conv_forward
return F.conv3d(
~~~~~~~~^
input, weight, bias, self.stride, self.padding, self.dilation, self.groups
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
)
^
RuntimeError: Given groups=1, weight of size [160, 12, 3, 3, 3], expected input[1, 3, 3, 662, 802] to have 12 channels, but got 3 channels instead
Fixes # (issue)
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.