Skip to content

[mm] fix broken MRoPE for GLM-4.1/4.5V#1575

Merged
AlpinDale merged 1 commit into
mainfrom
glm41-v-mrope
Nov 4, 2025
Merged

[mm] fix broken MRoPE for GLM-4.1/4.5V#1575
AlpinDale merged 1 commit into
mainfrom
glm41-v-mrope

Conversation

@AlpinDale

Copy link
Copy Markdown
Member

No description provided.

Signed-off-by: AlpinDale <alpindale@gmail.com>
@AlpinDale AlpinDale merged commit e1f2cc9 into main Nov 4, 2025
1 check passed

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request aims to fix MRoPE for GLM-4.1/4.5V by adding the get_mrope_input_positions method. However, the implementation of this new method has several critical bugs that will lead to incorrect behavior for multimodal inputs, especially when videos are present. The logic for identifying video tokens is flawed, it uses incorrect grid data for videos, and the indexing for multimodal data is wrong for interleaved inputs. I've provided a detailed comment on the new method with suggestions for fixing these issues.

Comment on lines +1482 to +1583
def get_mrope_input_positions(
self,
input_tokens: list[int],
hf_config: "PretrainedConfig",
image_grid_thw: list[list[int]] | torch.Tensor | None,
video_grid_thw: list[list[int]] | torch.Tensor | None,
second_per_grid_ts: list[float] | None = None,
context_len: int = 0,
seq_len: int | None = None,
audio_feature_lengths: torch.Tensor | None = None,
use_audio_in_video: bool = False,
) -> tuple[torch.Tensor, int]:
"""Get mrope input positions and delta value for GLM4V."""

image_token_id = hf_config.image_token_id
video_start_token_id = hf_config.video_start_token_id
video_end_token_id = hf_config.video_end_token_id
spatial_merge_size = hf_config.vision_config.spatial_merge_size
llm_pos_ids_list: list = []

if not (image_grid_thw is None and video_grid_thw is None):
if isinstance(image_grid_thw, torch.Tensor):
image_grid_thw = image_grid_thw.tolist()

input_token_type: list[str] = []
video_check_flg = False
for token in input_tokens:
if token == video_start_token_id:
video_check_flg = True
elif token == video_end_token_id:
video_check_flg = False

if (token == image_token_id) and (video_check_flg is False):
input_token_type.append("image")
elif (token == image_token_id) and (video_check_flg is True):
input_token_type.append("video")
else:
input_token_type.append("text")

input_type_group: list[tuple[str, int, int]] = []
for key, group_iter in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
group_list = list(group_iter)
start_index = group_list[0][0]
end_index = group_list[-1][0] + 1
input_type_group.append((key, start_index, end_index))

video_frame_num = 1
mm_data_idx = 0
for modality_type, start_idx, end_idx in input_type_group:
st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
if modality_type == "image":
t, h, w = (
image_grid_thw[mm_data_idx][0],
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)

t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
mm_data_idx += 1

elif modality_type == "video":
t, h, w = (
video_frame_num,
image_grid_thw[mm_data_idx][1],
image_grid_thw[mm_data_idx][2],
)
llm_grid_t, llm_grid_h, llm_grid_w = (
t,
h // spatial_merge_size,
w // spatial_merge_size,
)

for t_idx in range(llm_grid_t):
t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)

mm_data_idx += 1
video_frame_num += 1

else:
text_len = end_idx - start_idx
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
video_frame_num = 1

else:
text_len = len(input_tokens)
llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1))

llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
llm_positions = llm_positions[:, context_len:seq_len]
mrope_position_delta = (llm_positions.max() + 1 - len(input_tokens)).item()
return llm_positions, mrope_position_delta

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The implementation of get_mrope_input_positions has several critical issues that will lead to incorrect position embeddings for multimodal inputs, especially with videos:

  1. Incorrect Video Token Identification: In lines 1516-1517, video frames are identified by checking for image_token_id within a video block. However, the processor logic indicates that video_token_id is used for video frames. This will cause video frames to be misclassified.

  2. Incorrect Data Source for Videos: In lines 1551-1555, image_grid_thw is used to get the height and width for video frames. The video_grid_thw argument is passed to the function but is never used. This is incorrect and will result in wrong position IDs for videos.

  3. Flawed Multimodal Indexing: A single index mm_data_idx is used for both images and videos (lines 1529, 1548, 1568). This will fail with interleaved image and video inputs, leading to incorrect data access.

  4. Video Frame Handling Logic: The logic for handling video frames seems to assume one frame per "video" group, which might be incorrect if multiple frames are grouped together. The use of video_frame_num as llm_grid_t and the loop for t_idx in range(llm_grid_t): seems incorrect for generating multi-frame positions.

I recommend refactoring this method to be more robust, similar to the implementation in qwen2_vl.py, which correctly handles separate indexing for images and videos and uses the correct data sources. A corrected implementation would need to:

  • Use hf_config.video_token_id to identify video tokens.
  • Use video_grid_thw for video data.
  • Maintain separate counters for images and videos.

@AlpinDale AlpinDale deleted the glm41-v-mrope branch November 4, 2025 05:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant