-
Notifications
You must be signed in to change notification settings - Fork 995
[mcore-bridge] optimize gpt_bridge comm #6659
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False): | |||||||||||||||||
| self.etp_rank = mpu.get_expert_tensor_parallel_rank() | ||||||||||||||||||
| self.ep_rank = mpu.get_expert_model_parallel_rank() | ||||||||||||||||||
|
|
||||||||||||||||||
| dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size | ||||||||||||||||||
| expert_decoder_rank_generator = mpu.RankGenerator( | ||||||||||||||||||
| tp=self.etp_size, | ||||||||||||||||||
| ep=self.ep_size, | ||||||||||||||||||
| dp=dp_size, | ||||||||||||||||||
| pp=self.pp_size, | ||||||||||||||||||
| cp=1, | ||||||||||||||||||
| order='tp-cp-ep-dp-pp', | ||||||||||||||||||
| rank_offset=0, | ||||||||||||||||||
| ) | ||||||||||||||||||
| rank = dist.get_rank() | ||||||||||||||||||
| for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'): | ||||||||||||||||||
| group = mpu.create_group( | ||||||||||||||||||
| ranks, | ||||||||||||||||||
| group_desc='EP-PP-GROUP', | ||||||||||||||||||
| ) | ||||||||||||||||||
| if rank in ranks: | ||||||||||||||||||
| self.ep_pp_size = self.ep_size * self.pp_size | ||||||||||||||||||
| self.ep_pp_group = group | ||||||||||||||||||
| self.ep_pp_rank = dist.get_rank(group) | ||||||||||||||||||
|
|
||||||||||||||||||
| def _init_meta_hf_model(self): | ||||||||||||||||||
| with torch.device('meta'): | ||||||||||||||||||
| self.hf_model, self.processor = get_model_tokenizer( | ||||||||||||||||||
|
|
@@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl | |||||||||||||||||
| tensor = None if mg_weight is None else mg_weight.to('cuda') | ||||||||||||||||||
| tp_size = self.etp_size if is_expert else self.tp_size | ||||||||||||||||||
| tp_group = self.etp_group if is_expert else self.tp_group | ||||||||||||||||||
| pp_group = self.ep_pp_group if is_expert else self.pp_group | ||||||||||||||||||
| pp_size = self.ep_pp_size if is_expert else self.pp_size | ||||||||||||||||||
| pp_rank = self.ep_pp_rank if is_expert else self.pp_rank | ||||||||||||||||||
| if tensor is not None and tp_dim is not None and tp_size > 1: | ||||||||||||||||||
| if tp_dim == 0: | ||||||||||||||||||
| # save memory | ||||||||||||||||||
|
|
@@ -220,34 +244,26 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl | |||||||||||||||||
| tensor = torch.cat(output, dim=tp_dim) | ||||||||||||||||||
| del output | ||||||||||||||||||
| # pp/ep | ||||||||||||||||||
| for parallel_state in ['ep', 'pp']: | ||||||||||||||||||
| if parallel_state == 'pp' and self.pp_size > 1: | ||||||||||||||||||
| parallel_group = self.pp_group | ||||||||||||||||||
| parallel_rank = self.pp_rank | ||||||||||||||||||
| elif parallel_state == 'ep' and is_expert and self.ep_size > 1: | ||||||||||||||||||
| parallel_group = self.ep_group | ||||||||||||||||||
| parallel_rank = self.ep_rank | ||||||||||||||||||
| else: | ||||||||||||||||||
| continue | ||||||||||||||||||
| src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda') | ||||||||||||||||||
| dist.all_reduce(src_rank, group=parallel_group) | ||||||||||||||||||
| src_rank = dist.get_global_rank(parallel_group, src_rank.item()) | ||||||||||||||||||
| if pp_size > 1: | ||||||||||||||||||
| src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda') | ||||||||||||||||||
| dist.all_reduce(src_rank, group=pp_group) | ||||||||||||||||||
| src_rank = dist.get_global_rank(pp_group, src_rank.item()) | ||||||||||||||||||
| meta_data = torch.zeros(10, dtype=torch.int64, device='cuda') | ||||||||||||||||||
| dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3} | ||||||||||||||||||
| dtype_mapping_r = {v: k for k, v in dtype_mapping.items()} | ||||||||||||||||||
| if tensor is None: | ||||||||||||||||||
| dist.broadcast(meta_data, src=src_rank, group=parallel_group) | ||||||||||||||||||
| if meta_data[0].item() > 0: | ||||||||||||||||||
| shape = meta_data[1:1 + meta_data[0]].tolist() | ||||||||||||||||||
| dtype = dtype_mapping_r[meta_data[-1].item()] | ||||||||||||||||||
| tensor = torch.empty(shape, device='cuda', dtype=dtype) | ||||||||||||||||||
| dist.broadcast(tensor, src=src_rank, group=parallel_group) | ||||||||||||||||||
| dist.broadcast(meta_data, src=src_rank, group=pp_group) | ||||||||||||||||||
| assert meta_data[0].item() > 0, f'meta_data: {meta_data}' | ||||||||||||||||||
| shape = meta_data[1:1 + meta_data[0]].tolist() | ||||||||||||||||||
| dtype = dtype_mapping_r[meta_data[-1].item()] | ||||||||||||||||||
| tensor = torch.empty(shape, device='cuda', dtype=dtype) | ||||||||||||||||||
| dist.broadcast(tensor, src=src_rank, group=pp_group) | ||||||||||||||||||
| else: | ||||||||||||||||||
| meta_data[0] = tensor.ndim | ||||||||||||||||||
| meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda') | ||||||||||||||||||
| meta_data[-1] = dtype_mapping[tensor.dtype] | ||||||||||||||||||
| dist.broadcast(meta_data, src=src_rank, group=parallel_group) | ||||||||||||||||||
| dist.broadcast(tensor, src=src_rank, group=parallel_group) | ||||||||||||||||||
| dist.broadcast(meta_data, src=src_rank, group=pp_group) | ||||||||||||||||||
| dist.broadcast(tensor, src=src_rank, group=pp_group) | ||||||||||||||||||
| assert tensor is not None, f'mg_key: {mg_key}' | ||||||||||||||||||
| if offset: | ||||||||||||||||||
| tensor = tensor + offset | ||||||||||||||||||
|
|
@@ -273,10 +289,10 @@ def _set_state_dict(self, | |||||||||||||||||
| is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper) | ||||||||||||||||||
| if not to_mcore: | ||||||||||||||||||
| state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda') | ||||||||||||||||||
| if self.pp_size > 1: | ||||||||||||||||||
| if is_expert and self.ep_pp_size > 1: | ||||||||||||||||||
| dist.all_reduce(state, group=self.ep_pp_group) | ||||||||||||||||||
| elif not is_expert and self.pp_size > 1: | ||||||||||||||||||
| dist.all_reduce(state, group=self.pp_group) | ||||||||||||||||||
| if is_expert and self.ep_size > 1: | ||||||||||||||||||
| dist.all_reduce(state, group=self.ep_group) | ||||||||||||||||||
| is_lora, is_modules_to_save = state | ||||||||||||||||||
| if is_lora and self._is_peft_format and param_key != 'layer_norm_weight': | ||||||||||||||||||
| if to_mcore: | ||||||||||||||||||
|
|
@@ -627,10 +643,10 @@ def _set_mlp_state(self, | |||||||||||||||||
| is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1, | ||||||||||||||||||
| LoraParallelLinear) and self._is_peft_format | ||||||||||||||||||
| is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') | ||||||||||||||||||
| if self.pp_size > 1: | ||||||||||||||||||
| if is_expert and self.ep_pp_size > 1: | ||||||||||||||||||
| dist.all_reduce(is_lora, group=self.ep_pp_group) | ||||||||||||||||||
| elif not is_expert and self.pp_size > 1: | ||||||||||||||||||
| dist.all_reduce(is_lora, group=self.pp_group) | ||||||||||||||||||
|
Comment on lines
+646
to
649
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar to the logic in
Suggested change
|
||||||||||||||||||
| if is_expert and self.ep_size > 1: | ||||||||||||||||||
| dist.all_reduce(is_lora, group=self.ep_group) | ||||||||||||||||||
| if is_lora: | ||||||||||||||||||
| assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' | ||||||||||||||||||
| if mg_mlp is None: | ||||||||||||||||||
|
|
@@ -779,10 +795,10 @@ def _set_mlp_state(self, | |||||||||||||||||
| is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2, | ||||||||||||||||||
| LoraParallelLinear) and self._is_peft_format | ||||||||||||||||||
| is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda') | ||||||||||||||||||
| if self.pp_size > 1: | ||||||||||||||||||
| if is_expert and self.ep_pp_size > 1: | ||||||||||||||||||
| dist.all_reduce(is_lora, group=self.ep_pp_group) | ||||||||||||||||||
| elif not is_expert and self.pp_size > 1: | ||||||||||||||||||
| dist.all_reduce(is_lora, group=self.pp_group) | ||||||||||||||||||
|
Comment on lines
+798
to
801
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is another instance of the same communication logic. Applying the same refactoring here will further improve code consistency and maintainability across the file.
Suggested change
|
||||||||||||||||||
| if is_expert and self.ep_size > 1: | ||||||||||||||||||
| dist.all_reduce(is_lora, group=self.ep_group) | ||||||||||||||||||
| if is_lora: | ||||||||||||||||||
| assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.' | ||||||||||||||||||
| if mg_mlp is None: | ||||||||||||||||||
|
|
||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This
if/elifblock for theall_reduceoperation is repeated in_set_mlp_stateas well. To improve code maintainability and reduce duplication, this logic can be refactored. By first selecting the appropriate parallel group and size based on whether it's an expert layer, and then performing the communication, the code becomes more concise and easier to read.