@@ -243,7 +243,9 @@ def __init__(
243
243
self .norm_msa_a = AdaLayerNormZero (dim , device = device , dtype = dtype )
244
244
self .norm_mlp_a = AdaLayerNormZero (dim , device = device , dtype = dtype )
245
245
self .ff_a = nn .Sequential (
246
- nn .Linear (dim , dim * 4 ), nn .GELU (approximate = "tanh" ), nn .Linear (dim * 4 , dim , device = device , dtype = dtype )
246
+ nn .Linear (dim , dim * 4 , device = device , dtype = dtype ),
247
+ nn .GELU (approximate = "tanh" ),
248
+ nn .Linear (dim * 4 , dim , device = device , dtype = dtype )
247
249
)
248
250
# Text
249
251
self .norm_msa_b = AdaLayerNormZero (dim , device = device , dtype = dtype )
@@ -313,10 +315,10 @@ def __init__(
313
315
self .norm = AdaLayerNormZero (dim , device = device , dtype = dtype )
314
316
self .attn = FluxSingleAttention (dim , num_heads , attn_kwargs = attn_kwargs , device = device , dtype = dtype )
315
317
self .mlp = nn .Sequential (
316
- nn .Linear (dim , dim * 4 ),
318
+ nn .Linear (dim , dim * 4 , device = device , dtype = dtype ),
317
319
nn .GELU (approximate = "tanh" ),
318
320
)
319
- self .proj_out = nn .Linear (dim * 5 , dim )
321
+ self .proj_out = nn .Linear (dim * 5 , dim , device = device , dtype = dtype )
320
322
321
323
def forward (self , x , t_emb , rope_emb , image_emb = None ):
322
324
h , gate = self .norm (x , emb = t_emb )
0 commit comments