Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SVD #5895

Merged
merged 228 commits into from
Nov 29, 2023
Merged

Add SVD #5895

Show file tree
Hide file tree
Changes from 142 commits
Commits
Show all changes
228 commits
Select commit Hold shift + click to select a range
2f56481
begin model
patil-suraj Nov 21, 2023
58883ee
finish blocks
patil-suraj Nov 22, 2023
7de5d7c
add_embedding
patil-suraj Nov 22, 2023
cad51d4
addition_time_embed_dim
patil-suraj Nov 22, 2023
45c9b56
use TimestepEmbedding
patil-suraj Nov 22, 2023
669824e
fix temporal res block
patil-suraj Nov 22, 2023
ee9d7b8
fix time_pos_embed
patil-suraj Nov 22, 2023
ac94731
fix add_embedding
patil-suraj Nov 22, 2023
5df09ef
add conversion script
patil-suraj Nov 22, 2023
c93606c
fix model
patil-suraj Nov 23, 2023
7b64d3a
up
patil-suraj Nov 23, 2023
edf7121
add new resnet blocks
DN6 Nov 23, 2023
1bd09b1
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 23, 2023
d4cdfa3
make forward work
patil-suraj Nov 23, 2023
165ed7c
return sample in original shape
patil-suraj Nov 23, 2023
28dee6e
fix temb shape in TemporalResnetBlock
patil-suraj Nov 23, 2023
85846f7
add spatio temporal transformers
DN6 Nov 23, 2023
8ee2807
add vae blocks
DN6 Nov 23, 2023
5218f46
fix blocks
DN6 Nov 23, 2023
47684da
update
DN6 Nov 24, 2023
9c9d467
update
DN6 Nov 24, 2023
6f87490
fix shapes in Alphablender and add time activation in res blcok
patil-suraj Nov 24, 2023
ffd9e26
use new blocks
patil-suraj Nov 24, 2023
c8ec445
style
patil-suraj Nov 24, 2023
678d19f
fix temb shape
patil-suraj Nov 24, 2023
b0fc4fd
fix SpatioTemporalResBlock
patil-suraj Nov 24, 2023
5a523e2
reuse TemporalBasicTransformerBlock
patil-suraj Nov 24, 2023
20efe54
fix TemporalBasicTransformerBlock
patil-suraj Nov 24, 2023
6610331
use TransformerSpatioTemporalModel
patil-suraj Nov 24, 2023
29551f8
fix TransformerSpatioTemporalModel
patil-suraj Nov 24, 2023
af1e86a
fix time_context dim
patil-suraj Nov 24, 2023
9117547
clean up
DN6 Nov 24, 2023
8c3fd58
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 24, 2023
6481e94
make temb optional
DN6 Nov 24, 2023
6c69c7a
add blocks
patil-suraj Nov 24, 2023
8e1851a
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 24, 2023
f976f5a
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 24, 2023
1f34311
rename model
patil-suraj Nov 24, 2023
f1457b7
update conversion script
patil-suraj Nov 24, 2023
576fa1c
remove UNetMidBlockSpatioTemporal
patil-suraj Nov 24, 2023
f9def2a
add in init
patil-suraj Nov 24, 2023
6c28367
remove unused arg
patil-suraj Nov 24, 2023
d8c9e67
remove unused arg
patil-suraj Nov 24, 2023
9f22651
remove more unsed args
patil-suraj Nov 24, 2023
dff26ce
up
patil-suraj Nov 24, 2023
0c4192b
up
patil-suraj Nov 24, 2023
24b5c43
check for None
patil-suraj Nov 24, 2023
e684243
update vae
DN6 Nov 24, 2023
05eaec2
Merge branch 'test-v-old' into test-v
DN6 Nov 24, 2023
eefed8a
update up/mid blocks for decoder
DN6 Nov 24, 2023
37c428a
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 24, 2023
122a6bd
begin pipeline
patil-suraj Nov 24, 2023
3e47d3c
adapt scheduler
patil-suraj Nov 24, 2023
b336529
add guidance scalings
patil-suraj Nov 24, 2023
2f35e8c
fix norm eps in temporal transformers
patil-suraj Nov 24, 2023
132fe97
add temporal autoencoder
DN6 Nov 24, 2023
beaaf18
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 24, 2023
efb1e5e
make pipeline run
patil-suraj Nov 24, 2023
e779833
fix frame decodig
patil-suraj Nov 25, 2023
f9954a0
decode in float32
patil-suraj Nov 25, 2023
4d4469e
decode n frames at a time
patil-suraj Nov 25, 2023
9da55b3
pass decoding_t to decode_latents
patil-suraj Nov 25, 2023
4346ddd
fix decode_latents
patil-suraj Nov 25, 2023
7ddd14b
vae encode/decode in fp32
patil-suraj Nov 25, 2023
df98627
fix dtype in TransformerSpatioTemporalModel
patil-suraj Nov 25, 2023
0cf6c6b
type image_latents same as image_embeddings
patil-suraj Nov 25, 2023
d0017d9
allow using differnt eps in temporal block for video decoder
patil-suraj Nov 25, 2023
9af07d1
fix default values in vae
patil-suraj Nov 25, 2023
5316fb5
pass num frames in decode
patil-suraj Nov 25, 2023
b071aaa
switch spatial to temporal for mixing in VAE
patil-suraj Nov 26, 2023
8bcf43d
fix num frames during split decoding
patil-suraj Nov 26, 2023
268ffea
cast alpha to sample dtype
patil-suraj Nov 26, 2023
d930977
fix attention in MidBlockTemporalDecoder
patil-suraj Nov 26, 2023
21148de
fix typo
patil-suraj Nov 26, 2023
712b995
fix guidance_scales dtype
patil-suraj Nov 26, 2023
cf70b9a
fix missing activation in TemporalDecoder
patil-suraj Nov 26, 2023
c3bdeb8
skip_post_quant_conv
patil-suraj Nov 26, 2023
6827a1d
add vae conversion
patil-suraj Nov 26, 2023
96af28f
style
patil-suraj Nov 26, 2023
e34e9d9
take guidance scale as input
patil-suraj Nov 26, 2023
2a46326
up
patil-suraj Nov 26, 2023
fdd182f
allow passing PIL to export_video
patil-suraj Nov 26, 2023
1ce8ff5
accept fps as arg
patil-suraj Nov 26, 2023
cb49cbd
add pipeline and vae in init
patil-suraj Nov 26, 2023
13b646e
remove hack
patil-suraj Nov 26, 2023
d614a33
use AutoencoderKLTemporalDecoder
patil-suraj Nov 26, 2023
f651c12
don't scale image latents
patil-suraj Nov 26, 2023
760333d
add unet tests
DN6 Nov 27, 2023
af85fb1
clean up unet
patil-suraj Nov 27, 2023
6adae54
clean TransformerSpatioTemporalModel
patil-suraj Nov 27, 2023
7b6a0d4
add slow svd test
DN6 Nov 27, 2023
ab8076f
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 27, 2023
f7cf8c3
clean up
DN6 Nov 27, 2023
3fbe123
make temb optional in Decoder mid block
DN6 Nov 27, 2023
b8d84c4
fix norm eps in TransformerSpatioTemporalModel
patil-suraj Nov 27, 2023
1b3cf2d
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 27, 2023
403a81c
clean up temp decoder
DN6 Nov 27, 2023
26ed460
clean up
DN6 Nov 27, 2023
82cf608
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 27, 2023
c9d1727
clean up
DN6 Nov 27, 2023
a193e49
use c_noise values for timesteps
patil-suraj Nov 27, 2023
804bdeb
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 27, 2023
a08ef00
use math for log
patil-suraj Nov 27, 2023
3178b16
update
DN6 Nov 27, 2023
847bd0a
fix copies
patil-suraj Nov 27, 2023
18930e0
doc
patil-suraj Nov 27, 2023
90d8e83
upcast vae
patil-suraj Nov 27, 2023
8620851
update forward pass for gradient checkpointing
DN6 Nov 27, 2023
ee9f7d2
make added_time_ids is tensor
patil-suraj Nov 27, 2023
c452d9c
up
patil-suraj Nov 27, 2023
55b4d09
fix upcasting
patil-suraj Nov 27, 2023
8bc4251
remove post quant conv
DN6 Nov 27, 2023
c5941a2
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 27, 2023
56e8fca
Merge branch 'main' into test-v
DN6 Nov 27, 2023
63335d2
add _resize_with_antialiasing
patil-suraj Nov 27, 2023
479f58c
fix _compute_padding
patil-suraj Nov 27, 2023
da3e46b
cleanup model
patil-suraj Nov 27, 2023
04171ee
more cleanup
patil-suraj Nov 27, 2023
ad213ee
more cleanup
patil-suraj Nov 27, 2023
18cb6d5
more cleanup
patil-suraj Nov 27, 2023
25cfe79
remove freeu
patil-suraj Nov 27, 2023
58f1d61
remove attn slice
patil-suraj Nov 27, 2023
924813a
small clean
patil-suraj Nov 27, 2023
ddad380
up
patil-suraj Nov 27, 2023
0567fd0
up
patil-suraj Nov 27, 2023
05c631e
remove extra step kwargs
patil-suraj Nov 27, 2023
ac00e32
remove eta
patil-suraj Nov 27, 2023
200314d
remove dropout
patil-suraj Nov 27, 2023
782205e
remove callback
patil-suraj Nov 27, 2023
b095e2e
remove merge factor args
patil-suraj Nov 27, 2023
e60b2fe
clean
patil-suraj Nov 27, 2023
3d03e44
clean up
DN6 Nov 28, 2023
2613335
move to dedicated folder
DN6 Nov 28, 2023
0e64d43
remove attention_head_dim
patil-suraj Nov 28, 2023
8e33cb3
docstr and small fix
patil-suraj Nov 28, 2023
2dc556c
update unet doc strings
patil-suraj Nov 28, 2023
e3404fa
rename decoding_t
patil-suraj Nov 28, 2023
73386b4
correct linting
patrickvonplaten Nov 28, 2023
be346ac
store c_skip and c_out
patil-suraj Nov 28, 2023
b74e587
cleanup
patil-suraj Nov 28, 2023
b5e6097
clean TemporalResnetBlock
patil-suraj Nov 28, 2023
783f18d
more cleanup
patil-suraj Nov 28, 2023
51aa79a
clean up vae
DN6 Nov 28, 2023
c5fc4f0
clean up
DN6 Nov 28, 2023
e10e159
begin doc
patil-suraj Nov 28, 2023
ad50592
more cleanup
patil-suraj Nov 28, 2023
a5c7782
up
patil-suraj Nov 28, 2023
a4ba8ef
up
patil-suraj Nov 28, 2023
169ae20
doc
patil-suraj Nov 28, 2023
c2d83f0
Improve
patrickvonplaten Nov 28, 2023
dda9337
better naming
patrickvonplaten Nov 28, 2023
d7a71ed
better naming
patrickvonplaten Nov 28, 2023
550b73f
better naming
patrickvonplaten Nov 28, 2023
9cbe7d6
better naming
patrickvonplaten Nov 28, 2023
1a1067a
better naming
patrickvonplaten Nov 28, 2023
532b861
better naming
patrickvonplaten Nov 28, 2023
878e3ea
better naming
patrickvonplaten Nov 28, 2023
29e57f4
better naming
patrickvonplaten Nov 28, 2023
036c04f
Apply suggestions from code review
patrickvonplaten Nov 28, 2023
889b9e9
Default chunk size to None
patrickvonplaten Nov 28, 2023
eb30dde
add example
patil-suraj Nov 28, 2023
fbe0936
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 28, 2023
724a134
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 28, 2023
aed458f
Better
patrickvonplaten Nov 28, 2023
4ca4b33
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patrickvonplaten Nov 28, 2023
e732921
Apply suggestions from code review
patrickvonplaten Nov 28, 2023
994bf57
update doc
patil-suraj Nov 28, 2023
ffc2a1e
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 28, 2023
ad87aa4
Update src/diffusers/pipelines/stable_diffusion_video/pipeline_stable…
patil-suraj Nov 28, 2023
4e60bb7
style
patil-suraj Nov 28, 2023
f37b782
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 28, 2023
36df75c
Get torch compile working
patrickvonplaten Nov 28, 2023
f107be7
up
patil-suraj Nov 28, 2023
dbc2d2d
rename
patil-suraj Nov 28, 2023
b69e753
fix doc
patil-suraj Nov 28, 2023
57f11d6
add chunking
DN6 Nov 28, 2023
9bce8bb
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 28, 2023
43b63d6
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patrickvonplaten Nov 28, 2023
d27999e
torch compile
patrickvonplaten Nov 28, 2023
e17dda8
torch compile
patrickvonplaten Nov 28, 2023
381ea56
add modelling outputs
DN6 Nov 28, 2023
0df06dd
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 28, 2023
79fbd84
torch compile
patrickvonplaten Nov 28, 2023
4601fc1
Improve chunking
patrickvonplaten Nov 28, 2023
562d9d0
Apply suggestions from code review
patrickvonplaten Nov 28, 2023
2d513f7
Update docs/source/en/using-diffusers/svd.md
patrickvonplaten Nov 28, 2023
5f3a2b8
Close diff tag
apolinario Nov 28, 2023
6aba6e5
remove slicing
patil-suraj Nov 28, 2023
efd0a72
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 28, 2023
9162734
resnet docstr
patil-suraj Nov 28, 2023
a7342a1
add docstr in resnet
patil-suraj Nov 28, 2023
d409239
rename
patil-suraj Nov 28, 2023
52ab94b
Apply suggestions from code review
patrickvonplaten Nov 28, 2023
eac5399
update tests
DN6 Nov 28, 2023
9fa5d12
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 28, 2023
ecc7882
Fix output type latents
patrickvonplaten Nov 28, 2023
5143e01
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patrickvonplaten Nov 28, 2023
58814b0
fix more
patrickvonplaten Nov 28, 2023
21e627f
fix more
patrickvonplaten Nov 28, 2023
8510c7e
Update docs/source/en/using-diffusers/svd.md
patrickvonplaten Nov 28, 2023
557f638
fix more
patrickvonplaten Nov 28, 2023
deee57e
add pipeline tests
DN6 Nov 28, 2023
b33e42e
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 28, 2023
3b76055
remove unused arg
patil-suraj Nov 29, 2023
5f278af
clean up
DN6 Nov 29, 2023
9320cb7
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
DN6 Nov 29, 2023
d73fa34
make sure get_scaling receives tensors
patil-suraj Nov 29, 2023
7e42e28
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 29, 2023
5857cbc
fix euler scheduler
patil-suraj Nov 29, 2023
206f457
fix get_scalings
patil-suraj Nov 29, 2023
877e8bd
simply euler for now
patil-suraj Nov 29, 2023
5619c72
remove old test file
patil-suraj Nov 29, 2023
c888b98
use randn_tensor to create noise
patil-suraj Nov 29, 2023
109971b
fix device for rand tensor
patil-suraj Nov 29, 2023
f1be9ce
increase expected_max_difference
patil-suraj Nov 29, 2023
4e75f06
fix test_inference_batch_single_identical
patil-suraj Nov 29, 2023
46b129b
actually fix test_inference_batch_single_identical
patil-suraj Nov 29, 2023
367426e
disable test_save_load_float16
patil-suraj Nov 29, 2023
d0895b1
skip test_float16_inference
patil-suraj Nov 29, 2023
614f9ad
skip test_inference_batch_single_identical
patil-suraj Nov 29, 2023
60625db
fix test_xformers_attention_forwardGenerator_pass
patil-suraj Nov 29, 2023
8fc51ab
Apply suggestions from code review
patrickvonplaten Nov 29, 2023
fcf0790
update StableVideoDiffusionPipelineSlowTests
patil-suraj Nov 29, 2023
66ded24
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patil-suraj Nov 29, 2023
9962f91
update image
patil-suraj Nov 29, 2023
fbb131c
add diffusers example
patrickvonplaten Nov 29, 2023
896485a
Merge branch 'test-v' of https://github.com/huggingface/diffusers int…
patrickvonplaten Nov 29, 2023
4c04ca2
fix more
patrickvonplaten Nov 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
730 changes: 730 additions & 0 deletions scripts/convert_svd_to_diffusers.py

