Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions swift/llm/argument/export_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ class ExportArguments(MergeArguments, BaseArguments):
to_peft_format: bool = False
exist_ok: bool = False

def load_args_from_ckpt(self) -> None:
if self.to_cached_dataset:
return
super().load_args_from_ckpt()

def _init_output_dir(self):
if self.output_dir is None:
ckpt_dir = self.ckpt_dir or f'./{self.model_suffix}'
Expand Down
74 changes: 45 additions & 29 deletions swift/megatron/model/gpt_bridge.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,27 @@ def __init__(self, disable_tqmd: bool = False):
self.etp_rank = mpu.get_expert_tensor_parallel_rank()
self.ep_rank = mpu.get_expert_model_parallel_rank()

dp_size = dist.get_world_size() // self.etp_size // self.ep_size // self.pp_size
expert_decoder_rank_generator = mpu.RankGenerator(
tp=self.etp_size,
ep=self.ep_size,
dp=dp_size,
pp=self.pp_size,
cp=1,
order='tp-cp-ep-dp-pp',
rank_offset=0,
)
rank = dist.get_rank()
for ranks in expert_decoder_rank_generator.get_ranks('ep-pp'):
group = mpu.create_group(
ranks,
group_desc='EP-PP-GROUP',
)
if rank in ranks:
self.ep_pp_size = self.ep_size * self.pp_size
self.ep_pp_group = group
self.ep_pp_rank = dist.get_rank(group)

