-
Notifications
You must be signed in to change notification settings - Fork 6.3k
[Pipeline] Fix error of SVD pipeline when num_videos_per_prompt > 1 #7786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi, Thanks for your PR :) Possible to add a test for this as well? @DN6 seeing a couple of issues/PRs related to using SVD for generating more than one videos per prompt. Should we consider adding a SLOW test for it as well? |
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
Sure, I will try to add a test for this. |
Hello , @sayakpaul I am not familiar with pytest, so I simply followed the instructions here. $ pytest tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py
============================================================== test session starts ===============================================================
platform linux -- Python 3.10.14, pytest-8.2.0, pluggy-1.5.0
rootdir: /home/wu.yushu/src/efficient_diffusion/fix_diffusers/diffusers
configfile: pyproject.toml
plugins: timeout-2.3.1, xdist-3.5.0, requests-mock-1.10.0
collected 26 items
tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py .s.......sss............ss [100%]
================================================================ warnings summary ================================================================
diffuser_dev/lib/python3.10/site-packages/flax/core/meta.py:31
diffuser_dev/lib/python3.10/site-packages/flax/core/meta.py:31: DeprecationWarning: jax.experimental.maps and jax.experimental.maps.xmap are deprecated and will be removed in a future release. Use jax.experimental.shard_map or jax.vmap with the spmd_axis_name argument for expressing SPMD device-parallel computations. Please file an issue on https://github.com/google/jax/issues if neither jax.experimental.shard_map nor jax.vmap are suitable for your use case.
from jax.experimental import maps
tests/pipelines/stable_video_diffusion/test_stable_video_diffusion.py::StableVideoDiffusionPipelineFastTests::test_save_load_float16
diffuser_dev/lib/python3.10/site-packages/torch/nn/modules/conv.py:605: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return F.conv3d(
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=================================================== 20 passed, 6 skipped, 2 warnings in 22.83s =================================================== I am not sure if it is what you expected. Thanks. |
@DN6 could you give this a look too? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for catching this! LGTM 👍🏽
…7786) swap the order for do_classifier_free_guidance concat with repeat Co-authored-by: Sayak Paul <spsayakpaul@gmail.com> Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
What does this PR do?
Fixes #Issue
it should repeat latents before concatenate do_classifier_free_guidance latents to align with the CFG process after noise_pred.
Before submitting
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
@sayakpaul @yiyixuxu @DN6
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.