Skip to content

Commit da48940

Browse files
committed
update transformer
1 parent 64275b0 commit da48940

File tree

3 files changed

+125
-182
lines changed

3 files changed

+125
-182
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 138 deletions
Original file line numberDiff line numberDiff line change
@@ -717,144 +717,6 @@ def fuse_projections(self, fuse=True):
717717
self.fused_projections = fuse
718718

719719

720-
class AsymmetricAttention(nn.Module):
721-
def __init__(
722-
self,
723-
query_dim: int,
724-
query_context_dim: int,
725-
num_attentions_heads: int = 8,
726-
attention_head_dim: int = 64,
727-
bias: bool = False,
728-
context_bias: bool = False,
729-
out_dim: Optional[int] = None,
730-
out_context_dim: Optional[int] = None,
731-
qk_norm: Optional[str] = None,
732-
eps: float = 1e-5,
733-
elementwise_affine: bool = True,
734-
out_bias: bool = True,
735-
processor: Optional["AttnProcessor"] = None,
736-
) -> None:
737-
from .normalization import RMSNorm
738-
739-
self.query_dim = query_dim
740-
self.query_context_dim = query_context_dim
741-
self.inner_dim = out_dim if out_dim is not None else num_attentions_heads * attention_head_dim
742-
self.out_dim = out_dim if out_dim is not None else query_dim
743-
self.out_context_dim = out_context_dim if out_context_dim is not None else query_context_dim
744-
745-
self.scale = attention_head_dim ** -0.5
746-
self.num_attention_heads = out_dim // attention_head_dim if out_dim is not None else num_attentions_heads
747-
748-
if qk_norm is None:
749-
self.norm_q = None
750-
self.norm_k = None
751-
self.norm_context_q = None
752-
self.norm_context_k = None
753-
elif qk_norm == "rms_norm":
754-
self.norm_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
755-
self.norm_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
756-
self.norm_context_q = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
757-
self.norm_context_k = RMSNorm(attention_head_dim, eps=eps, elementwise_affine=elementwise_affine)
758-
else:
759-
raise ValueError((f"Unknown qk_norm: {qk_norm}. Should be None or `rms_norm`."))
760-
761-
self.to_q = nn.Linear(query_dim, self.inner_dim, bias=bias)
762-
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
763-
self.to_k = nn.Linear(query_dim, self.inner_dim, bias=bias)
764-
765-
self.to_context_q = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
766-
self.to_context_k = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
767-
self.to_context_v = nn.Linear(query_context_dim, self.inner_dim, bias=context_bias)
768-
769-
# TODO(aryan): Take care of dropouts for training purpose in future
770-
self.to_out = nn.ModuleList([
771-
nn.Linear(self.inner_dim, self.out_dim)
772-
])
773-
self.to_out = nn.ModuleList([
774-
nn.Linear(self.inner_dim, self.out_context_dim)
775-
])
776-
777-
if processor is None:
778-
processor = AsymmetricAttnProcessor2_0()
779-
780-
self.set_processor(processor)
781-
782-
783-
# Similar to SD3
784-
# class AsymmetricAttnProcessor2_0:
785-
# r"""
786-
# Processor for implementing Asymmetric SDPA as described in Genmo/Mochi (TODO(aryan) add link).
787-
# """
788-
789-
# def __init__(self):
790-
# if not hasattr(F, "scaled_dot_product_attention"):
791-
# raise ImportError("AsymmetricAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
792-
793-
# def __call__(
794-
# self,
795-
# attn: AsymmetricAttention,
796-
# hidden_states: torch.Tensor,
797-
# encoder_hidden_states: torch.Tensor,
798-
# temb: torch.Tensor,
799-
# image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
800-
# ) -> torch.Tensor:
801-
# batch_size = hidden_states.size(0)
802-
803-
# query = attn.to_q(hidden_states)
804-
# key = attn.to_k(hidden_states)
805-
# value = attn.to_v(hidden_states)
806-
807-
# query_context = attn.to_context_q(encoder_hidden_states)
808-
# key_context = attn.to_context_k(encoder_hidden_states)
809-
# value_context = attn.to_context_v(encoder_hidden_states)
810-
811-
# inner_dim = key.shape[-1]
812-
# head_dim = inner_dim / attn.num_attention_heads
813-
814-
# query = query.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
815-
# key = key.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
816-
# value = value.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
817-
818-
# query_context = query_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
819-
# key_context = key_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
820-
# value_context = value_context.unflatten(2, (attn.num_attention_heads, head_dim)).transpose(1, 2)
821-
822-
# if attn.norm_q is not None:
823-
# query = attn.norm_q(query)
824-
# if attn.norm_k is not None:
825-
# key = attn.norm_k(key)
826-
827-
# if attn.norm_context_q is not None:
828-
# query_context = attn.norm_context_q(query_context)
829-
# key_context = attn.norm_context_k(key_context)
830-
831-
# if image_rotary_emb is not None:
832-
# from .embeddings import apply_rotary_emb
833-
834-
# query = apply_rotary_emb(query, image_rotary_emb)
835-
# key = apply_rotary_emb(key, image_rotary_emb)
836-
837-
# sequence_length = query.size(1)
838-
# context_sequence_length = query_context.size(1)
839-
# query = torch.cat([query, query_context], dim=1)
840-
# key = torch.cat([key, key_context], dim=1)
841-
# value = torch.cat([value, value_context], dim=1)
842-
843-
# hidden_states = F.scaled_dot_product_attention(
844-
# query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False
845-
# )
846-
847-
# hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
848-
# hidden_states = hidden_states.to(query.dtype)
849-
850-
# hidden_states, encoder_hidden_states = hidden_states.split_with_sizes([sequence_length, context_sequence_length], dim=1)
851-
852-
# hidden_states = attn.to_out[0](hidden_states)
853-
# encoder_hidden_states = attn.to_context_out[0](encoder_hidden_states)
854-
855-
# return hidden_states, encoder_hidden_states
856-
857-
858720
class AttnProcessor:
859721
r"""
860722
Default processor for performing attention-related computations.

src/diffusers/models/embeddings.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,6 +1302,28 @@ def forward(self, timestep, caption_feat, caption_mask):
13021302
return conditioning
13031303

13041304

1305+
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
1306+
def __init__(self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8) -> None:
1307+
super().__init__()
1308+
1309+
self.time_proj = Timesteps(
1310+
num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0
1311+
)
1312+
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
1313+
self.pooler = MochiAttentionPool(num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim)
1314+
self.caption_proj = nn.Linear(embedding_dim, pooled_projection_dim)
1315+
1316+
def forward(self, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, hidden_dtype: Optional[torch.dtype] = None):
1317+
time_proj = self.time_proj(timestep)
1318+
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
1319+
1320+
pooled_projections = self.pooler(encoder_hidden_states, encoder_attention_mask)
1321+
caption_proj = self.caption_proj(encoder_hidden_states)
1322+
1323+
conditioning = time_emb + pooled_projections
1324+
return conditioning, caption_proj
1325+
1326+
13051327
class TextTimeEmbedding(nn.Module):
13061328
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
13071329
super().__init__()

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 103 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,11 @@
2121
from ...configuration_utils import ConfigMixin, register_to_config
2222
from ...utils import logging
2323
from ...utils.torch_utils import maybe_allow_in_graph
24-
from ..attention import Attention, FeedForward
25-
from ..embeddings import PatchEmbed, MochiAttentionPool, TimestepEmbedding, Timesteps
24+
from ..attention import Attention, FeedForward, JointAttnProcessor2_0
25+
from ..embeddings import PatchEmbed, MochiCombinedTimestepCaptionEmbedding
2626
from ..modeling_outputs import Transformer2DModelOutput
2727
from ..modeling_utils import ModelMixin
28-
from ..normalization import MochiRMSNormZero, RMSNorm
28+
from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm
2929

3030

3131
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -38,61 +38,73 @@ def __init__(
3838
dim: int,
3939
num_attention_heads: int,
4040
attention_head_dim: int,
41-
caption_dim: int,
41+
pooled_projection_dim: int,
42+
qk_norm: str = "rms_norm",
4243
activation_fn: str = "swiglu",
43-
update_captions: bool = True,
44+
context_pre_only: bool = True,
4445
) -> None:
4546
super().__init__()
4647

47-
self.update_captions = update_captions
48+
self.context_pre_only = context_pre_only
4849

4950
self.norm1 = MochiRMSNormZero(dim, 4 * dim)
5051

51-
if update_captions:
52-
self.norm_context1 = MochiRMSNormZero(dim, 4 * caption_dim)
52+
if context_pre_only:
53+
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
5354
else:
54-
self.norm_context1 = RMSNorm(caption_dim, eps=1e-5, elementwise_affine=False)
55+
self.norm1_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False)
5556

5657
self.attn = Attention(
5758
query_dim=dim,
5859
heads=num_attention_heads,
5960
attention_head_dim=attention_head_dim,
6061
out_dim=4 * dim,
61-
qk_norm="rms_norm",
62-
eps=1e-5,
63-
elementwise_affine=False,
64-
)
65-
self.attn_context = Attention(
66-
query_dim=dim,
67-
heads=num_attention_heads,
68-
attention_head_dim=attention_head_dim,
69-
out_dim=4 * caption_dim if update_captions else caption_dim,
70-
qk_norm="rms_norm",
71-
eps=1e-5,
62+
qk_norm=qk_norm,
63+
eps=1e-6,
7264
elementwise_affine=False,
65+
processor=JointAttnProcessor2_0(),
7366
)
7467

68+
self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
69+
self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False)
70+
71+
self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
72+
self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
73+
7574
self.ff = FeedForward(dim, mult=4, activation_fn=activation_fn)
76-
self.ff_context = FeedForward(caption_dim, mult=4, activation_fn=activation_fn)
75+
self.ff_context = FeedForward(pooled_projection_dim, mult=4, activation_fn=activation_fn)
76+
77+
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
78+
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
7779

7880
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
7981
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
8082

81-
if self.update_captions:
82-
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm_context1(encoder_hidden_states, temb)
83+
if self.context_pre_only:
84+
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(encoder_hidden_states, temb)
8385
else:
84-
norm_encoder_hidden_states = self.norm_context1(encoder_hidden_states)
86+
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
8587

86-
attn_hidden_states = self.attn(
88+
attn_hidden_states, context_attn_hidden_states = self.attn(
8789
hidden_states=norm_hidden_states,
88-
encoder_hidden_states=None,
90+
encoder_hidden_states=norm_encoder_hidden_states,
8991
image_rotary_emb=image_rotary_emb,
9092
)
91-
attn_encoder_hidden_states = self.attn_context(
92-
hidden_states=norm_encoder_hidden_states,
93-
encoder_hidden_states=None,
94-
image_rotary_emb=None,
95-
)
93+
94+
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
95+
hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
96+
if not self.context_pre_only:
97+
encoder_hidden_states = encoder_hidden_states + self.norm2_context(context_attn_hidden_states) * torch.tanh(enc_gate_msa).unsqueeze(1)
98+
encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
99+
100+
ff_output = self.ff(hidden_states)
101+
context_ff_output = self.ff_context(encoder_hidden_states)
102+
103+
hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1)
104+
if not self.context_pre_only:
105+
encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0)
106+
107+
return hidden_states, encoder_hidden_states
96108

97109

98110
@maybe_allow_in_graph
@@ -106,32 +118,35 @@ def __init__(
106118
num_attention_heads: int = 24,
107119
attention_head_dim: int = 128,
108120
num_layers: int = 48,
109-
caption_dim=1536,
110-
mlp_ratio_x=4.0,
111-
mlp_ratio_y=4.0,
121+
pooled_projection_dim: int = 1536,
112122
in_channels=12,
113-
qk_norm=True,
114-
qkv_bias=False,
115-
out_bias=True,
123+
out_channels: Optional[int] = None,
124+
qk_norm: str = "rms_norm",
116125
timestep_mlp_bias=True,
117126
timestep_scale=1000.0,
118-
text_embed_dim=4096,
127+
text_embed_dim: int = 4096,
128+
time_embed_dim: int = 256,
119129
activation_fn: str = "swiglu",
120-
max_sequence_length=256,
130+
max_sequence_length: int = 256,
121131
) -> None:
122132
super().__init__()
123133

124134
inner_dim = num_attention_heads * attention_head_dim
135+
out_channels = out_channels or in_channels
136+
137+
self.time_embed = MochiCombinedTimestepCaptionEmbedding(
138+
embedding_dim=text_embed_dim,
139+
pooled_projection_dim=pooled_projection_dim,
140+
time_embed_dim=time_embed_dim,
141+
num_attention_heads=8,
142+
)
125143

126144
self.patch_embed = PatchEmbed(
127145
patch_size=patch_size,
128146
in_channels=in_channels,
129147
embed_dim=inner_dim,
130148
)
131149

132-
self.caption_embedder = MochiAttentionPool(num_attention_heads=8, embed_dim=text_embed_dim, output_dim=inner_dim)
133-
self.caption_proj = nn.Linear(text_embed_dim, caption_dim)
134-
135150
self.pos_frequencies = nn.Parameter(
136151
torch.empty(3, num_attention_heads, attention_head_dim // 2)
137152
)
@@ -141,9 +156,53 @@ def __init__(
141156
dim=inner_dim,
142157
num_attention_heads=num_attention_heads,
143158
attention_head_dim=attention_head_dim,
144-
caption_dim=caption_dim,
159+
pooled_projection_dim=pooled_projection_dim,
160+
qk_norm=qk_norm,
145161
activation_fn=activation_fn,
146-
update_captions=i < num_layers - 1,
162+
context_pre_only=i < num_layers - 1,
147163
)
148164
for i in range(num_layers)
149165
])
166+
167+
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm")
168+
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
169+
170+
def forward(
171+
self,
172+
hidden_states: torch.Tensor,
173+
encoder_hidden_states: torch.Tensor,
174+
timestep: torch.LongTensor,
175+
encoder_attention_mask: torch.Tensor,
176+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
177+
return_dict: bool = True,
178+
) -> torch.Tensor:
179+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
180+
p = self.config.patch_size
181+
182+
post_patch_height = height // p
183+
post_patch_width = width // p
184+
185+
temb, caption_proj = self.time_embed(timestep, encoder_hidden_states, encoder_attention_mask)
186+
187+
hidden_states = self.patch_embed(hidden_states)
188+
189+
for i, block in enumerate(self.transformer_blocks):
190+
hidden_states, encoder_hidden_states = block(
191+
hidden_states=hidden_states,
192+
encoder_hidden_states=encoder_hidden_states,
193+
temb=temb,
194+
image_rotary_emb=image_rotary_emb,
195+
)
196+
197+
# TODO(aryan): do something with self.pos_frequencies
198+
199+
hidden_states = self.norm_out(hidden_states, temb)
200+
hidden_states = self.proj_out(hidden_states)
201+
202+
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_height, p, p, -1)
203+
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
204+
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
205+
206+
if not return_dict:
207+
return (output,)
208+
return Transformer2DModelOutput(sample=output)

0 commit comments

Comments
 (0)