Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1172,10 +1172,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
Returns:
`torch.Tensor`: hidden_states.
"""
input_dtype = hidden_states.dtype
hidden_states = self.patch_embed(hidden_states)

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
hidden_states = (hidden_states + pos_embeds).to(input_dtype)

rotary_pos_emb = self.rot_pos_emb(grid_thw)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen3_vl/modeling_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -740,10 +740,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
Returns:
`torch.Tensor`: hidden_states.
"""
input_dtype = hidden_states.dtype
hidden_states = self.patch_embed(hidden_states)

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
hidden_states = (hidden_states + pos_embeds).to(input_dtype)

rotary_pos_emb = self.rot_pos_emb(grid_thw)

Expand Down
3 changes: 2 additions & 1 deletion src/transformers/models/qwen3_vl/modular_qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,10 +640,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
Returns:
`torch.Tensor`: hidden_states.
"""
input_dtype = hidden_states.dtype
hidden_states = self.patch_embed(hidden_states)

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
hidden_states = (hidden_states + pos_embeds).to(input_dtype)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think we should cast only pos_embeds to the input dtype here

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, further couldn't we fix in fast_pos_embed_interpolate instead of recasting? To avoid too many conversions - could for instance pass the input_dtype.
In h_idxs = torch.linspace(0, self.num_grid_per_side - 1, h) passing the wanted dtype should be enough no?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thought the same thing, but I am not sure if positional embedding was intentionally done in full precision for better performance 🤔

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this change intentionally keeps it running in FP32/same dtype as master weights for now, without changing the numerical dynamics.

Casting only pos_embeds or fixing inside fast_pos_embed_interpolate would have numerical implications, which requires ablations with training results if we want to be careful.

Happy to discuss — I'm leaning towards not changing model behaviors for now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Friendly reminder.

Let me know what you think.

CC @molbap @zucchini-nlp


rotary_pos_emb = self.rot_pos_emb(grid_thw)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -731,10 +731,11 @@ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs)
Returns:
`torch.Tensor`: hidden_states.
"""
input_dtype = hidden_states.dtype
hidden_states = self.patch_embed(hidden_states)

pos_embeds = self.fast_pos_embed_interpolate(grid_thw)
hidden_states = hidden_states + pos_embeds
hidden_states = (hidden_states + pos_embeds).to(input_dtype)

rotary_pos_emb = self.rot_pos_emb(grid_thw)

Expand Down