Conversation
Summary of ChangesHello @AlpinDale, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces comprehensive support for the BailingMoe V2 model architecture. It refactors the existing BailingMoe implementation to incorporate advanced features such as configurable query-key normalization in attention, a more flexible Mixture-of-Experts (MoE) routing mechanism with various scoring functions and expert grouping options, and a hybrid MLP structure that allows for dense MLPs in early layers and MoE MLPs in later layers. These changes aim to enhance the model's performance and configurability, ensuring compatibility with the new V2 specification. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request adds support for BailingMoe V2. The changes include adding new features like QK normalization, a more flexible MoE router, and conditional use of dense MLP layers. Several bug fixes and robustness improvements have also been made, such as correcting weight loading for tied embeddings and handling of None residuals in normalization layers. The code quality is good. I have one suggestion to refactor a part of the MoE forward pass for better readability and efficiency.
| if self.shared_experts: | ||
| shared_output = self.shared_experts(hidden_states) | ||
| # router_logits: (num_tokens, n_experts) | ||
| router_logits, _ = self.gate(hidden_states) | ||
| router_logits = self.gate(hidden_states.to(self.router_dtype)) | ||
| router_logits = router_logits.to(hidden_states.dtype) | ||
| final_hidden_states = self.experts(hidden_states=hidden_states, | ||
| router_logits=router_logits) | ||
|
|
||
| if self.num_shared_experts > 0: | ||
| final_hidden_states *= self.routed_scaling_factor | ||
|
|
||
| if self.shared_experts: | ||
| final_hidden_states = final_hidden_states + shared_output |
There was a problem hiding this comment.
The shared_output is computed at the beginning of the forward pass but used only at the end. This can be inefficient if self.shared_experts(hidden_states) is a costly operation. Additionally, the if self.shared_experts: check is performed twice.
For better readability and efficiency, it's better to compute shared_output just before it's used and combine the logic within a single conditional block.
| if self.shared_experts: | |
| shared_output = self.shared_experts(hidden_states) | |
| # router_logits: (num_tokens, n_experts) | |
| router_logits, _ = self.gate(hidden_states) | |
| router_logits = self.gate(hidden_states.to(self.router_dtype)) | |
| router_logits = router_logits.to(hidden_states.dtype) | |
| final_hidden_states = self.experts(hidden_states=hidden_states, | |
| router_logits=router_logits) | |
| if self.num_shared_experts > 0: | |
| final_hidden_states *= self.routed_scaling_factor | |
| if self.shared_experts: | |
| final_hidden_states = final_hidden_states + shared_output | |
| # router_logits: (num_tokens, n_experts) | |
| router_logits = self.gate(hidden_states.to(self.router_dtype)) | |
| router_logits = router_logits.to(hidden_states.dtype) | |
| final_hidden_states = self.experts(hidden_states=hidden_states, | |
| router_logits=router_logits) | |
| final_hidden_states *= self.routed_scaling_factor | |
| if self.shared_experts: | |
| shared_output = self.shared_experts(hidden_states) | |
| final_hidden_states = final_hidden_states + shared_output |
No description provided.