Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,10 +116,9 @@ def __init__(
init_method=config.init_method,
bias=False,
skip_bias_add=False,
parallel_mode=None,
parallel_mode='duplicated',
skip_weight_param_allocation=False,
)
Comment thread
Jintao-Huang marked this conversation as resolved.
self.output_layer.weight.average_gradients_across_tp_domain = True
elif self.config.task_type == 'embedding' and self.post_process:
self.output_layer = None

Expand Down Expand Up @@ -302,7 +301,7 @@ def forward(
if self.config.position_embedding_type == 'mrope' and position_ids.ndim == 2: # qwen3_asr
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
inference_context = deprecate_inference_params(inference_context, inference_params)

# There is a difference in whether rotary_pos_emb can be fused between the decoder and MTP.
decoder_input, rotary_pos_emb, rotary_pos_cos, rotary_pos_sin, sequence_len_offset = (
self._preprocess(
input_ids=input_ids,
Expand Down Expand Up @@ -491,10 +490,6 @@ def _postprocess(
# (so that the output layer, which expects S×B×H, receives only the final token)
hidden_states = inference_context.last_token_logits(hidden_states.squeeze(1).unsqueeze(0)).unsqueeze(1)

if self.config.task_type in {'seq_cls', 'embedding'
} and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
hidden_states = gather_from_sequence_parallel_region(hidden_states, tensor_parallel_output_grad=False)

if self.config.task_type == 'embedding':
logits = F.normalize(hidden_states, p=2, dim=-1)
else:
Expand All @@ -507,6 +502,10 @@ def _postprocess(
positive_token_id = self.tokenizer.convert_tokens_to_ids(positive_token)
negative_token_id = self.tokenizer.convert_tokens_to_ids(negative_token)
logits = (logits[..., positive_token_id] - logits[..., negative_token_id])[..., None]
if self.config.task_type in {'seq_cls', 'embedding'
} and self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1:
Comment thread
Jintao-Huang marked this conversation as resolved.
logits = gather_from_sequence_parallel_region(logits, tensor_parallel_output_grad=False)
Comment thread
Jintao-Huang marked this conversation as resolved.
Comment thread
Jintao-Huang marked this conversation as resolved.

# Restore sequence parallel execution to the output layer if necessary.
if sequence_parallel_override:
assert (in_inference_mode and inference_context.is_dynamic_batching()
Expand Down
7 changes: 7 additions & 0 deletions src/mcore_bridge/model/gpts/minimax_m2.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ def __init__(
config=self.config,
eps=self.config.layernorm_epsilon,
)
# These norms operate on the gathered (TP-replicated) tensor, so they must
# NOT participate in the cross-TP SP/QK-layernorm grad AllReduce,
# otherwise the gradients will be amplified by tp_size.
for norm in (self.q_norm, self.k_norm):
for p in norm.parameters():
p.sequence_parallel = False
p.average_gradients_across_tp_domain = True

def get_query_key_value_tensors(self, *_args, **kwargs):
enable_tp = mpu.get_tensor_model_parallel_world_size() > 1
Expand Down
33 changes: 18 additions & 15 deletions src/mcore_bridge/tuners/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,12 @@ def _build_local_te_linear(input_size: int, output_size: int, bias: bool, **kwar
bias=bias,
)
local_kwargs = dict(kwargs)
parallel_mode = None
if is_torch_npu_available():
# Local TE linear layers in MindSpeed 0.15.x use duplicated semantics,
# and this path does not accept tp_group.
local_kwargs.pop('tp_group', None)
parallel_mode = 'duplicated'
local_kwargs.pop('tp_group', None)
return TELinear(
input_size=input_size,
output_size=output_size,
bias=bias,
parallel_mode=parallel_mode,
parallel_mode='duplicated',
skip_weight_param_allocation=False,
**local_kwargs,
)
Expand Down Expand Up @@ -128,8 +123,12 @@ def __init__(
self.sequence_parallel = getattr(base_layer, 'sequence_parallel', False)
if self.is_expert:
self.tp_size = get_expert_tensor_parallel_world_size()
# TODO: For TEGroupedLinear under ETP, initialization must use different random seeds
# across EP ranks and identical random seeds across ETP ranks.
# Additionally, a parameter-averaging all_reduce across the ETP group is required.
# Note that TEGroupedLinear here excludes TEColumnParallelGroupedLinear and TERowParallelGroupedLinear.
if self.tp_size > 1:
raise ValueError('Currently, LoRA does not support ETP.') # TODO: init/all-reduce
raise ValueError('Currently, LoRA does not support ETP.')
else:
self.tp_size = get_tensor_model_parallel_world_size()
self.update_layer(
Expand Down Expand Up @@ -238,12 +237,15 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs):
)
lora_b.parallel_mode = self.base_layer.parallel_mode # fix moe_shared_expert_overlap
Comment thread
Jintao-Huang marked this conversation as resolved.
for lora in [lora_a, lora_b]:
if getattr(lora, 'parallel_mode', None) is None and hasattr(lora, 'weight'): # TODO: experts
if isinstance(self.base_layer, TopKRouter):
sequence_parallel = self.base_layer.weight.sequence_parallel
else:
sequence_parallel = self.sequence_parallel
lora.weight.sequence_parallel = sequence_parallel
# When parallel_mode is set to None by moe_shared_expert_overlap,
# disable UB comm overlap; the corresponding collectives are driven
# externally by the framework.
if isinstance(lora, (TERowParallelLinear, TEColumnParallelLinear)) and lora.parallel_mode is None:
lora.ub_overlap_rs_fprop = False
lora.ub_overlap_ag_dgrad = False
lora.ub_overlap_ag_fprop = False
lora.ub_overlap_rs_dgrad = False

self.lora_A[adapter_name] = lora_a
self.lora_B[adapter_name] = lora_b
if hasattr(self, 'lora_bias'):
Expand All @@ -261,9 +263,10 @@ def update_layer(self, adapter_name, r, *, lora_alpha, **kwargs):
def _get_rng_context(self, lora):
if not get_cuda_rng_tracker().is_initialized():
return nullcontext()
parallel_mode = getattr(lora, 'parallel_mode', None)
if self.is_expert:
rng_context = get_cuda_rng_tracker().fork(get_expert_parallel_rng_tracker_name())
elif getattr(lora, 'parallel_mode', None) is None:
elif parallel_mode in {'duplicated', None}:
rng_context = nullcontext()
else:
rng_context = get_cuda_rng_tracker().fork()
Expand Down
10 changes: 10 additions & 0 deletions src/mcore_bridge/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,16 @@ def new_deepcopy(x, *args, **kwargs):

for key, value in saved_values.items():
setattr(res, key, value)

# copy params attr
for name, param in x.named_parameters():
res_param = deep_getattr(res, name)
if res_param is None:
continue
for k, v in param.__dict__.items():
if not hasattr(res_param, k):
setattr(res_param, k, v)
Comment on lines +97 to +98
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The check if not hasattr(res_param, k) prevents copying attributes that might have been initialized with default values in the constructor of the new module. For example, if a parameter is initialized with sequence_parallel=False by default, it will not be updated even if the source parameter has it set to True. Since __dict__ on an nn.Parameter typically only contains custom metadata added by the user or the framework (such as Megatron's parallelism flags), it is safer to overwrite these attributes to ensure the state is correctly mirrored in the deep-copied module.

Suggested change
if not hasattr(res_param, k):
setattr(res_param, k, v)
for k, v in param.__dict__.items():
setattr(res_param, k, v)


return res

copy.deepcopy = new_deepcopy
Expand Down
Loading