diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 4ce6408dbb3e..59cfc0415808 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1172,10 +1172,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ + input_dtype = hidden_states.dtype hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + hidden_states = (hidden_states + pos_embeds).to(input_dtype) rotary_pos_emb = self.rot_pos_emb(grid_thw) diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index d41cfa4b090e..e3a9a7804fcd 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -740,10 +740,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ + input_dtype = hidden_states.dtype hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + hidden_states = (hidden_states + pos_embeds).to(input_dtype) rotary_pos_emb = self.rot_pos_emb(grid_thw) diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 5d1c88d03bc4..3f32fbc6ff53 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -640,10 +640,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ + input_dtype = hidden_states.dtype hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + hidden_states = (hidden_states + pos_embeds).to(input_dtype) rotary_pos_emb = self.rot_pos_emb(grid_thw) diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 264902c2d8a4..34a83f5f5c57 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -731,10 +731,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) Returns: `torch.Tensor`: hidden_states. """ + input_dtype = hidden_states.dtype hidden_states = self.patch_embed(hidden_states) pos_embeds = self.fast_pos_embed_interpolate(grid_thw) - hidden_states = hidden_states + pos_embeds + hidden_states = (hidden_states + pos_embeds).to(input_dtype) rotary_pos_emb = self.rot_pos_emb(grid_thw)