Skip to content

Commit 2b7deff

Browse files
vladmandicDN6github-actions[bot]
authored
fix scale_shift_factor being on cpu for wan and ltx (#12347)
* wan fix scale_shift_factor being on cpu * apply device cast to ltx transformer * Apply style fixes --------- Co-authored-by: Dhruv Nair <dhruv.nair@gmail.com> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 941ac9c commit 2b7deff

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

src/diffusers/models/transformers/transformer_ltx.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,9 @@ def forward(
353353
norm_hidden_states = self.norm1(hidden_states)
354354

355355
num_ada_params = self.scale_shift_table.shape[0]
356-
ada_values = self.scale_shift_table[None, None] + temb.reshape(batch_size, temb.size(1), num_ada_params, -1)
356+
ada_values = self.scale_shift_table[None, None].to(temb.device) + temb.reshape(
357+
batch_size, temb.size(1), num_ada_params, -1
358+
)
357359
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ada_values.unbind(dim=2)
358360
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
359361

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -682,12 +682,12 @@ def forward(
682682
# 5. Output norm, projection & unpatchify
683683
if temb.ndim == 3:
684684
# batch_size, seq_len, inner_dim (wan 2.2 ti2v)
685-
shift, scale = (self.scale_shift_table.unsqueeze(0) + temb.unsqueeze(2)).chunk(2, dim=2)
685+
shift, scale = (self.scale_shift_table.unsqueeze(0).to(temb.device) + temb.unsqueeze(2)).chunk(2, dim=2)
686686
shift = shift.squeeze(2)
687687
scale = scale.squeeze(2)
688688
else:
689689
# batch_size, inner_dim
690-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
690+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
691691

692692
# Move the shift and scale tensors to the same device as hidden_states.
693693
# When using multi-GPU inference via accelerate these will be on the

src/diffusers/models/transformers/transformer_wan_vace.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def forward(
103103
control_hidden_states = control_hidden_states + hidden_states
104104

105105
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
106-
self.scale_shift_table + temb.float()
106+
self.scale_shift_table.to(temb.device) + temb.float()
107107
).chunk(6, dim=1)
108108

109109
# 1. Self-attention
@@ -361,7 +361,7 @@ def forward(
361361
hidden_states = hidden_states + control_hint * scale
362362

363363
# 6. Output norm, projection & unpatchify
364-
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
364+
shift, scale = (self.scale_shift_table.to(temb.device) + temb.unsqueeze(1)).chunk(2, dim=1)
365365

366366
# Move the shift and scale tensors to the same device as hidden_states.
367367
# When using multi-GPU inference via accelerate these will be on the

0 commit comments

Comments
 (0)