Skip to content

Commit 63188ac

Browse files
authored
init all dit module with device and dtype for speed up (#164)
1 parent ca8a9a5 commit 63188ac

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

diffsynth_engine/models/flux/flux_dit.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,9 @@ def __init__(
243243
self.norm_msa_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
244244
self.norm_mlp_a = AdaLayerNormZero(dim, device=device, dtype=dtype)
245245
self.ff_a = nn.Sequential(
246-
nn.Linear(dim, dim * 4), nn.GELU(approximate="tanh"), nn.Linear(dim * 4, dim, device=device, dtype=dtype)
246+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
247+
nn.GELU(approximate="tanh"),
248+
nn.Linear(dim * 4, dim, device=device, dtype=dtype)
247249
)
248250
# Text
249251
self.norm_msa_b = AdaLayerNormZero(dim, device=device, dtype=dtype)
@@ -313,10 +315,10 @@ def __init__(
313315
self.norm = AdaLayerNormZero(dim, device=device, dtype=dtype)
314316
self.attn = FluxSingleAttention(dim, num_heads, attn_kwargs=attn_kwargs, device=device, dtype=dtype)
315317
self.mlp = nn.Sequential(
316-
nn.Linear(dim, dim * 4),
318+
nn.Linear(dim, dim * 4, device=device, dtype=dtype),
317319
nn.GELU(approximate="tanh"),
318320
)
319-
self.proj_out = nn.Linear(dim * 5, dim)
321+
self.proj_out = nn.Linear(dim * 5, dim, device=device, dtype=dtype)
320322

321323
def forward(self, x, t_emb, rope_emb, image_emb=None):
322324
h, gate = self.norm(x, emb=t_emb)

0 commit comments

Comments
 (0)