Expert parallel support twinkle.TransfomersModel#23
Conversation
Summary of ChangesHello @kevssim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the distributed training capabilities of Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
|
/gemini review |
There was a problem hiding this comment.
Code Review
This pull request introduces support for Expert Parallelism (EP) in twinkle.TransformersModel, primarily for Mixture-of-Experts (MoE) models, and integrates it with FSDP. The changes are extensive, including a new example script for Qwen3-MoE, modifications to the core EP logic to handle mixed precision, a new gradient clipping utility for mixed DTensor/Tensor environments, and substantial updates to TransformersModel and NativeFSDPStrategy to manage the EP lifecycle. My review has identified a critical bug in the distributed initialization logic that could cause a hang in multi-process scenarios. Additionally, I've provided suggestions to enhance performance by minimizing CPU-GPU synchronization, refactor duplicated code for better maintainability, and make minor improvements to code efficiency and style. Overall, this is a significant feature addition, and addressing the identified issues will improve its robustness and performance.
| if rank < 0: | ||
| rank = 0 | ||
| if world_size <= 1: | ||
| raise RuntimeError( | ||
| "EP+FSDP requires distributed launch with WORLD_SIZE>1 " | ||
| "and initialized rank env vars." | ||
| ) |
There was a problem hiding this comment.
There's a critical issue in the distributed initialization logic. The current order of checks can lead to incorrect behavior. If world_size > 1 but RANK is not set, rank is defaulted to 0, causing all processes to identify as rank 0. You should first check for a valid distributed environment (world_size > 1) and then ensure rank is valid.
| if rank < 0: | |
| rank = 0 | |
| if world_size <= 1: | |
| raise RuntimeError( | |
| "EP+FSDP requires distributed launch with WORLD_SIZE>1 " | |
| "and initialized rank env vars." | |
| ) | |
| if world_size <= 1: | |
| raise RuntimeError( | |
| "EP+FSDP requires distributed launch with WORLD_SIZE>1 " | |
| "and initialized rank env vars." | |
| ) | |
| if rank < 0: | |
| raise RuntimeError( | |
| "EP+FSDP requires the RANK environment variable to be set for distributed launch." | |
| ) |
| except Exception: | ||
| pass |
There was a problem hiding this comment.
Using a bare except Exception: pass is generally discouraged as it can hide unexpected errors. It's better to catch a more specific exception, such as ImportError if you're concerned about the module not being available, or at least log the exception for debugging purposes.
| except Exception: | |
| pass | |
| except ImportError: | |
| pass |
| compute_dtype = _module_compute_dtype(expert, input_dtype) | ||
| if compute_dtype != input_dtype: | ||
| expert_in = expert_in.to(compute_dtype) | ||
| out = expert(expert_in) | ||
| if out.dtype != input_dtype: | ||
| out = out.to(input_dtype) |
There was a problem hiding this comment.
This block of code for handling dtype casting appears to be duplicated in _maybe_run_shared_expert (lines 313-318). To improve maintainability and reduce code duplication, consider extracting this logic into a helper function. For example:
def _run_with_casting(module: nn.Module, inputs: torch.Tensor) -> torch.Tensor:
input_dtype = inputs.dtype
compute_dtype = _module_compute_dtype(module, input_dtype)
if compute_dtype != input_dtype:
inputs = inputs.to(compute_dtype)
out = module(inputs)
if out.dtype != input_dtype:
out = out.to(input_dtype)
return outThen you could simplify this part to return _run_with_casting(expert, expert_in) and similarly refactor _maybe_run_shared_expert.
| local_norm = 0.0 | ||
| for grad in grads: | ||
| local_grad = _local_grad(grad) | ||
| if local_grad.numel() == 0: | ||
| continue | ||
| local_norm = max(local_norm, local_grad.detach().abs().max().item()) |
There was a problem hiding this comment.
Calling .item() inside a loop over gradients can introduce multiple CPU-GPU synchronization points, which may impact performance, especially with a large number of parameters. To optimize, you could accumulate the maximum values as tensors on the device and perform a single reduction at the end of the loop before calling .item().
For example, you could create a list of tensor max values, then use torch.stack(list).max() to get the final local max before reducing across processes. This avoids CPU-GPU syncs within the loop. A similar optimization can be applied to the L2 norm calculation on lines 78-83.
| decay_param_names = [ | ||
| n for n, p in params.items() if (n in decay_parameters and p.requires_grad) | ||
| ] | ||
| no_decay_param_names = [ | ||
| n for n, p in params.items() if (n not in decay_parameters and p.requires_grad) | ||
| ] |
There was a problem hiding this comment.
To improve efficiency slightly, you can determine the decay_param_names and no_decay_param_names in a single pass over the params dictionary instead of two separate list comprehensions.
decay_param_names = []
no_decay_param_names = []
for n, p in params.items():
if p.requires_grad:
if n in decay_parameters:
decay_param_names.append(n)
else:
no_decay_param_names.append(n)| if optimizer_cls in ('AdamW', 'Adam'): | ||
| kwargs['foreach'] = False | ||
| elif isinstance(optimizer_cls, type) and issubclass(optimizer_cls, (AdamW, Adam)): | ||
| kwargs['foreach'] = False |
There was a problem hiding this comment.
The logic to check if optimizer_cls is Adam or AdamW is a bit repetitive. You could consolidate this check to improve readability.
is_adam_family = optimizer_cls in ('AdamW', 'Adam') or \
(isinstance(optimizer_cls, type) and issubclass(optimizer_cls, (AdamW, Adam)))
if is_adam_family:
kwargs['foreach'] = FalseThere was a problem hiding this comment.
Code Review
This pull request introduces significant functionality by adding support for Expert Parallelism (EP) for Mixture-of-Experts (MoE) models, with a focus on compatibility with FSDP. The changes are comprehensive, including a new NativeFSDPStrategy, a new example script for Qwen3-MoE, and modifications to TransformersModel to handle the distributed setup. The implementation also includes a custom gradient clipping utility for mixed DTensor/Tensor scenarios and proper mixed-precision handling within expert execution. The code is well-structured and addresses the complexities of combining these parallelism techniques. My review includes a couple of suggestions to improve code robustness and efficiency.
| except Exception: | ||
| pass |
There was a problem hiding this comment.
| decay_param_names = [ | ||
| n for n, p in params.items() if (n in decay_parameters and p.requires_grad) | ||
| ] | ||
| no_decay_param_names = [ | ||
| n for n, p in params.items() if (n not in decay_parameters and p.requires_grad) | ||
| ] | ||
| optimizer_grouped_parameters = [ | ||
| { | ||
| "params": [ | ||
| p for n, p in params.items() if (n in decay_parameters and p.requires_grad) | ||
| params[n] for n in decay_param_names | ||
| ], | ||
| "param_names": decay_param_names, | ||
| "weight_decay": weight_decay, 'lr': lr | ||
| }, | ||
| { | ||
| "params": [ | ||
| p for n, p in params.items() if (n not in decay_parameters and p.requires_grad) | ||
| params[n] for n in no_decay_param_names | ||
| ], | ||
| "param_names": no_decay_param_names, | ||
| "weight_decay": 0.0, 'lr': lr | ||
| }, | ||
| ] |
There was a problem hiding this comment.
The current implementation iterates over the parameters multiple times to group them for the optimizer. This can be made more efficient and readable by using a single loop to build the parameter groups.
decay_params = []
decay_param_names = []
no_decay_params = []
no_decay_param_names = []
for n, p in params.items():
if not p.requires_grad:
continue
if n in decay_parameters:
decay_params.append(p)
decay_param_names.append(n)
else:
no_decay_params.append(p)
no_decay_param_names.append(n)
optimizer_grouped_parameters = [
{
"params": decay_params,
"param_names": decay_param_names,
"weight_decay": weight_decay, 'lr': lr
},
{
"params": no_decay_params,
"param_names": no_decay_param_names,
"weight_decay": 0.0, 'lr': lr
},
]
No description provided.