Skip to content

[megatron] support gemma4 megatron#9296

Merged
Jintao-Huang merged 11 commits into
modelscope:mainfrom
Jintao-Huang:support_gemma4_megatron
May 19, 2026
Merged

[megatron] support gemma4 megatron#9296
Jintao-Huang merged 11 commits into
modelscope:mainfrom
Jintao-Huang:support_gemma4_megatron

Conversation

@Jintao-Huang
Copy link
Copy Markdown
Collaborator

@Jintao-Huang Jintao-Huang commented May 9, 2026

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the documentation to reflect support for Gemma 4 models and refactors embedding handling in Megatron utilities to support multiple modules during device conversion. In swift/model/models/gemma.py, a suggestion was made to use inputs_embeds.device instead of multimodal_mask.device when moving the pad_embedding tensor to ensure better robustness and consistency across operands in the subsequent torch.where call.

Comment thread swift/model/models/gemma.py Outdated

if self.config.get_text_config().hidden_size_per_layer_input:
pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :]
pad_embedding = pad_embedding.to(multimodal_mask.device)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using inputs_embeds.device as the target for the .to() call is generally more robust than multimodal_mask.device. Since inputs_embeds is the primary tensor representing the hidden states in this operation, it serves as the most reliable reference for the execution device, ensuring consistency across all operands in the subsequent torch.where call.

Suggested change
pad_embedding = pad_embedding.to(multimodal_mask.device)
pad_embedding = pad_embedding.to(inputs_embeds.device)

@Jintao-Huang
Copy link
Copy Markdown
Collaborator Author

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the documentation to reflect support for Gemma-4 models, adds a Megatron SFT training script, and refactors device conversion utilities. Review feedback identifies a potential runtime error in the training script due to inconsistent parallelism configurations where the Expert Parallel size exceeds the Data Parallel size. Additionally, improvements are suggested for the conversion utility to prevent a potential UnboundLocalError by initializing the data type variable and to enhance the readability of boolean logic through the use of parentheses.

Comment on lines +18 to +20
--tensor_model_parallel_size 2 \
--expert_model_parallel_size 4 \
--pipeline_model_parallel_size 2 \
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The parallelism configuration appears inconsistent with the total number of GPUs (NPROC_PER_NODE=8). With tensor_model_parallel_size=2 and pipeline_model_parallel_size=2, the Data Parallel (DP) size is calculated as 8 / (2 * 2) = 2. In Megatron-Core, the Expert Parallel (EP) size (expert_model_parallel_size) must typically be less than or equal to the DP size (EP <= DP). Setting EP=4 while DP=2 will likely result in a runtime error during model initialization.

Comment on lines +65 to +71
for module in modules:
try:
origin_torch_dtype = next(module.parameters()).dtype
except StopIteration:
pass
else:
break
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The variable origin_torch_dtype is not initialized before the loop. If the modules list is empty or if none of the modules contain parameters (causing StopIteration in every iteration), origin_torch_dtype will remain undefined. This will lead to an UnboundLocalError when it is accessed later in the _to_cpu_hook function (line 85). Initializing it to None provides a safe fallback, as module.to(dtype=None) is a no-op in PyTorch.

Suggested change
for module in modules:
try:
origin_torch_dtype = next(module.parameters()).dtype
except StopIteration:
pass
else:
break
origin_torch_dtype = None
for module in modules:
try:
origin_torch_dtype = next(module.parameters()).dtype
break
except StopIteration:
pass


def _to_cpu_hook(module, args, output):
if share_embedding and module in embeddings:
if share_embedding and module in embeddings or 'rotaryemb' in module.__class__.__name__.lower():
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The boolean expression relies on operator precedence (and before or), which can be error-prone and harder to read. Additionally, checking for 'rotaryemb' in the class name is a bit fragile. While string matching is often used in this context to avoid circular imports, adding parentheses would at least clarify the intended logic.

Suggested change
if share_embedding and module in embeddings or 'rotaryemb' in module.__class__.__name__.lower():
if (share_embedding and module in embeddings) or 'rotaryemb' in module.__class__.__name__.lower():

@Jintao-Huang Jintao-Huang merged commit 30c6799 into modelscope:main May 19, 2026
1 of 3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants