@@ -160,7 +160,12 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
160
160
self .norm2 = FP32LayerNorm (dim , elementwise_affine = False , bias = False )
161
161
self .ff = AuraFlowFeedForward (dim , dim * 4 )
162
162
163
- def forward (self , hidden_states : torch .FloatTensor , temb : torch .FloatTensor , attention_kwargs : Optional [Dict [str , Any ]] = None ):
163
+ def forward (
164
+ self ,
165
+ hidden_states : torch .FloatTensor ,
166
+ temb : torch .FloatTensor ,
167
+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
168
+ ):
164
169
residual = hidden_states
165
170
attention_kwargs = attention_kwargs or {}
166
171
@@ -224,7 +229,11 @@ def __init__(self, dim, num_attention_heads, attention_head_dim):
224
229
self .ff_context = AuraFlowFeedForward (dim , dim * 4 )
225
230
226
231
def forward (
227
- self , hidden_states : torch .FloatTensor , encoder_hidden_states : torch .FloatTensor , temb : torch .FloatTensor , attention_kwargs : Optional [Dict [str , Any ]] = None ,
232
+ self ,
233
+ hidden_states : torch .FloatTensor ,
234
+ encoder_hidden_states : torch .FloatTensor ,
235
+ temb : torch .FloatTensor ,
236
+ attention_kwargs : Optional [Dict [str , Any ]] = None ,
228
237
):
229
238
residual = hidden_states
230
239
residual_context = encoder_hidden_states
@@ -238,7 +247,9 @@ def forward(
238
247
239
248
# Attention.
240
249
attn_output , context_attn_output = self .attn (
241
- hidden_states = norm_hidden_states , encoder_hidden_states = norm_encoder_hidden_states , ** attention_kwargs ,
250
+ hidden_states = norm_hidden_states ,
251
+ encoder_hidden_states = norm_encoder_hidden_states ,
252
+ ** attention_kwargs ,
242
253
)
243
254
244
255
# Process attention outputs for the `hidden_states`.
@@ -492,7 +503,10 @@ def forward(
492
503
493
504
else :
494
505
encoder_hidden_states , hidden_states = block (
495
- hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states , temb = temb , attention_kwargs = attention_kwargs ,
506
+ hidden_states = hidden_states ,
507
+ encoder_hidden_states = encoder_hidden_states ,
508
+ temb = temb ,
509
+ attention_kwargs = attention_kwargs ,
496
510
)
497
511
498
512
# Single DiT blocks that combine the `hidden_states` (image) and `encoder_hidden_states` (text)
@@ -509,7 +523,9 @@ def forward(
509
523
)
510
524
511
525
else :
512
- combined_hidden_states = block (hidden_states = combined_hidden_states , temb = temb , attention_kwargs = attention_kwargs )
526
+ combined_hidden_states = block (
527
+ hidden_states = combined_hidden_states , temb = temb , attention_kwargs = attention_kwargs
528
+ )
513
529
514
530
hidden_states = combined_hidden_states [:, encoder_seq_len :]
515
531
0 commit comments