def _init_meta_hf_model(self):
with torch.device('meta'):
self.hf_model, self.processor = get_model_tokenizer(
Expand Down Expand Up @@ -198,6 +219,9 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
tensor = None if mg_weight is None else mg_weight.to('cuda')
tp_size = self.etp_size if is_expert else self.tp_size
tp_group = self.etp_group if is_expert else self.tp_group
pp_group = self.ep_pp_group if is_expert else self.pp_group
pp_size = self.ep_pp_size if is_expert else self.pp_size
pp_rank = self.ep_pp_rank if is_expert else self.pp_rank
if tensor is not None and tp_dim is not None and tp_size > 1:
if tp_dim == 0:
# save memory
Expand All @@ -220,34 +244,26 @@ def _get_weight(self, mg_weight: torch.Tensor, mg_key: Optional[str], offset: fl
tensor = torch.cat(output, dim=tp_dim)
del output
# pp/ep
for parallel_state in ['ep', 'pp']:
if parallel_state == 'pp' and self.pp_size > 1:
parallel_group = self.pp_group
parallel_rank = self.pp_rank
elif parallel_state == 'ep' and is_expert and self.ep_size > 1:
parallel_group = self.ep_group
parallel_rank = self.ep_rank
else:
continue
src_rank = torch.tensor([0 if tensor is None else parallel_rank], dtype=torch.int64, device='cuda')
dist.all_reduce(src_rank, group=parallel_group)
src_rank = dist.get_global_rank(parallel_group, src_rank.item())
if pp_size > 1:
src_rank = torch.tensor([0 if tensor is None else pp_rank], dtype=torch.int64, device='cuda')
dist.all_reduce(src_rank, group=pp_group)
src_rank = dist.get_global_rank(pp_group, src_rank.item())
meta_data = torch.zeros(10, dtype=torch.int64, device='cuda')
dtype_mapping = {torch.float64: 0, torch.float32: 1, torch.float16: 2, torch.bfloat16: 3}
dtype_mapping_r = {v: k for k, v in dtype_mapping.items()}
if tensor is None:
dist.broadcast(meta_data, src=src_rank, group=parallel_group)
if meta_data[0].item() > 0:
shape = meta_data[1:1 + meta_data[0]].tolist()
dtype = dtype_mapping_r[meta_data[-1].item()]
tensor = torch.empty(shape, device='cuda', dtype=dtype)
dist.broadcast(tensor, src=src_rank, group=parallel_group)
dist.broadcast(meta_data, src=src_rank, group=pp_group)
assert meta_data[0].item() > 0, f'meta_data: {meta_data}'
shape = meta_data[1:1 + meta_data[0]].tolist()
dtype = dtype_mapping_r[meta_data[-1].item()]
tensor = torch.empty(shape, device='cuda', dtype=dtype)
dist.broadcast(tensor, src=src_rank, group=pp_group)
else:
meta_data[0] = tensor.ndim
meta_data[1:1 + tensor.ndim] = torch.tensor(tensor.shape, dtype=torch.int64, device='cuda')
meta_data[-1] = dtype_mapping[tensor.dtype]
dist.broadcast(meta_data, src=src_rank, group=parallel_group)
dist.broadcast(tensor, src=src_rank, group=parallel_group)
dist.broadcast(meta_data, src=src_rank, group=pp_group)
dist.broadcast(tensor, src=src_rank, group=pp_group)
assert tensor is not None, f'mg_key: {mg_key}'
if offset:
tensor = tensor + offset
Expand All @@ -273,10 +289,10 @@ def _set_state_dict(self,
is_modules_to_save = isinstance(sub_module, ModulesToSaveWrapper)
if not to_mcore:
state = torch.tensor([is_lora, is_modules_to_save], dtype=torch.bool, device='cuda')
if self.pp_size > 1:
if is_expert and self.ep_pp_size > 1:
dist.all_reduce(state, group=self.ep_pp_group)
elif not is_expert and self.pp_size > 1:
dist.all_reduce(state, group=self.pp_group)
Comment on lines +292 to 295
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This if/elif block for the all_reduce operation is repeated in _set_mlp_state as well. To improve code maintainability and reduce duplication, this logic can be refactored. By first selecting the appropriate parallel group and size based on whether it's an expert layer, and then performing the communication, the code becomes more concise and easier to read.

Suggested change
if is_expert and self.ep_pp_size > 1:
dist.all_reduce(state, group=self.ep_pp_group)
elif not is_expert and self.pp_size > 1:
dist.all_reduce(state, group=self.pp_group)
pp_group = self.ep_pp_group if is_expert else self.pp_group
pp_size = self.ep_pp_size if is_expert else self.pp_size
if pp_size > 1:
dist.all_reduce(state, group=pp_group)

if is_expert and self.ep_size > 1:
dist.all_reduce(state, group=self.ep_group)
is_lora, is_modules_to_save = state
if is_lora and self._is_peft_format and param_key != 'layer_norm_weight':
if to_mcore:
Expand Down Expand Up @@ -627,10 +643,10 @@ def _set_mlp_state(self,
is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc1,
LoraParallelLinear) and self._is_peft_format
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
if self.pp_size > 1:
if is_expert and self.ep_pp_size > 1:
dist.all_reduce(is_lora, group=self.ep_pp_group)
elif not is_expert and self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
Comment on lines +646 to 649
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the logic in _set_state_dict, this if/elif block for all_reduce can be refactored to improve readability and avoid code duplication. This will make the codebase more consistent and easier to maintain.

Suggested change
if is_expert and self.ep_pp_size > 1:
dist.all_reduce(is_lora, group=self.ep_pp_group)
elif not is_expert and self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
pp_group = self.ep_pp_group if is_expert else self.pp_group
pp_size = self.ep_pp_size if is_expert else self.pp_size
if pp_size > 1:
dist.all_reduce(is_lora, group=pp_group)

if is_expert and self.ep_size > 1:
dist.all_reduce(is_lora, group=self.ep_group)
if is_lora:
assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.'
if mg_mlp is None:
Expand Down Expand Up @@ -779,10 +795,10 @@ def _set_mlp_state(self,
is_lora = False if mg_mlp is None else isinstance(mg_mlp.linear_fc2,
LoraParallelLinear) and self._is_peft_format
is_lora = torch.tensor([is_lora], dtype=torch.bool, device='cuda')
if self.pp_size > 1:
if is_expert and self.ep_pp_size > 1:
dist.all_reduce(is_lora, group=self.ep_pp_group)
elif not is_expert and self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
Comment on lines +798 to 801
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is another instance of the same communication logic. Applying the same refactoring here will further improve code consistency and maintainability across the file.

Suggested change
if is_expert and self.ep_pp_size > 1:
dist.all_reduce(is_lora, group=self.ep_pp_group)
elif not is_expert and self.pp_size > 1:
dist.all_reduce(is_lora, group=self.pp_group)
pp_group = self.ep_pp_group if is_expert else self.pp_group
pp_size = self.ep_pp_size if is_expert else self.pp_size
if pp_size > 1:
dist.all_reduce(is_lora, group=pp_group)

if is_expert and self.ep_size > 1:
dist.all_reduce(is_lora, group=self.ep_group)
if is_lora:
assert not hf_grouped, 'Currently, hf_grouped with LoRA is not supported.'
if mg_mlp is None:
Expand Down
Loading