Skip to content

Commit 05ebd6c

Browse files
committed
make style
1 parent da48940 commit 05ebd6c

File tree

4 files changed

+107
-117
lines changed

4 files changed

+107
-117
lines changed

src/diffusers/models/embeddings.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1303,17 +1303,25 @@ def forward(self, timestep, caption_feat, caption_mask):
13031303

13041304

13051305
class MochiCombinedTimestepCaptionEmbedding(nn.Module):
1306-
def __init__(self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8) -> None:
1306+
def __init__(
1307+
self, embedding_dim: int, pooled_projection_dim: int, time_embed_dim: int = 256, num_attention_heads: int = 8
1308+
) -> None:
13071309
super().__init__()
1308-
1309-
self.time_proj = Timesteps(
1310-
num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0
1311-
)
1310+
1311+
self.time_proj = Timesteps(num_channels=time_embed_dim, flip_sin_to_cos=True, downscale_freq_shift=0.0)
13121312
self.timestep_embedder = TimestepEmbedding(in_channels=time_embed_dim, time_embed_dim=embedding_dim)
1313-
self.pooler = MochiAttentionPool(num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim)
1313+
self.pooler = MochiAttentionPool(
1314+
num_attention_heads=num_attention_heads, embed_dim=pooled_projection_dim, output_dim=embedding_dim
1315+
)
13141316
self.caption_proj = nn.Linear(embedding_dim, pooled_projection_dim)
13151317

1316-
def forward(self, timestep: torch.LongTensor, encoder_hidden_states: torch.Tensor, encoder_attention_mask: torch.Tensor, hidden_dtype: Optional[torch.dtype] = None):
1318+
def forward(
1319+
self,
1320+
timestep: torch.LongTensor,
1321+
encoder_hidden_states: torch.Tensor,
1322+
encoder_attention_mask: torch.Tensor,
1323+
hidden_dtype: Optional[torch.dtype] = None,
1324+
):
13171325
time_proj = self.time_proj(timestep)
13181326
time_emb = self.timestep_embedder(time_proj.to(dtype=hidden_dtype))
13191327

@@ -1467,7 +1475,7 @@ def __init__(
14671475
self.to_kv = nn.Linear(embed_dim, 2 * embed_dim)
14681476
self.to_q = nn.Linear(embed_dim, embed_dim)
14691477
self.to_out = nn.Linear(embed_dim, self.output_dim)
1470-
1478+
14711479
@staticmethod
14721480
def pool_tokens(x: torch.Tensor, mask: torch.Tensor, *, keepdim=False) -> torch.Tensor:
14731481
"""
@@ -1526,9 +1534,7 @@ def forward(self, x: torch.Tensor, mask: torch.BoolTensor) -> torch.Tensor:
15261534
q = q.unsqueeze(2) # (B, H, 1, head_dim)
15271535

15281536
# Compute attention.
1529-
x = F.scaled_dot_product_attention(
1530-
q, k, v, attn_mask=attn_mask, dropout_p=0.0
1531-
) # (B, H, 1, head_dim)
1537+
x = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, dropout_p=0.0) # (B, H, 1, head_dim)
15321538

15331539
# Concatenate heads and run output.
15341540
x = x.squeeze(2).flatten(1, 2) # (B, D = H * head_dim)

src/diffusers/models/normalization.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,14 +245,18 @@ class MochiRMSNormZero(nn.Module):
245245
embedding_dim (`int`): The size of each embedding vector.
246246
"""
247247

248-
def __init__(self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False) -> None:
248+
def __init__(
249+
self, embedding_dim: int, hidden_dim: int, norm_eps: float = 1e-5, elementwise_affine: bool = False
250+
) -> None:
249251
super().__init__()
250252

251253
self.silu = nn.SiLU()
252254
self.linear = nn.Linear(embedding_dim, hidden_dim)
253255
self.norm = RMSNorm(embedding_dim, eps=norm_eps, elementwise_affine=elementwise_affine)
254256

255-
def forward(self, hidden_states: torch.Tensor, emb: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
257+
def forward(
258+
self, hidden_states: torch.Tensor, emb: torch.Tensor
259+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
256260
emb = self.linear(self.silu(emb))
257261
scale_msa, gate_msa, scale_mlp, gate_mlp = emb.chunk(4, dim=1)
258262
hidden_states = self.norm(hidden_states) * (1 + scale_msa[:, None])

src/diffusers/models/transformers/transformer_mochi.py

Lines changed: 48 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Any, Dict, Optional, Tuple, Union
16+
from typing import Optional, Tuple
1717

1818
import torch
1919
import torch.nn as nn
@@ -22,7 +22,7 @@
2222
from ...utils import logging
2323
from ...utils.torch_utils import maybe_allow_in_graph
2424
from ..attention import Attention, FeedForward, JointAttnProcessor2_0
25-
from ..embeddings import PatchEmbed, MochiCombinedTimestepCaptionEmbedding
25+
from ..embeddings import MochiCombinedTimestepCaptionEmbedding, PatchEmbed
2626
from ..modeling_outputs import Transformer2DModelOutput
2727
from ..modeling_utils import ModelMixin
2828
from ..normalization import AdaLayerNormContinuous, MochiRMSNormZero, RMSNorm
@@ -46,14 +46,14 @@ def __init__(
4646
super().__init__()
4747

4848
self.context_pre_only = context_pre_only
49-
49+
5050
self.norm1 = MochiRMSNormZero(dim, 4 * dim)
5151

5252
if context_pre_only:
5353
self.norm1_context = MochiRMSNormZero(dim, 4 * pooled_projection_dim)
5454
else:
5555
self.norm1_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False)
56-
56+
5757
self.attn = Attention(
5858
query_dim=dim,
5959
heads=num_attention_heads,
@@ -67,7 +67,7 @@ def __init__(
6767

6868
self.norm2 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
6969
self.norm2_context = RMSNorm(pooled_projection_dim, eps=1e-6, elementwise_affine=False)
70-
70+
7171
self.norm3 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
7272
self.norm3_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
7373

@@ -76,15 +76,23 @@ def __init__(
7676

7777
self.norm4 = RMSNorm(dim, eps=1e-6, elementwise_affine=False)
7878
self.norm4_context = RMSNorm(pooled_projection_dim, eps=1e-56, elementwise_affine=False)
79-
80-
def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor, image_rotary_emb: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
79+
80+
def forward(
81+
self,
82+
hidden_states: torch.Tensor,
83+
encoder_hidden_states: torch.Tensor,
84+
temb: torch.Tensor,
85+
image_rotary_emb: Optional[torch.Tensor] = None,
86+
) -> Tuple[torch.Tensor, torch.Tensor]:
8187
norm_hidden_states, gate_msa, scale_mlp, gate_mlp = self.norm1(hidden_states, temb)
8288

8389
if self.context_pre_only:
84-
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(encoder_hidden_states, temb)
90+
norm_encoder_hidden_states, enc_gate_msa, enc_scale_mlp, enc_gate_mlp = self.norm1_context(
91+
encoder_hidden_states, temb
92+
)
8593
else:
8694
norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states)
87-
95+
8896
attn_hidden_states, context_attn_hidden_states = self.attn(
8997
hidden_states=norm_hidden_states,
9098
encoder_hidden_states=norm_encoder_hidden_states,
@@ -94,16 +102,20 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens
94102
hidden_states = hidden_states + self.norm2(attn_hidden_states) * torch.tanh(gate_msa).unsqueeze(1)
95103
hidden_states = self.norm3(hidden_states) * (1 + scale_mlp.unsqueeze(1))
96104
if not self.context_pre_only:
97-
encoder_hidden_states = encoder_hidden_states + self.norm2_context(context_attn_hidden_states) * torch.tanh(enc_gate_msa).unsqueeze(1)
98-
encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (1 + enc_scale_mlp.unsqueeze(1))
99-
105+
encoder_hidden_states = encoder_hidden_states + self.norm2_context(
106+
context_attn_hidden_states
107+
) * torch.tanh(enc_gate_msa).unsqueeze(1)
108+
encoder_hidden_states = encoder_hidden_states + self.norm3_context(encoder_hidden_states) * (
109+
1 + enc_scale_mlp.unsqueeze(1)
110+
)
111+
100112
ff_output = self.ff(hidden_states)
101113
context_ff_output = self.ff_context(encoder_hidden_states)
102-
114+
103115
hidden_states = hidden_states + ff_output * torch.tanh(gate_mlp).unsqueeze(1)
104116
if not self.context_pre_only:
105117
encoder_hidden_states = encoder_hidden_states + context_ff_output * torch.tanh(enc_gate_mlp).unsqueeze(0)
106-
118+
107119
return hidden_states, encoder_hidden_states
108120

109121

@@ -140,33 +152,35 @@ def __init__(
140152
time_embed_dim=time_embed_dim,
141153
num_attention_heads=8,
142154
)
143-
155+
144156
self.patch_embed = PatchEmbed(
145157
patch_size=patch_size,
146158
in_channels=in_channels,
147159
embed_dim=inner_dim,
148160
)
149161

150-
self.pos_frequencies = nn.Parameter(
151-
torch.empty(3, num_attention_heads, attention_head_dim // 2)
162+
self.pos_frequencies = nn.Parameter(torch.empty(3, num_attention_heads, attention_head_dim // 2))
163+
164+
self.transformer_blocks = nn.ModuleList(
165+
[
166+
MochiTransformerBlock(
167+
dim=inner_dim,
168+
num_attention_heads=num_attention_heads,
169+
attention_head_dim=attention_head_dim,
170+
pooled_projection_dim=pooled_projection_dim,
171+
qk_norm=qk_norm,
172+
activation_fn=activation_fn,
173+
context_pre_only=i < num_layers - 1,
174+
)
175+
for i in range(num_layers)
176+
]
152177
)
153178

154-
self.transformer_blocks = nn.ModuleList([
155-
MochiTransformerBlock(
156-
dim=inner_dim,
157-
num_attention_heads=num_attention_heads,
158-
attention_head_dim=attention_head_dim,
159-
pooled_projection_dim=pooled_projection_dim,
160-
qk_norm=qk_norm,
161-
activation_fn=activation_fn,
162-
context_pre_only=i < num_layers - 1,
163-
)
164-
for i in range(num_layers)
165-
])
166-
167-
self.norm_out = AdaLayerNormContinuous(inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm")
179+
self.norm_out = AdaLayerNormContinuous(
180+
inner_dim, inner_dim, elementwise_affine=False, eps=1e-6, norm_type="layer_norm"
181+
)
168182
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
169-
183+
170184
def forward(
171185
self,
172186
hidden_states: torch.Tensor,
@@ -193,13 +207,13 @@ def forward(
193207
temb=temb,
194208
image_rotary_emb=image_rotary_emb,
195209
)
196-
210+
197211
# TODO(aryan): do something with self.pos_frequencies
198212

199213
hidden_states = self.norm_out(hidden_states, temb)
200214
hidden_states = self.proj_out(hidden_states)
201215

202-
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_height, p, p, -1)
216+
hidden_states = hidden_states.reshape(batch_size, num_frames, post_patch_height, post_patch_width, p, p, -1)
203217
hidden_states = hidden_states.permute(0, 6, 1, 2, 4, 3, 5)
204218
output = hidden_states.reshape(batch_size, -1, num_frames, height, width)
205219

0 commit comments

Comments
 (0)