Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SVD #5895

Merged
merged 228 commits into from
Nov 29, 2023
Merged

Add SVD #5895

merged 228 commits into from
Nov 29, 2023

Conversation

patil-suraj
Copy link
Contributor

@patil-suraj patil-suraj commented Nov 22, 2023

What does this PR do?

Adds Stable Video Diffusion.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Nov 22, 2023

The documentation is not available anymore as the PR was closed or merged.

@drhead
Copy link
Contributor

drhead commented Nov 22, 2023

Is this PR going to add support for the temporally-aware VAE? I am currently working on porting that module and don't want to end up creating any conflicts.

edit: can disregard, I can now see that after the model components implemented here are complete, implementation of the VAE decoder itself would be a trivial matter.

@patil-suraj
Copy link
Contributor Author

@drhead Yes, this PR will support everything related to SVD.

@tin2tin
Copy link

tin2tin commented Nov 24, 2023

@tin2tin tin2tin mentioned this pull request Nov 24, 2023
2 tasks
@patrickvonplaten patrickvonplaten merged commit 63f767e into main Nov 29, 2023
22 checks passed
@patil-suraj patil-suraj deleted the test-v branch November 29, 2023 18:14
```

<video width="1024" height="576" controls>
<source src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd/rocket_generated.mp4?download=true" type="video/mp4">
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove ?download=true

@liuquande
Copy link

Hi, would this PR consider adding model training/finetuning script for stable video diffusion, thanks!

@jeff-da
Copy link

jeff-da commented Dec 18, 2023

FPS should be set as a constant somewhere? I see both 7 and 8 used.

def export_to_video(

@shliu0
Copy link

shliu0 commented Dec 20, 2023

Hi, is this PR going to support LCM-LoRA like what have been done in SD image models?

yoonseokjin pushed a commit to yoonseokjin/diffusers that referenced this pull request Dec 25, 2023
* begin model

* finish blocks

* add_embedding

* addition_time_embed_dim

* use TimestepEmbedding

* fix temporal res block

* fix time_pos_embed

* fix add_embedding

* add conversion script

* fix model

* up

* add new resnet blocks

* make forward work

* return sample in original shape

* fix temb shape in TemporalResnetBlock

* add spatio temporal transformers

* add vae blocks

* fix blocks

* update

* update

* fix shapes in Alphablender and add time activation in res blcok

* use new blocks

* style

* fix temb shape

* fix SpatioTemporalResBlock

* reuse TemporalBasicTransformerBlock

* fix TemporalBasicTransformerBlock

* use TransformerSpatioTemporalModel

* fix TransformerSpatioTemporalModel

* fix time_context dim

* clean up

* make temb optional

* add blocks

* rename model

* update conversion script

* remove UNetMidBlockSpatioTemporal

* add in init

* remove unused arg

* remove unused arg

* remove more unsed args

* up

* up

* check for None

* update vae

* update up/mid blocks for decoder

* begin pipeline

* adapt scheduler

* add guidance scalings

* fix norm eps in temporal transformers

* add temporal autoencoder

* make pipeline run

* fix frame decodig

* decode in float32

* decode n frames at a time

* pass decoding_t to decode_latents

* fix decode_latents

* vae encode/decode in fp32

* fix dtype in TransformerSpatioTemporalModel

* type image_latents same as image_embeddings

* allow using differnt eps in temporal block for video decoder

* fix default values in vae

* pass num frames in decode

* switch spatial to temporal for mixing in VAE

* fix num frames during split decoding

* cast alpha to sample dtype

* fix attention in MidBlockTemporalDecoder

* fix typo

* fix guidance_scales dtype

* fix missing activation in TemporalDecoder

* skip_post_quant_conv

* add vae conversion

* style

* take guidance scale as input

* up

* allow passing PIL to export_video

* accept fps as arg

* add pipeline and vae in init

* remove hack

* use AutoencoderKLTemporalDecoder

* don't scale image latents

* add unet tests

* clean up unet

* clean TransformerSpatioTemporalModel

* add slow svd test

* clean up

* make temb optional in Decoder mid block

* fix norm eps in TransformerSpatioTemporalModel

* clean up temp decoder

* clean up

* clean up

* use c_noise values for timesteps

* use math for log

* update

* fix copies

* doc

* upcast vae

* update forward pass for gradient checkpointing

* make added_time_ids is tensor

* up

* fix upcasting

* remove post quant conv

* add _resize_with_antialiasing

* fix _compute_padding

* cleanup model

* more cleanup

* more cleanup

* more cleanup

* remove freeu

* remove attn slice

* small clean

* up

* up

* remove extra step kwargs

* remove eta

* remove dropout

* remove callback

* remove merge factor args

* clean

* clean up

* move to dedicated folder

* remove attention_head_dim

* docstr and small fix

* update unet doc strings

* rename decoding_t

* correct linting

* store c_skip and c_out

* cleanup

* clean TemporalResnetBlock

* more cleanup

* clean up vae

* clean up

* begin doc

* more cleanup

* up

* up

* doc

* Improve

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* Apply suggestions from code review

* Default chunk size to None

* add example

* Better

* Apply suggestions from code review

* update doc

* Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable_diffusion_video.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* style

* Get torch compile working

* up

* rename

* fix doc

* add chunking

* torch compile

* torch compile

* add modelling outputs

* torch compile

* Improve chunking

* Apply suggestions from code review

* Update docs/source/en/using-diffusers/svd.md

* Close diff tag

* remove slicing

* resnet docstr

* add docstr in resnet

* rename

* Apply suggestions from code review

* update tests

* Fix output type latents

* fix more

* fix more

* Update docs/source/en/using-diffusers/svd.md

* fix more

* add pipeline tests

* remove unused arg

* clean  up

* make sure get_scaling receives tensors

* fix euler scheduler

* fix get_scalings

* simply euler for now

* remove old test file

* use randn_tensor to create noise

* fix device for rand tensor

* increase expected_max_difference

* fix test_inference_batch_single_identical

* actually fix test_inference_batch_single_identical

* disable test_save_load_float16

* skip test_float16_inference

* skip test_inference_batch_single_identical

* fix test_xformers_attention_forwardGenerator_pass

* Apply suggestions from code review

* update StableVideoDiffusionPipelineSlowTests

* update image

* add diffusers example

* fix more

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
@Kaihua-Chen
Copy link

Thanks for supporting stable video diffusion! Should we consider this as the official implementation (e.g., was the performance verified with the original paper)?

AmericanPresidentJimmyCarter pushed a commit to AmericanPresidentJimmyCarter/diffusers that referenced this pull request Apr 26, 2024
* begin model

* finish blocks

* add_embedding

* addition_time_embed_dim

* use TimestepEmbedding

* fix temporal res block

* fix time_pos_embed

* fix add_embedding

* add conversion script

* fix model

* up

* add new resnet blocks

* make forward work

* return sample in original shape

* fix temb shape in TemporalResnetBlock

* add spatio temporal transformers

* add vae blocks

* fix blocks

* update

* update

* fix shapes in Alphablender and add time activation in res blcok

* use new blocks

* style

* fix temb shape

* fix SpatioTemporalResBlock

* reuse TemporalBasicTransformerBlock

* fix TemporalBasicTransformerBlock

* use TransformerSpatioTemporalModel

* fix TransformerSpatioTemporalModel

* fix time_context dim

* clean up

* make temb optional

* add blocks

* rename model

* update conversion script

* remove UNetMidBlockSpatioTemporal

* add in init

* remove unused arg

* remove unused arg

* remove more unsed args

* up

* up

* check for None

* update vae

* update up/mid blocks for decoder

* begin pipeline

* adapt scheduler

* add guidance scalings

* fix norm eps in temporal transformers

* add temporal autoencoder

* make pipeline run

* fix frame decodig

* decode in float32

* decode n frames at a time

* pass decoding_t to decode_latents

* fix decode_latents

* vae encode/decode in fp32

* fix dtype in TransformerSpatioTemporalModel

* type image_latents same as image_embeddings

* allow using differnt eps in temporal block for video decoder

* fix default values in vae

* pass num frames in decode

* switch spatial to temporal for mixing in VAE

* fix num frames during split decoding

* cast alpha to sample dtype

* fix attention in MidBlockTemporalDecoder

* fix typo

* fix guidance_scales dtype

* fix missing activation in TemporalDecoder

* skip_post_quant_conv

* add vae conversion

* style

* take guidance scale as input

* up

* allow passing PIL to export_video

* accept fps as arg

* add pipeline and vae in init

* remove hack

* use AutoencoderKLTemporalDecoder

* don't scale image latents

* add unet tests

* clean up unet

* clean TransformerSpatioTemporalModel

* add slow svd test

* clean up

* make temb optional in Decoder mid block

* fix norm eps in TransformerSpatioTemporalModel

* clean up temp decoder

* clean up

* clean up

* use c_noise values for timesteps

* use math for log

* update

* fix copies

* doc

* upcast vae

* update forward pass for gradient checkpointing

* make added_time_ids is tensor

* up

* fix upcasting

* remove post quant conv

* add _resize_with_antialiasing

* fix _compute_padding

* cleanup model

* more cleanup

* more cleanup

* more cleanup

* remove freeu

* remove attn slice

* small clean

* up

* up

* remove extra step kwargs

* remove eta

* remove dropout

* remove callback

* remove merge factor args

* clean

* clean up

* move to dedicated folder

* remove attention_head_dim

* docstr and small fix

* update unet doc strings

* rename decoding_t

* correct linting

* store c_skip and c_out

* cleanup

* clean TemporalResnetBlock

* more cleanup

* clean up vae

* clean up

* begin doc

* more cleanup

* up

* up

* doc

* Improve

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* better naming

* Apply suggestions from code review

* Default chunk size to None

* add example

* Better

* Apply suggestions from code review

* update doc

* Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable_diffusion_video.py

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* style

* Get torch compile working

* up

* rename

* fix doc

* add chunking

* torch compile

* torch compile

* add modelling outputs

* torch compile

* Improve chunking

* Apply suggestions from code review

* Update docs/source/en/using-diffusers/svd.md

* Close diff tag

* remove slicing

* resnet docstr

* add docstr in resnet

* rename

* Apply suggestions from code review

* update tests

* Fix output type latents

* fix more

* fix more

* Update docs/source/en/using-diffusers/svd.md

* fix more

* add pipeline tests

* remove unused arg

* clean  up

* make sure get_scaling receives tensors

* fix euler scheduler

* fix get_scalings

* simply euler for now

* remove old test file

* use randn_tensor to create noise

* fix device for rand tensor

* increase expected_max_difference

* fix test_inference_batch_single_identical

* actually fix test_inference_batch_single_identical

* disable test_save_load_float16

* skip test_float16_inference

* skip test_inference_batch_single_identical

* fix test_xformers_attention_forwardGenerator_pass

* Apply suggestions from code review

* update StableVideoDiffusionPipelineSlowTests

* update image

* add diffusers example

* fix more

---------

Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com>
Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: apolinário <joaopaulo.passos@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet