From a48ad1efaaaaf325c3a3172b888bf9e8f35108e6 Mon Sep 17 00:00:00 2001 From: cyy Date: Sun, 10 Aug 2025 18:16:02 +0800 Subject: [PATCH] Avoid attention_mask copy in qwen2.5 Signed-off-by: cyy --- .../models/qwen2_5_omni/modeling_qwen2_5_omni.py | 15 +++++++++------ .../models/qwen2_5_omni/modular_qwen2_5_omni.py | 15 +++++++++------ .../models/qwen2_5_vl/modeling_qwen2_5_vl.py | 15 +++++++++------ .../models/qwen2_5_vl/modular_qwen2_5_vl.py | 15 +++++++++------ 4 files changed, 36 insertions(+), 24 deletions(-) diff --git a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py index e618ba861b6c..e8419bcf40c4 100644 --- a/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py @@ -265,8 +265,8 @@ def get_rope_index( mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) + if attention_mask is not None: + attention_mask = attention_mask == 1 position_ids = torch.ones( 3, input_ids.shape[0], @@ -275,9 +275,9 @@ def get_rope_index( device=input_ids.device, ) image_idx, video_idx, audio_idx = 0, 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] + if attention_mask is not None: + input_ids = input_ids[attention_mask[i]] image_nums, video_nums, audio_nums = 0, 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] @@ -458,9 +458,12 @@ def get_rope_index( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device) return position_ids, mrope_position_deltas else: diff --git a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py index 53028a99b206..eb5679194b90 100644 --- a/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +++ b/src/transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py @@ -1320,8 +1320,8 @@ def get_rope_index( mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) + if attention_mask is not None: + attention_mask = attention_mask == 1 position_ids = torch.ones( 3, input_ids.shape[0], @@ -1330,9 +1330,9 @@ def get_rope_index( device=input_ids.device, ) image_idx, video_idx, audio_idx = 0, 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] + if attention_mask is not None: + input_ids = input_ids[attention_mask[i]] image_nums, video_nums, audio_nums = 0, 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] @@ -1513,9 +1513,12 @@ def get_rope_index( llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(input_ids)) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device) return position_ids, mrope_position_deltas else: diff --git a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py index 8a04e8116eb5..fd67771f542e 100644 --- a/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py @@ -1033,8 +1033,8 @@ def get_rope_index( mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) + if attention_mask is not None: + attention_mask = attention_mask == 1 position_ids = torch.ones( 3, input_ids.shape[0], @@ -1043,9 +1043,9 @@ def get_rope_index( device=input_ids.device, ) image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] + if attention_mask is not None: + input_ids = input_ids[attention_mask[i]] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] @@ -1122,9 +1122,12 @@ def get_rope_index( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device) return position_ids, mrope_position_deltas else: if attention_mask is not None: diff --git a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py index 90ec79edf0ea..55f77c3cfbba 100644 --- a/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +++ b/src/transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py @@ -421,8 +421,8 @@ def get_rope_index( mrope_position_deltas = [] if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None): total_input_ids = input_ids - if attention_mask is None: - attention_mask = torch.ones_like(total_input_ids) + if attention_mask is not None: + attention_mask = attention_mask == 1 position_ids = torch.ones( 3, input_ids.shape[0], @@ -431,9 +431,9 @@ def get_rope_index( device=input_ids.device, ) image_index, video_index = 0, 0 - attention_mask = attention_mask.to(total_input_ids.device) for i, input_ids in enumerate(total_input_ids): - input_ids = input_ids[attention_mask[i] == 1] + if attention_mask is not None: + input_ids = input_ids[attention_mask[i]] image_nums, video_nums = 0, 0 vision_start_indices = torch.argwhere(input_ids == vision_start_token_id).squeeze(1) vision_tokens = input_ids[vision_start_indices + 1] @@ -510,9 +510,12 @@ def get_rope_index( llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx) llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1) - position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device) + if attention_mask is not None: + position_ids[..., i, attention_mask[i]] = llm_positions.to(position_ids.device) + else: + position_ids[..., i, :] = llm_positions.to(position_ids.device) mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i])) - mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1) + mrope_position_deltas = torch.tensor(mrope_position_deltas).unsqueeze(1).to(device=input_ids.device) return position_ids, mrope_position_deltas else: if attention_mask is not None: