diff --git a/src/diffusers/models/transformers/transformer_skyreels_v2.py b/src/diffusers/models/transformers/transformer_skyreels_v2.py index 6b600aa22487..4688ec5a9c71 100644 --- a/src/diffusers/models/transformers/transformer_skyreels_v2.py +++ b/src/diffusers/models/transformers/transformer_skyreels_v2.py @@ -389,6 +389,10 @@ def __init__( t_dim = attention_head_dim - h_dim - w_dim freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + freqs_cos = [] freqs_sin = [] @@ -412,11 +416,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [ - self.attention_head_dim - 2 * (self.attention_head_dim // 3), - self.attention_head_dim // 3, - self.attention_head_dim // 3, - ] + split_sizes = [self.t_dim, self.h_dim, self.w_dim] freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1) diff --git a/src/diffusers/models/transformers/transformer_wan.py b/src/diffusers/models/transformers/transformer_wan.py index dd75fb124f1a..742147d301a0 100644 --- a/src/diffusers/models/transformers/transformer_wan.py +++ b/src/diffusers/models/transformers/transformer_wan.py @@ -362,6 +362,11 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_cos = [] @@ -387,11 +392,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [ - self.attention_head_dim - 2 * (self.attention_head_dim // 3), - self.attention_head_dim // 3, - self.attention_head_dim // 3, - ] + split_sizes = [self.t_dim, self.h_dim, self.w_dim] freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1)