-
Notifications
You must be signed in to change notification settings - Fork 33
set module device to skip weight init #207
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Summary of ChangesHello @qzzz95, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a significant refactoring to the model initialization process for both the image and text encoders. The primary goal is to optimize memory usage and potentially improve loading times by ensuring that model parameters are allocated directly on the target computational device (e.g., GPU) from the outset, thereby bypassing an intermediate allocation and transfer from the CPU. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request updates the model initialization to allow specifying a device, which is a great improvement for workflows that use device='meta' to avoid unnecessary weight initialization. The changes are generally in the right direction.
However, I've identified a couple of issues:
- The custom
T5LayerNormmodule is not being initialized on the specified device. This will likely cause errors whendevice='meta'is used. I've left comments pointing out the specific locations. - There's an inconsistency in the use of
non_blockingwhen moving models to a device. The change tonon_blocking=Falseis good for safety, and I'd recommend applying it consistently across the codebase.
Addressing these points will make the implementation more robust and consistent.
|
|
||
| class T5SelfAttention(nn.Module): | ||
| def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0): | ||
| def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, device, shared_pos=True, dropout=0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
While you've correctly passed the device parameter to T5Attention, T5FeedForward, and T5RelativeEmbedding, the T5LayerNorm modules (self.norm1 and self.norm2, initialized on lines 126 and 128) are still initialized without the device parameter. This will cause issues when using device='meta', as their weights will be created on a concrete device instead of as meta tensors.
To fix this, you should probably update T5LayerNorm to accept a device parameter and use it during nn.Parameter creation, then pass the device during instantiation here.
| self.blocks = nn.ModuleList( | ||
| [ | ||
| T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout) | ||
| T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, device, shared_pos, dropout) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| model = model.requires_grad_(False) | ||
| model.load_state_dict(state_dict, assign=True) | ||
| model.to(device=device, dtype=dtype, non_blocking=True) | ||
| model.to(device=device, dtype=dtype, non_blocking=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing non_blocking=True to False is a good defensive move to ensure the model is fully moved to the target device before proceeding. For consistency, you might want to consider applying this change to WanImageEncoder.from_state_dict and the base PreTrainedModel.from_state_dict as well, as they currently use non_blocking=True.
|
|
||
| class T5Attention(nn.Module): | ||
| def __init__(self, dim, dim_attn, num_heads, dropout=0.0): | ||
| def __init__(self, dim, dim_attn, num_heads, device, dropout=0.0): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model类的device都加个"cuda:0"的默认值然后挪到参数表后面吧
055b3c5 to
e6b3020
Compare
No description provided.