Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions swift/megatron/model/mm_gpt/qwen3_vl.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,9 @@ def prepare_model(self, hf_model):

@staticmethod
def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config):
from ...trainers.utils import split_cp_inputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

For better code organization and to avoid potential circular import issues, it's recommended to move imports to the top of the file. Please move from ...trainers.utils import split_cp_inputs to the top-level imports.

input_ids = inputs['input_ids']
packed_seq_params = inputs.get('packed_seq_params')
pixel_values = inputs.get('pixel_values')
pixel_values_videos = inputs.get('pixel_values_videos')
image_grid_thw = inputs.get('image_grid_thw')
Expand Down Expand Up @@ -139,8 +141,17 @@ def _get_inputs_embeds(inputs_embeds, inputs, visual, processor, config):
inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
visual_pos_masks = image_mask[..., 0] | video_mask[..., 0]
visual_pos_masks = visual_pos_masks.transpose(0, 1)
# compat sp
# compat cp
args = get_args()
if args.context_parallel_size > 1:
assert packed_seq_params is not None
device = visual_pos_masks.device
cp_mask = torch.full(visual_pos_masks.shape[:1], -1, dtype=torch.long, device=device)
cp_mask[visual_pos_masks[:, 0]] = torch.arange(visual_pos_masks.sum(), device=device)
cp_mask = split_cp_inputs(cp_mask, packed_seq_params.cu_seqlens_q, 0)
visual_pos_masks = split_cp_inputs(visual_pos_masks, packed_seq_params.cu_seqlens_q, 0)
deepstack_visual_embeds = deepstack_visual_embeds[:, cp_mask[(cp_mask != -1)]]
# compat sp
tp_world_size = parallel_state.get_tensor_model_parallel_world_size()
tp_rank = parallel_state.get_tensor_model_parallel_rank()
if args.sequence_parallel and tp_world_size > 1:
Expand Down Expand Up @@ -445,7 +456,6 @@ def forward(

def _deepstack_process(self, hidden_states: torch.Tensor, visual_pos_masks: torch.Tensor,
visual_embeds: torch.Tensor):
# TODO: compat CP
visual_pos_masks = visual_pos_masks.to(hidden_states.device)
visual_embeds = visual_embeds.to(hidden_states.device, hidden_states.dtype)
local_this = hidden_states[visual_pos_masks, :].clone() + visual_embeds
Expand Down
12 changes: 6 additions & 6 deletions swift/megatron/model/mm_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,14 @@ def _patch_word_embeddings(self, kwargs):
origin_forward = VocabParallelEmbedding.forward

def forward(_self, input_):
from ..trainers.utils import split_cp_inputs
args = get_args()
reduce_scatter_embeddings = _self.reduce_scatter_embeddings
_self.reduce_scatter_embeddings = False
input_ = torch.masked_fill(input_, input_ < 0, 0)
res = origin_forward(_self, input_)
_self.reduce_scatter_embeddings = reduce_scatter_embeddings
packed_seq_params = kwargs.get('packed_seq_params')
if self.visual is not None:
res = self.visual.get_inputs_embeds(res, **kwargs)
kwargs.clear()
Expand All @@ -54,6 +57,8 @@ def forward(_self, input_):
inputs_embeds = res.pop('inputs_embeds')
kwargs.update(res)
res = inputs_embeds
if args.context_parallel_size > 1:
res = split_cp_inputs(res, packed_seq_params.cu_seqlens_q, 1)
if reduce_scatter_embeddings:
res = res.transpose(0, 1).contiguous()
res = scatter_to_sequence_parallel_region(res, group=_self.tp_group)
Expand All @@ -80,14 +85,9 @@ def forward(
if decoder_input is not None:
pass
elif self.pre_process:
from ..trainers.utils import get_batch_on_this_cp_rank
kwargs.update({'input_ids': input_ids})
kwargs.update({'input_ids': input_ids, 'packed_seq_params': packed_seq_params})
with self._patch_word_embeddings(kwargs):
decoder_input = self.language_model.embedding(input_ids=input_ids, position_ids=position_ids)
decoder_input = get_batch_on_this_cp_rank({
'decoder_input': decoder_input,
'packed_seq_params': packed_seq_params
})['decoder_input']
else:
# intermediate stage of pipeline
# decoder will get hidden_states from encoder.input_tensor
Expand Down
49 changes: 14 additions & 35 deletions swift/megatron/trainers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,42 +64,24 @@ def get_packed_seq_params(position_ids: torch.Tensor) -> PackedSeqParams:
qkv_format='thd')


def _split_tokens(tokens, cu_seqlens):
assert tokens.shape[-2] == 1, f'tokens.shape: {tokens.shape}' # [..., 1, L]
new_tokens = []
def split_cp_inputs(inputs: torch.Tensor, cu_seqlens: torch.Tensor, dim: int):
if dim < 0:
dim = (dim + inputs.ndim) % inputs.ndim
new_inputs = []
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
for i in range(cu_seqlens.shape[0] - 1):
val = tokens[..., cu_seqlens[i]:cu_seqlens[i + 1]]
val = val.view(
*tokens.shape[:-1],
2 * cp_size,
val.shape[-1] // (2 * cp_size),
)
slices = [slice(None)] * inputs.ndim
slices[dim] = slice(cu_seqlens[i], cu_seqlens[i + 1])
val = inputs[slices]
view_shape = (*inputs.shape[:dim], 2 * cp_size, val.shape[dim] // (2 * cp_size), *inputs.shape[dim + 1:])
val = val.view(view_shape)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
pin_memory=True).cuda(non_blocking=True)
val = val.index_select(-2, index)
new_tokens.append(val.view(*tokens.shape[:-1], -1))
return torch.cat(new_tokens, dim=-1)


def _split_tokens_decoder_input(tokens, cu_seqlens):
assert tokens.shape[1] == 1, f'tokens.shape: {tokens.shape}' # [L, 1, E]
new_tokens = []
cp_size = mpu.get_context_parallel_world_size()
cp_rank = mpu.get_context_parallel_rank()
for i in range(cu_seqlens.shape[0] - 1):
val = tokens[cu_seqlens[i]:cu_seqlens[i + 1], ...]
val = val.view(
2 * cp_size,
val.shape[0] // (2 * cp_size),
*tokens.shape[1:],
)
index = torch.tensor([cp_rank, (2 * cp_size - cp_rank - 1)], device='cpu',
pin_memory=True).cuda(non_blocking=True)
val = val.index_select(0, index)
new_tokens.append(val.view(-1, *tokens.shape[1:]))
return torch.cat(new_tokens, dim=0)
val = val.index_select(dim, index)
view_shape = (*inputs.shape[:dim], -1, *inputs.shape[dim + 1:])
new_inputs.append(val.view(view_shape))
return torch.cat(new_inputs, dim=dim)


def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
Expand Down Expand Up @@ -130,10 +112,7 @@ def get_batch_on_this_cp_rank(batch: Dict[str, Any]):
if args.task_type == 'seq_cls' and key == 'labels':
continue
if val is not None:
if key == 'decoder_input':
batch[key] = _split_tokens_decoder_input(val, packed_seq_params.cu_seqlens_q)
else:
batch[key] = _split_tokens(val, packed_seq_params.cu_seqlens_q)
batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This simplification introduces a potential issue. The previous code had special handling for decoder_input, which requires splitting along dim=0 due to its (sequence_length, batch_size, hidden_size) shape. The new code splits all tensors along dim=-1, which is incorrect for decoder_input as it would split the hidden dimension.

While MultimodalGPTModel now handles this splitting separately, the 'decoder_input' key is still checked for in this function (on line 103), making this code path potentially active for other cases. This is confusing and could lead to silent errors.

To fix this, please either restore the special handling for decoder_input or, if it's confirmed to be dead code, remove 'decoder_input' from the keys list on line 103.

Suggested change
batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1)
if key == 'decoder_input':
batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, 0)
else:
batch[key] = split_cp_inputs(val, packed_seq_params.cu_seqlens_q, -1)


return batch

Expand Down
Loading