[megatron] support gemma4 megatron#9296
Conversation
There was a problem hiding this comment.
Code Review
This pull request updates the documentation to reflect support for Gemma 4 models and refactors embedding handling in Megatron utilities to support multiple modules during device conversion. In swift/model/models/gemma.py, a suggestion was made to use inputs_embeds.device instead of multimodal_mask.device when moving the pad_embedding tensor to ensure better robustness and consistency across operands in the subsequent torch.where call.
|
|
||
| if self.config.get_text_config().hidden_size_per_layer_input: | ||
| pad_embedding = self.language_model.embed_tokens.weight[self.config.text_config.pad_token_id, :] | ||
| pad_embedding = pad_embedding.to(multimodal_mask.device) |
There was a problem hiding this comment.
Using inputs_embeds.device as the target for the .to() call is generally more robust than multimodal_mask.device. Since inputs_embeds is the primary tensor representing the hidden states in this operation, it serves as the most reliable reference for the execution device, ensuring consistency across all operands in the subsequent torch.where call.
| pad_embedding = pad_embedding.to(multimodal_mask.device) | |
| pad_embedding = pad_embedding.to(inputs_embeds.device) |
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request updates the documentation to reflect support for Gemma-4 models, adds a Megatron SFT training script, and refactors device conversion utilities. Review feedback identifies a potential runtime error in the training script due to inconsistent parallelism configurations where the Expert Parallel size exceeds the Data Parallel size. Additionally, improvements are suggested for the conversion utility to prevent a potential UnboundLocalError by initializing the data type variable and to enhance the readability of boolean logic through the use of parentheses.
| --tensor_model_parallel_size 2 \ | ||
| --expert_model_parallel_size 4 \ | ||
| --pipeline_model_parallel_size 2 \ |
There was a problem hiding this comment.
The parallelism configuration appears inconsistent with the total number of GPUs (NPROC_PER_NODE=8). With tensor_model_parallel_size=2 and pipeline_model_parallel_size=2, the Data Parallel (DP) size is calculated as 8 / (2 * 2) = 2. In Megatron-Core, the Expert Parallel (EP) size (expert_model_parallel_size) must typically be less than or equal to the DP size (EP <= DP). Setting EP=4 while DP=2 will likely result in a runtime error during model initialization.
| for module in modules: | ||
| try: | ||
| origin_torch_dtype = next(module.parameters()).dtype | ||
| except StopIteration: | ||
| pass | ||
| else: | ||
| break |
There was a problem hiding this comment.
The variable origin_torch_dtype is not initialized before the loop. If the modules list is empty or if none of the modules contain parameters (causing StopIteration in every iteration), origin_torch_dtype will remain undefined. This will lead to an UnboundLocalError when it is accessed later in the _to_cpu_hook function (line 85). Initializing it to None provides a safe fallback, as module.to(dtype=None) is a no-op in PyTorch.
| for module in modules: | |
| try: | |
| origin_torch_dtype = next(module.parameters()).dtype | |
| except StopIteration: | |
| pass | |
| else: | |
| break | |
| origin_torch_dtype = None | |
| for module in modules: | |
| try: | |
| origin_torch_dtype = next(module.parameters()).dtype | |
| break | |
| except StopIteration: | |
| pass |
|
|
||
| def _to_cpu_hook(module, args, output): | ||
| if share_embedding and module in embeddings: | ||
| if share_embedding and module in embeddings or 'rotaryemb' in module.__class__.__name__.lower(): |
There was a problem hiding this comment.
The boolean expression relies on operator precedence (and before or), which can be error-prone and harder to read. Additionally, checking for 'rotaryemb' in the class name is a bit fragile. While string matching is often used in this context to avoid circular imports, adding parentheses would at least clarify the intended logic.
| if share_embedding and module in embeddings or 'rotaryemb' in module.__class__.__name__.lower(): | |
| if (share_embedding and module in embeddings) or 'rotaryemb' in module.__class__.__name__.lower(): |
modelscope/mcore-bridge#56