Skip to content

Expert parallel support twinkle.TransfomersModel#23

Merged
kevssim merged 4 commits intodevfrom
expert_parallel
Feb 3, 2026
Merged

Expert parallel support twinkle.TransfomersModel#23
kevssim merged 4 commits intodevfrom
expert_parallel

Conversation

@kevssim
Copy link
Copy Markdown
Collaborator

@kevssim kevssim commented Feb 2, 2026

No description provided.

@kevssim kevssim marked this pull request as draft February 2, 2026 10:09
@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @kevssim, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the distributed training capabilities of twinkle.TransformersModel by integrating robust Expert Parallel (EP) support. The changes enable efficient training of Mixture-of-Experts (MoE) models by properly handling mixed precision for experts, introducing a specialized gradient clipping utility for distributed environments, and dynamically configuring the training strategy and optimizers to work seamlessly with EP. A new example script for Qwen3-MoE demonstrates these new features.

Highlights

  • Expert Parallel (EP) Integration: Introduced comprehensive support for Expert Parallelism within the twinkle.TransformersModel framework, allowing for efficient distributed training of Mixture-of-Experts (MoE) models, particularly when combined with FSDP.
  • Mixed Precision Handling for Experts: Enhanced the expert_parallel.py module to correctly handle mixed precision operations for both shared and individual experts, ensuring proper data type conversions during forward passes.
  • Advanced Gradient Clipping: Added a new utility for gradient clipping, normalize_and_clip_grad_norm, which specifically addresses scenarios involving mixed DTensor and regular Tensor gradients in distributed training.
  • Dynamic Strategy Selection and Initialization: Implemented logic within TransformersModel to dynamically select NativeFSDPStrategy when Expert Parallel is enabled, and to ensure proper distributed environment initialization and placement of experts on local devices.
  • Optimizer Configuration for EP: Automated the setting of foreach=False for Adam and AdamW optimizers when Expert Parallel is active, mitigating potential DTensor mixed-type errors during EP runs.
  • Qwen3-MoE Example: Provided a new example script demonstrating the usage of Expert Parallel with FSDP for the Qwen3-30B-A3B-Instruct-2507 model, showcasing the new capabilities.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@kevssim
Copy link
Copy Markdown
Collaborator Author

kevssim commented Feb 2, 2026

/gemini review

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Expert Parallelism (EP) in twinkle.TransformersModel, primarily for Mixture-of-Experts (MoE) models, and integrates it with FSDP. The changes are extensive, including a new example script for Qwen3-MoE, modifications to the core EP logic to handle mixed precision, a new gradient clipping utility for mixed DTensor/Tensor environments, and substantial updates to TransformersModel and NativeFSDPStrategy to manage the EP lifecycle. My review has identified a critical bug in the distributed initialization logic that could cause a hang in multi-process scenarios. Additionally, I've provided suggestions to enhance performance by minimizing CPU-GPU synchronization, refactor duplicated code for better maintainability, and make minor improvements to code efficiency and style. Overall, this is a significant feature addition, and addressing the identified issues will improve its robustness and performance.

Comment on lines +261 to +267
if rank < 0:
rank = 0
if world_size <= 1:
raise RuntimeError(
"EP+FSDP requires distributed launch with WORLD_SIZE>1 "
"and initialized rank env vars."
)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a critical issue in the distributed initialization logic. The current order of checks can lead to incorrect behavior. If world_size > 1 but RANK is not set, rank is defaulted to 0, causing all processes to identify as rank 0. You should first check for a valid distributed environment (world_size > 1) and then ensure rank is valid.

Suggested change
if rank < 0:
rank = 0
if world_size <= 1:
raise RuntimeError(
"EP+FSDP requires distributed launch with WORLD_SIZE>1 "
"and initialized rank env vars."
)
if world_size <= 1:
raise RuntimeError(
"EP+FSDP requires distributed launch with WORLD_SIZE>1 "
"and initialized rank env vars."
)
if rank < 0:
raise RuntimeError(
"EP+FSDP requires the RANK environment variable to be set for distributed launch."
)

Comment on lines +59 to +60
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a bare except Exception: pass is generally discouraged as it can hide unexpected errors. It's better to catch a more specific exception, such as ImportError if you're concerned about the module not being available, or at least log the exception for debugging purposes.

Suggested change
except Exception:
pass
except ImportError:
pass

Comment on lines +341 to +346
compute_dtype = _module_compute_dtype(expert, input_dtype)
if compute_dtype != input_dtype:
expert_in = expert_in.to(compute_dtype)
out = expert(expert_in)
if out.dtype != input_dtype:
out = out.to(input_dtype)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This block of code for handling dtype casting appears to be duplicated in _maybe_run_shared_expert (lines 313-318). To improve maintainability and reduce code duplication, consider extracting this logic into a helper function. For example:

def _run_with_casting(module: nn.Module, inputs: torch.Tensor) -> torch.Tensor:
    input_dtype = inputs.dtype
    compute_dtype = _module_compute_dtype(module, input_dtype)
    if compute_dtype != input_dtype:
        inputs = inputs.to(compute_dtype)
    out = module(inputs)
    if out.dtype != input_dtype:
        out = out.to(input_dtype)
    return out

Then you could simplify this part to return _run_with_casting(expert, expert_in) and similarly refactor _maybe_run_shared_expert.

Comment on lines +67 to +72
local_norm = 0.0
for grad in grads:
local_grad = _local_grad(grad)
if local_grad.numel() == 0:
continue
local_norm = max(local_norm, local_grad.detach().abs().max().item())
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling .item() inside a loop over gradients can introduce multiple CPU-GPU synchronization points, which may impact performance, especially with a large number of parameters. To optimize, you could accumulate the maximum values as tensors on the device and perform a single reduction at the end of the loop before calling .item().

For example, you could create a list of tensor max values, then use torch.stack(list).max() to get the final local max before reducing across processes. This avoids CPU-GPU syncs within the loop. A similar optimization can be applied to the L2 norm calculation on lines 78-83.

Comment on lines +531 to +536
decay_param_names = [
n for n, p in params.items() if (n in decay_parameters and p.requires_grad)
]
no_decay_param_names = [
n for n, p in params.items() if (n not in decay_parameters and p.requires_grad)
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve efficiency slightly, you can determine the decay_param_names and no_decay_param_names in a single pass over the params dictionary instead of two separate list comprehensions.

        decay_param_names = []
        no_decay_param_names = []
        for n, p in params.items():
            if p.requires_grad:
                if n in decay_parameters:
                    decay_param_names.append(n)
                else:
                    no_decay_param_names.append(n)

Comment on lines +672 to +675
if optimizer_cls in ('AdamW', 'Adam'):
kwargs['foreach'] = False
elif isinstance(optimizer_cls, type) and issubclass(optimizer_cls, (AdamW, Adam)):
kwargs['foreach'] = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to check if optimizer_cls is Adam or AdamW is a bit repetitive. You could consolidate this check to improve readability.

            is_adam_family = optimizer_cls in ('AdamW', 'Adam') or \
                (isinstance(optimizer_cls, type) and issubclass(optimizer_cls, (AdamW, Adam)))
            if is_adam_family:
                kwargs['foreach'] = False

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces significant functionality by adding support for Expert Parallelism (EP) for Mixture-of-Experts (MoE) models, with a focus on compatibility with FSDP. The changes are comprehensive, including a new NativeFSDPStrategy, a new example script for Qwen3-MoE, and modifications to TransformersModel to handle the distributed setup. The implementation also includes a custom gradient clipping utility for mixed DTensor/Tensor scenarios and proper mixed-precision handling within expert execution. The code is well-structured and addresses the complexities of combining these parallelism techniques. My review includes a couple of suggestions to improve code robustness and efficiency.

Comment on lines +59 to +60
except Exception:
pass
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using a broad except Exception: can hide unexpected errors. It's better to catch more specific exceptions that you anticipate, such as ImportError or AttributeError, to make the code more robust and easier to debug.

Suggested change
except Exception:
pass
except (ImportError, AttributeError):
pass

Comment on lines +531 to 552
decay_param_names = [
n for n, p in params.items() if (n in decay_parameters and p.requires_grad)
]
no_decay_param_names = [
n for n, p in params.items() if (n not in decay_parameters and p.requires_grad)
]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in params.items() if (n in decay_parameters and p.requires_grad)
params[n] for n in decay_param_names
],
"param_names": decay_param_names,
"weight_decay": weight_decay, 'lr': lr
},
{
"params": [
p for n, p in params.items() if (n not in decay_parameters and p.requires_grad)
params[n] for n in no_decay_param_names
],
"param_names": no_decay_param_names,
"weight_decay": 0.0, 'lr': lr
},
]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation iterates over the parameters multiple times to group them for the optimizer. This can be made more efficient and readable by using a single loop to build the parameter groups.

        decay_params = []
        decay_param_names = []
        no_decay_params = []
        no_decay_param_names = []
        for n, p in params.items():
            if not p.requires_grad:
                continue
            if n in decay_parameters:
                decay_params.append(p)
                decay_param_names.append(n)
            else:
                no_decay_params.append(p)
                no_decay_param_names.append(n)

        optimizer_grouped_parameters = [
            {
                "params": decay_params,
                "param_names": decay_param_names,
                "weight_decay": weight_decay, 'lr': lr
            },
            {
                "params": no_decay_params,
                "param_names": no_decay_param_names,
                "weight_decay": 0.0, 'lr': lr
            },
        ]

@kevssim kevssim marked this pull request as ready for review February 3, 2026 03:27
@kevssim kevssim merged commit 5cda1af into dev Feb 3, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant