-
Notifications
You must be signed in to change notification settings - Fork 31.3k
Closed
Labels
Description
System Info
When using Qwen2.5-VL for generation, the method prepare_inputs_for_generation calls get_rope_index without passing second_per_grid_ts.
def prepare_inputs_for_generation(
self,
input_ids,
past_key_values=None,
attention_mask=None,
inputs_embeds=None,
cache_position=None,
position_ids=None,
use_cache=True,
pixel_values=None,
pixel_values_videos=None,
image_grid_thw=None,
video_grid_thw=None,
second_per_grid_ts=None,
**kwargs,
):
# Overwritten -- in specific circumstances we don't want to forward image inputs to the model
model_inputs = super().prepare_inputs_for_generation(
input_ids,
past_key_values=past_key_values,
attention_mask=attention_mask,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
position_ids=position_ids,
pixel_values=pixel_values,
pixel_values_videos=pixel_values_videos,
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
second_per_grid_ts=second_per_grid_ts,
use_cache=use_cache,
**kwargs,
)
# Qwen2-5-VL position_ids are prepared with rope_deltas
if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do assisted decoding
if cache_position[0] == 0 or self.model.rope_deltas is None:
vision_positions, rope_deltas = self.model.get_rope_index(
model_inputs.get("input_ids", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
# ⚠ second_per_grid_ts is missing here
)
self.model.rope_deltas = rope_deltas
# then use the prev pre-calculated rope-deltas to get the correct position ids
elif "position_ids" in model_inputs:
...
return model_inputsThis is inconsistent with the training setup, where second_per_grid_ts is provided in Qwen2_5_VLModel.forward:
@auto_docstring
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[Cache] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
pixel_values: Optional[torch.Tensor] = None,
pixel_values_videos: Optional[torch.FloatTensor] = None,
image_grid_thw: Optional[torch.LongTensor] = None,
video_grid_thw: Optional[torch.LongTensor] = None,
rope_deltas: Optional[torch.LongTensor] = None,
cache_position: Optional[torch.LongTensor] = None,
second_per_grid_ts: Optional[torch.Tensor] = None,
**kwargs: Unpack[TransformersKwargs],
) -> Union[tuple, Qwen2_5_VLModelOutputWithPast]:
r"""
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
The temporal, height and width of feature shape of each image in LLM.
video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
The temporal, height and width of feature shape of each video in LLM.
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
The rope index difference between sequence length and multimodal rope.
second_per_grid_ts (`torch.Tensor` of shape `(num_videos)`, *optional*):
The time interval (in seconds) for each grid along the temporal dimension in the 3D position IDs.
"""
...
if position_ids is None:
# Calculate RoPE index once per generation in the pre-fill stage only.
# When compiling, we can't check tensor values thus we check only input length
# It is safe to assume that `length!=1` means we're in pre-fill because compiled
# models currently cannot do asssisted decoding
prefill_compiled_stage = is_torchdynamo_compiling() and (
(input_ids is not None and input_ids.shape[1] != 1)
or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
)
prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
(cache_position is not None and cache_position[0] == 0)
or (past_key_values is None or past_key_values.get_seq_length() == 0)
)
if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
position_ids, rope_deltas = self.get_rope_index(
input_ids,
image_grid_thw,
video_grid_thw,
second_per_grid_ts=second_per_grid_ts, # Here passing the second_per_grid_ts
attention_mask=attention_mask,
)
self.rope_deltas = rope_deltas
else:
batch_size, seq_length, _ = inputs_embeds.shape
position_ids = torch.arange(seq_length, device=inputs_embeds.device)
position_ids = position_ids.view(1, 1, -1).expand(3, batch_size, -1)
if cache_position is not None:
delta = (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
else:
delta = torch.zeros((batch_size, seq_length), device=inputs_embeds.device)
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=1)
position_ids = position_ids + delta.to(position_ids.device)
...
return output if return_dict else output.to_tuple()As a result, the model may compute incorrect RoPE indices for video inputs, especially when variable frame rates or temporal scaling are involved.
Who can help?
No response
Information
- The official example scripts
- My own modified scripts
Tasks
- An officially supported task in the
examplesfolder (such as GLUE/SQuAD, ...) - My own task or dataset (give details below)
Reproduction
- Load Qwen2.5-VL from Hugging Face:
from transformers import AutoModelForCausalLM, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2.5-VL")
processor = AutoProcessor.from_pretrained("Qwen/Qwen2.5-VL")- Run generation with video inputs.
- Check prepare_inputs_for_generation:
vision_positions, rope_deltas = self.model.get_rope_index(
model_inputs.get("input_ids", None),
image_grid_thw=image_grid_thw,
video_grid_thw=video_grid_thw,
attention_mask=attention_mask,
# ⚠ second_per_grid_ts is missing here
)Expected behavior
second_per_grid_tsshould be passed intoget_rope_indexduring generation, to align with training behavior.- This ensures temporal RoPE indices are computed consistently between training and inference.