[mm] fix broken MRoPE for GLM-4.1/4.5V#1575
Conversation
Signed-off-by: AlpinDale <alpindale@gmail.com>
There was a problem hiding this comment.
Code Review
This pull request aims to fix MRoPE for GLM-4.1/4.5V by adding the get_mrope_input_positions method. However, the implementation of this new method has several critical bugs that will lead to incorrect behavior for multimodal inputs, especially when videos are present. The logic for identifying video tokens is flawed, it uses incorrect grid data for videos, and the indexing for multimodal data is wrong for interleaved inputs. I've provided a detailed comment on the new method with suggestions for fixing these issues.
| def get_mrope_input_positions( | ||
| self, | ||
| input_tokens: list[int], | ||
| hf_config: "PretrainedConfig", | ||
| image_grid_thw: list[list[int]] | torch.Tensor | None, | ||
| video_grid_thw: list[list[int]] | torch.Tensor | None, | ||
| second_per_grid_ts: list[float] | None = None, | ||
| context_len: int = 0, | ||
| seq_len: int | None = None, | ||
| audio_feature_lengths: torch.Tensor | None = None, | ||
| use_audio_in_video: bool = False, | ||
| ) -> tuple[torch.Tensor, int]: | ||
| """Get mrope input positions and delta value for GLM4V.""" | ||
|
|
||
| image_token_id = hf_config.image_token_id | ||
| video_start_token_id = hf_config.video_start_token_id | ||
| video_end_token_id = hf_config.video_end_token_id | ||
| spatial_merge_size = hf_config.vision_config.spatial_merge_size | ||
| llm_pos_ids_list: list = [] | ||
|
|
||
| if not (image_grid_thw is None and video_grid_thw is None): | ||
| if isinstance(image_grid_thw, torch.Tensor): | ||
| image_grid_thw = image_grid_thw.tolist() | ||
|
|
||
| input_token_type: list[str] = [] | ||
| video_check_flg = False | ||
| for token in input_tokens: | ||
| if token == video_start_token_id: | ||
| video_check_flg = True | ||
| elif token == video_end_token_id: | ||
| video_check_flg = False | ||
|
|
||
| if (token == image_token_id) and (video_check_flg is False): | ||
| input_token_type.append("image") | ||
| elif (token == image_token_id) and (video_check_flg is True): | ||
| input_token_type.append("video") | ||
| else: | ||
| input_token_type.append("text") | ||
|
|
||
| input_type_group: list[tuple[str, int, int]] = [] | ||
| for key, group_iter in itertools.groupby(enumerate(input_token_type), lambda x: x[1]): | ||
| group_list = list(group_iter) | ||
| start_index = group_list[0][0] | ||
| end_index = group_list[-1][0] + 1 | ||
| input_type_group.append((key, start_index, end_index)) | ||
|
|
||
| video_frame_num = 1 | ||
| mm_data_idx = 0 | ||
| for modality_type, start_idx, end_idx in input_type_group: | ||
| st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0 | ||
| if modality_type == "image": | ||
| t, h, w = ( | ||
| image_grid_thw[mm_data_idx][0], | ||
| image_grid_thw[mm_data_idx][1], | ||
| image_grid_thw[mm_data_idx][2], | ||
| ) | ||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||
| t, | ||
| h // spatial_merge_size, | ||
| w // spatial_merge_size, | ||
| ) | ||
|
|
||
| t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() | ||
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten() | ||
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten() | ||
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) | ||
| mm_data_idx += 1 | ||
|
|
||
| elif modality_type == "video": | ||
| t, h, w = ( | ||
| video_frame_num, | ||
| image_grid_thw[mm_data_idx][1], | ||
| image_grid_thw[mm_data_idx][2], | ||
| ) | ||
| llm_grid_t, llm_grid_h, llm_grid_w = ( | ||
| t, | ||
| h // spatial_merge_size, | ||
| w // spatial_merge_size, | ||
| ) | ||
|
|
||
| for t_idx in range(llm_grid_t): | ||
| t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten() | ||
| h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten() | ||
| w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten() | ||
| llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx) | ||
|
|
||
| mm_data_idx += 1 | ||
| video_frame_num += 1 | ||
|
|
||
| else: | ||
| text_len = end_idx - start_idx | ||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) | ||
| video_frame_num = 1 | ||
|
|
||
| else: | ||
| text_len = len(input_tokens) | ||
| llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1)) | ||
|
|
||
| llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) | ||
| llm_positions = llm_positions[:, context_len:seq_len] | ||
| mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item() | ||
| return llm_positions, mrope_position_delta |
There was a problem hiding this comment.
The implementation of get_mrope_input_positions has several critical issues that will lead to incorrect position embeddings for multimodal inputs, especially with videos:
-
Incorrect Video Token Identification: In lines 1516-1517, video frames are identified by checking for
image_token_idwithin a video block. However, the processor logic indicates thatvideo_token_idis used for video frames. This will cause video frames to be misclassified. -
Incorrect Data Source for Videos: In lines 1551-1555,
image_grid_thwis used to get the height and width for video frames. Thevideo_grid_thwargument is passed to the function but is never used. This is incorrect and will result in wrong position IDs for videos. -
Flawed Multimodal Indexing: A single index
mm_data_idxis used for both images and videos (lines 1529, 1548, 1568). This will fail with interleaved image and video inputs, leading to incorrect data access. -
Video Frame Handling Logic: The logic for handling video frames seems to assume one frame per "video" group, which might be incorrect if multiple frames are grouped together. The use of
video_frame_numasllm_grid_tand the loopfor t_idx in range(llm_grid_t):seems incorrect for generating multi-frame positions.
I recommend refactoring this method to be more robust, similar to the implementation in qwen2_vl.py, which correctly handles separate indexing for images and videos and uses the correct data sources. A corrected implementation would need to:
- Use
hf_config.video_token_idto identify video tokens. - Use
video_grid_thwfor video data. - Maintain separate counters for images and videos.
No description provided.