From fadbec5c4de00dc6c6a60924fef94b0e45163de1 Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Nov 2025 01:52:52 +0100 Subject: [PATCH 1/2] fix --- .../models/transformers/transformer_sana_video.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index aaf96175c0e8..4305992c5f1f 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -189,6 +189,11 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim + + self.t_dim = t_dim + self.h_dim = h_dim + self.w_dim = w_dim + freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_cos = [] @@ -214,11 +219,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [ - self.attention_head_dim - 2 * (self.attention_head_dim // 3), - self.attention_head_dim // 3, - self.attention_head_dim // 3, - ] + split_sizes = [self.t_dim, self.h_dim, self.w_dim] freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1) From 80765d6450a8ee17cebe9c58568771d42de613ba Mon Sep 17 00:00:00 2001 From: yiyixuxu Date: Wed, 12 Nov 2025 02:15:15 +0100 Subject: [PATCH 2/2] remoce cocpies instead --- .../models/transformers/transformer_sana_video.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sana_video.py b/src/diffusers/models/transformers/transformer_sana_video.py index 4305992c5f1f..424d9ff9d360 100644 --- a/src/diffusers/models/transformers/transformer_sana_video.py +++ b/src/diffusers/models/transformers/transformer_sana_video.py @@ -172,7 +172,6 @@ def apply_rotary_emb( return hidden_states -# Copied from diffusers.models.transformers.transformer_wan.WanRotaryPosEmbed class WanRotaryPosEmbed(nn.Module): def __init__( self, @@ -189,11 +188,6 @@ def __init__( h_dim = w_dim = 2 * (attention_head_dim // 6) t_dim = attention_head_dim - h_dim - w_dim - - self.t_dim = t_dim - self.h_dim = h_dim - self.w_dim = w_dim - freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64 freqs_cos = [] @@ -219,7 +213,11 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: p_t, p_h, p_w = self.patch_size ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w - split_sizes = [self.t_dim, self.h_dim, self.w_dim] + split_sizes = [ + self.attention_head_dim - 2 * (self.attention_head_dim // 3), + self.attention_head_dim // 3, + self.attention_head_dim // 3, + ] freqs_cos = self.freqs_cos.split(split_sizes, dim=1) freqs_sin = self.freqs_sin.split(split_sizes, dim=1)