Large diffs are not rendered by default.

6 changes: 6 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
[
"AsymmetricAutoencoderKL",
"AutoencoderKL",
"AutoencoderKLTemporalDecoder",
"AutoencoderTiny",
"ConsistencyDecoderVAE",
"ControlNetModel",
Expand All @@ -92,6 +93,7 @@
"UNet2DModel",
"UNet3DConditionModel",
"UNetMotionModel",
"UNetSpatioTemporalConditionModel",
"VQModel",
]
)
Expand Down Expand Up @@ -267,6 +269,7 @@
"StableDiffusionPix2PixZeroPipeline",
"StableDiffusionSAGPipeline",
"StableDiffusionUpscalePipeline",
"StableDiffusionVideoPipeline",
"StableDiffusionXLAdapterPipeline",
"StableDiffusionXLControlNetImg2ImgPipeline",
"StableDiffusionXLControlNetInpaintPipeline",
Expand Down Expand Up @@ -446,6 +449,7 @@
from .models import (
AsymmetricAutoencoderKL,
AutoencoderKL,
AutoencoderKLTemporalDecoder,
AutoencoderTiny,
ConsistencyDecoderVAE,
ControlNetModel,
Expand All @@ -462,6 +466,7 @@
UNet2DModel,
UNet3DConditionModel,
UNetMotionModel,
UNetSpatioTemporalConditionModel,
VQModel,
)
from .optimization import (
Expand Down Expand Up @@ -616,6 +621,7 @@
StableDiffusionPix2PixZeroPipeline,
StableDiffusionSAGPipeline,
StableDiffusionUpscalePipeline,
StableDiffusionVideoPipeline,
StableDiffusionXLAdapterPipeline,
StableDiffusionXLControlNetImg2ImgPipeline,
StableDiffusionXLControlNetInpaintPipeline,
Expand Down
11 changes: 10 additions & 1 deletion src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

from typing import TYPE_CHECKING

from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
from ..utils import (
DIFFUSERS_SLOW_IMPORT,
_LazyModule,
is_flax_available,
is_torch_available,
)


_import_structure = {}
Expand All @@ -23,6 +28,7 @@
_import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
_import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
_import_structure["autoencoder_kl"] = ["AutoencoderKL"]
_import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
_import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
_import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
_import_structure["controlnet"] = ["ControlNetModel"]
Expand All @@ -38,6 +44,7 @@
_import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
_import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
_import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
_import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
_import_structure["vq_model"] = ["VQModel"]

if is_flax_available():
Expand All @@ -51,6 +58,7 @@
from .adapter import MultiAdapter, T2IAdapter
from .autoencoder_asym_kl import AsymmetricAutoencoderKL
from .autoencoder_kl import AutoencoderKL
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
from .autoencoder_tiny import AutoencoderTiny
from .consistency_decoder_vae import ConsistencyDecoderVAE
from .controlnet import ControlNetModel
Expand All @@ -66,6 +74,7 @@
from .unet_3d_condition import UNet3DConditionModel
from .unet_kandi3 import Kandinsky3UNet
from .unet_motion_model import MotionAdapter, UNetMotionModel
from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
from .vq_model import VQModel

if is_flax_available():
Expand Down
142 changes: 141 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,12 @@ def __init__(
if not self.use_ada_layer_norm_single:
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)

self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
self.ff = FeedForward(
dim,
dropout=dropout,
activation_fn=activation_fn,
final_dropout=final_dropout,
)

# 4. Fuser
if attention_type == "gated" or attention_type == "gated-text-image":
Expand Down Expand Up @@ -339,6 +344,141 @@ def forward(
return hidden_states


@maybe_allow_in_graph
class TemporalBasicTransformerBlock(nn.Module):
r"""
A basic Transformer block for video like data.

Parameters:
dim (`int`): The number of channels in the input and output.
time_mix_inner_dim (`int`): The number of channels for temporal attention.
num_attention_heads (`int`): The number of heads to use for multi-head attention.
attention_head_dim (`int`): The number of channels in each head.
cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
"""

def __init__(
self,
dim: int,
time_mix_inner_dim: int,
num_attention_heads: int,
attention_head_dim: int,
cross_attention_dim: Optional[int] = None,
):
super().__init__()
self.is_res = dim == time_mix_inner_dim

self.norm_in = nn.LayerNorm(dim)

# Define 3 blocks. Each block has its own normalization layer.
# 1. Self-Attn
self.norm_in = nn.LayerNorm(dim)
self.ff_in = FeedForward(
dim,
dim_out=time_mix_inner_dim,
activation_fn="geglu",
)

self.norm1 = nn.LayerNorm(time_mix_inner_dim)
self.attn1 = Attention(
query_dim=time_mix_inner_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
cross_attention_dim=None,
)

# 2. Cross-Attn
if cross_attention_dim is not None:
# We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
# I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
# the second cross attention block.
self.norm2 = nn.LayerNorm(time_mix_inner_dim)
self.attn2 = Attention(
query_dim=time_mix_inner_dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
) # is self-attn if encoder_hidden_states is none
else:
self.norm2 = None
self.attn2 = None

# 3. Feed-forward
self.norm3 = nn.LayerNorm(time_mix_inner_dim)
self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")

# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
patrickvonplaten marked this conversation as resolved.
Show resolved Hide resolved

def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
# Sets chunk feed-forward
self._chunk_size = chunk_size
self._chunk_dim = dim

def forward(
self,
hidden_states: torch.FloatTensor,
num_frames: int,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
) -> torch.FloatTensor:
# Notice that normalization is always applied before the real computation in the following blocks.
# 0. Self-Attention
batch_size = hidden_states.shape[0]

batch_frames, seq_length, channels = hidden_states.shape
batch_size = batch_frames // num_frames

hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)

residual = hidden_states
hidden_states = self.norm_in(hidden_states)
hidden_states = self.ff_in(hidden_states)
if self.is_res:
hidden_states = hidden_states + residual

norm_hidden_states = self.norm1(hidden_states)
attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
hidden_states = attn_output + hidden_states

# 3. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = self.norm2(hidden_states)
attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
hidden_states = attn_output + hidden_states

# 4. Feed-forward
norm_hidden_states = self.norm3(hidden_states)

if self._chunk_size is not None:
# "feed_forward_chunk_size" can be used to save memory
if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
raise ValueError(
f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
)

num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
ff_output = torch.cat(
[self.ff(hid_slice) for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)],
dim=self._chunk_dim,
)
else:
ff_output = self.ff(norm_hidden_states)

if self.is_res:
hidden_states = ff_output + hidden_states
else:
hidden_states = ff_output

hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
hidden_states = hidden_states.permute(0, 2, 1, 3)
hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)

return hidden_states


class FeedForward(nn.Module):
r"""
A feed-forward layer.
Expand Down
Loading
Loading