Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions src/diffusers/models/transformers/transformer_skyreels_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def __init__(
t_dim = attention_head_dim - h_dim - w_dim
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64

self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim

freqs_cos = []
freqs_sin = []

Expand All @@ -412,11 +416,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]

freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
Expand Down
11 changes: 6 additions & 5 deletions src/diffusers/models/transformers/transformer_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,11 @@ def __init__(

h_dim = w_dim = 2 * (attention_head_dim // 6)
t_dim = attention_head_dim - h_dim - w_dim

self.t_dim = t_dim
self.h_dim = h_dim
self.w_dim = w_dim

freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64

freqs_cos = []
Expand All @@ -387,11 +392,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
p_t, p_h, p_w = self.patch_size
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w

split_sizes = [
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
self.attention_head_dim // 3,
self.attention_head_dim // 3,
]
split_sizes = [self.t_dim, self.h_dim, self.w_dim]

freqs_cos = self.freqs_cos.split(split_sizes, dim=1)
freqs_sin = self.freqs_sin.split(split_sizes, dim=1)
Expand Down
Loading