diff --git a/swift/llm/argument/export_args.py b/swift/llm/argument/export_args.py index 921163e4ae..16568c24b6 100644 --- a/swift/llm/argument/export_args.py +++ b/swift/llm/argument/export_args.py @@ -67,6 +67,11 @@ class ExportArguments(MergeArguments, BaseArguments): to_peft_format: bool = False exist_ok: bool = False + def load_args_from_ckpt(self) -> None: + if self.to_cached_dataset: + return + super().load_args_from_ckpt() + def _init_output_dir(self): if self.output_dir is None: ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}' diff --git a/swift/megatron/model/gpt_bridge.py b/swift/megatron/model/gpt_bridge.py index 40aff04568..451925285d 100644 --- a/swift/megatron/model/gpt_bridge.py +++ b/swift/megatron/model/gpt_bridge.py @@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False): self.etp_rank = mpu.get_expert_tensor_parallel_rank() self.ep_rank = mpu.get_expert_model_parallel_rank() + dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size + expert_decoder_rank_generator = mpu.RankGenerator( + tp=self.etp_size, + ep=self.ep_size, + dp=dp_size, + pp=self.pp_size, + cp=1, + order='tp-cp-ep-dp-pp', + rank_offset=0, + ) + rank = dist.get_rank() + for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): + group = mpu.create_group( + ranks, + group_desc='EP-PP-GROUP', + ) + if rank in ranks: + self.ep_pp_size = self.ep_size * self.pp_size + self.ep_pp_group = group + self.ep_pp_rank = dist.get_rank(group) + def _init_meta_hf_model(self): with torch.device('meta'): self.hf_model, self.processor = get_model_tokenizer( @@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl tensor = None if mg_weight is None else mg_weight.to('cuda') tp_size = self.etp_size if is_expert else self.tp_size tp_group = self.etp_group if is_expert else self.tp_group + pp_group = self.ep_pp_group if is_expert else self.pp_group + pp_size = self.ep_pp_size if is_expert else self.pp_size + pp_rank = self.ep_pp_rank if is_expert else self.pp_rank if tensor is not None and tp_dim is not None and tp_size > 1: if tp_dim == 0: # save memory @@ -220,34 +244,26 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl tensor = torch.cat(output, dim=tp_dim) del output # pp/ep - for parallel_state in ['ep', 'pp']: - if parallel_state == 'pp' and self.pp_size > 1: - parallel_group = self.pp_group - parallel_rank = self.pp_rank - elif parallel_state == 'ep' and is_expert and self.ep_size > 1: - parallel_group = self.ep_group - parallel_rank = self.ep_rank - else: - continue - src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda') - dist.all_reduce(src_rank, group=parallel_group) - src_rank = dist.get_global_rank(parallel_group, src_rank.item()) + if pp_size > 1: + src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda') + dist.all_reduce(src_rank, group=pp_group) + src_rank = dist.get_global_rank(pp_group, src_rank.item()) meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3} dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} if tensor is None: - dist.broadcast(meta_data, src=src_rank, group=parallel_group) - if meta_data[0].item() > 0: - shape = meta_data[1:1 + meta_data[0]].tolist() - dtype = dtype_mapping_r[meta_data[-1].item()] - tensor = torch.empty(shape, device='cuda', dtype=dtype) - dist.broadcast(tensor, src=src_rank, group=parallel_group) + dist.broadcast(meta_data, src=src_rank, group=pp_group) + assert meta_data[0].item() > 0, f'meta_data: {meta_data}' + shape = meta_data[1:1 + meta_data[0]].tolist() + dtype = dtype_mapping_r[meta_data[-1].item()] + tensor = torch.empty(shape, device='cuda', dtype=dtype) + dist.broadcast(tensor, src=src_rank, group=pp_group) else: meta_data[0] = tensor.ndim meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') meta_data[-1] = dtype_mapping[tensor.dtype] - dist.broadcast(meta_data, src=src_rank, group=parallel_group) - dist.broadcast(tensor, src=src_rank, group=parallel_group) + dist.broadcast(meta_data, src=src_rank, group=pp_group) + dist.broadcast(tensor, src=src_rank, group=pp_group) assert tensor is not None, f'mg_key: {mg_key}' if offset: tensor = tensor + offset @@ -273,10 +289,10 @@ def _set_state_dict(self, is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper) if not to_mcore: state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda') - if self.pp_size > 1: + if is_expert and self.ep_pp_size > 1: + dist.all_reduce(state, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: dist.all_reduce(state, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(state, group=self.ep_group) is_lora, is_modules_to_save = state if is_lora and self._is_peft_format and param_key != 'layer_norm_weight': if to_mcore: @@ -627,10 +643,10 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: + if is_expert and self.ep_pp_size > 1: + dist.all_reduce(is_lora, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: dist.all_reduce(is_lora, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(is_lora, group=self.ep_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: @@ -779,10 +795,10 @@ def _set_mlp_state(self, is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2, LoraParallelLinear) and self._is_peft_format is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') - if self.pp_size > 1: + if is_expert and self.ep_pp_size > 1: + dist.all_reduce(is_lora, group=self.ep_pp_group) + elif not is_expert and self.pp_size > 1: dist.all_reduce(is_lora, group=self.pp_group) - if is_expert and self.ep_size > 1: - dist.all_reduce(is_lora, group=self.ep_group) if is_lora: assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' if mg_mlp is None: