From 647239cb62057a1e35101a9d08613026a011e797 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 00:59:32 +0800 Subject: [PATCH 1/3] optimize_gpt_bridge_comm --- swift/llm/argument/export_args.py | 5 ++ swift/megatron/model/gpt_bridge.py | 73 +++++++++++++++++------------- 2 files changed, 46 insertions(+), 32 deletions(-) diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index 921163e4ae..aa7844922c 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -67,6 +67,11 @@ class ExportArguments(MergeArguments, BaseArguments): to_peft_format: bool = False exist_ok: bool = False + def _init_ckpt_dir(self, adapters=None): + if self.to_cached_dataset: + return + super()._init_ckpt_dir(adapters) + def _init_output_dir(self): if self.output_dir is None: ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 40aff04568..13d92c16eb 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False): self.etp_rank = mpu.get_expert_tensor_parallel_rank() self.ep_rank = mpu.get_expert_model_parallel_rank() + dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size + expert_decoder_rank_generator = mpu.RankGenerator( + tp=self.etp_size, + ep=self.ep_size, + dp=dp_size, + pp=self.pp_size, + cp=1, + order='tp-cp-ep-dp-pp', + rank_offset=0, + ) + rank = dist.get_rank() + for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): + group = mpu.create_group( + ranks, + group_desc='EP-PP-GROUP', + ) + if rank in ranks: + self.ep_pp_size = self.ep_size * self.pp_size + self.ep_pp_group = group + self.ep_pp_rank = dist.get_rank(group) + def _init_meta_hf_model(self): with torch.device('meta'): self.hf_model, self.processor = get_model_tokenizer( @@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl tensor = None if mg_weight is None else mg_weight.to('cuda') tp_size = self.etp_size if is_expert else self.tp_size tp_group = self.etp_group if is_expert else self.tp_group + pp_group = self.ep_pp_group if is_expert else self.pp_group + pp_size = self.ep_pp_size if is_expert else self.pp_size + pp_rank = self.ep_pp_rank if is_expert else self.pp_rank if tensor is not None and tp_dim is not None and tp_size > 1: if tp_dim == 0: # save memory @@ -220,34 +244,25 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl tensor = torch.cat(output, dim=tp_dim) del output # pp/ep - for parallel_state in ['ep', 'pp']: - if parallel_state == 'pp' and self.pp_size > 1: - parallel_group = self.pp_group - parallel_rank = self.pp_rank - elif parallel_state == 'ep' and is_expert and self.ep_size > 1: - parallel_group = self.ep_group - parallel_rank = self.ep_rank - else: - continue - src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda') - dist.all_reduce(src_rank, group=parallel_group) - src_rank = dist.get_global_rank(parallel_group, src_rank.item()) + if pp_size > 1: + src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda') + dist.all_reduce(src_rank, group=pp_group) + src_rank = dist.get_global_rank(pp_group, src_rank.item()) meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3} dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: - dist.broadcast(meta_data, src=src_rank, group=parallel_group) - if meta_data[0].item() > 0: - shape = meta_data[1:1 + meta_data[0]].tolist() - dtype = dtype_mapping_r[meta_data[-1].item()] - tensor = torch.empty(shape, device='cuda', dtype=dtype) - dist.broadcast(tensor, src=src_rank, group=parallel_group) + dist.broadcast(meta_data, src=src_rank, group=pp_group) + shape = meta_data[1:1 + meta_data[0]].tolist() + dtype = dtype_mapping_r[meta_data[-1].item()] + tensor = torch.empty(shape, device='cuda', dtype=dtype) + dist.broadcast(tensor, src=src_rank, group=pp_group) else: meta_data[0] = tensor.ndim meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') meta_data[-1] = dtype_mapping[tensor.dtype] - dist.broadcast(meta_data, src=src_rank, group=parallel_group) - dist.broadcast(tensor, src=src_rank, group=parallel_group) + dist.broadcast(meta_data, src=src_rank, group=pp_group) + dist.broadcast(tensor, src=src_rank, group=pp_group) assert tensor is not None, f'mg_key: {mg_key}' if offset: tensor = tensor + offset @@ -273,10 +288,8 @@ def _set_state_dict(self, is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper) if not to_mcore: state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda') - if self.pp_size > 1: - dist.all_reduce(state, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(state, group=self.ep_group) + if self.ep_pp_size > 1: + dist.all_reduce(state, group=self.ep_pp_group) is_lora, is_modules_to_save = state if is_lora and self._is_peft_format and param_key != 'layer_norm_weight': if to_mcore: @@ -627,10 +640,8 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(is_lora, group=self.ep_group) + if self.ep_pp_size > 1: + dist.all_reduce(is_lora, group=self.ep_pp_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: @@ -779,10 +790,8 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: - dist.all_reduce(is_lora, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(is_lora, group=self.ep_group) + if self.ep_pp_size > 1: + dist.all_reduce(is_lora, group=self.ep_pp_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: From da5442fdb2226fe8434c4140fa9b6b7943a38451 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 01:36:23 +0800 Subject: [PATCH 2/3] fix --- swift/megatron/model/gpt_bridge.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 13d92c16eb..451925285d 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -253,6 +253,7 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: dist.broadcast(meta_data, src=src_rank, group=pp_group) + assert meta_data[0].item() > 0, f'meta_data: {meta_data}' shape = meta_data[1:1 + meta_data[0]].tolist() dtype = dtype_mapping_r[meta_data[-1].item()] tensor = torch.empty(shape, device='cuda', dtype=dtype) @@ -288,8 +289,10 @@ def _set_state_dict(self, is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper) if not to_mcore: state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda') - if self.ep_pp_size > 1: + if is_expert and self.ep_pp_size > 1: dist.all_reduce(state, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: + dist.all_reduce(state, group=self.pp_group) is_lora, is_modules_to_save = state if is_lora and self._is_peft_format and param_key != 'layer_norm_weight': if to_mcore: @@ -640,8 +643,10 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.ep_pp_size > 1: + if is_expert and self.ep_pp_size > 1: dist.all_reduce(is_lora, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: @@ -790,8 +795,10 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.ep_pp_size > 1: + if is_expert and self.ep_pp_size > 1: dist.all_reduce(is_lora, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: + dist.all_reduce(is_lora, group=self.pp_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: From e18520656f227a381ef4f234a2ef6898c7dce13d Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Wed, 19 Nov 2025 01:50:59 +0800 Subject: [PATCH 3/3] fix --- swift/llm/argument/export_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index aa7844922c..16568c24b6 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -67,10 +67,10 @@ class ExportArguments(MergeArguments, BaseArguments): to_peft_format: bool = False exist_ok: bool = False - def _init_ckpt_dir(self, adapters=None): + def load_args_from_ckpt(self) -> None: if self.to_cached_dataset: return - super()._init_ckpt_dir(adapters) + super().load_args_from_ckpt() def _init_output_dir(self): if self.output_dir is None: