Skip to content

Conversation

@qzzz95
Copy link
Contributor

@qzzz95 qzzz95 commented Nov 17, 2025

No description provided.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @qzzz95, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant refactoring to the model initialization process for both the image and text encoders. The primary goal is to optimize memory usage and potentially improve loading times by ensuring that model parameters are allocated directly on the target computational device (e.g., GPU) from the outset, thereby bypassing an intermediate allocation and transfer from the CPU.

Highlights

  • Direct Device Initialization for Image Encoder: The WanImageEncoder now initializes its internal CLIP model directly on the specified device (e.g., GPU) by passing the device argument to the clip_xlm_roberta_vit_h_14 function, rather than defaulting to CPU.
  • Device Parameter Propagation in Text Encoder: Multiple components within the WanTextEncoder, including T5Attention, T5FeedForward, T5SelfAttention, T5RelativeEmbedding, and the main T5Encoder's token and positional embeddings, have been updated to accept and utilize a device parameter during their initialization. This ensures that their internal nn.Linear and nn.Embedding layers are created directly on the target device.
  • Non-blocking Transfer Adjustment: The non_blocking parameter in the model.to() call within T5Encoder.from_state_dict has been changed from True to False, which might affect how model state is transferred to the target device during loading.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request updates the model initialization to allow specifying a device, which is a great improvement for workflows that use device='meta' to avoid unnecessary weight initialization. The changes are generally in the right direction.

However, I've identified a couple of issues:

  • The custom T5LayerNorm module is not being initialized on the specified device. This will likely cause errors when device='meta' is used. I've left comments pointing out the specific locations.
  • There's an inconsistency in the use of non_blocking when moving models to a device. The change to non_blocking=False is good for safety, and I'd recommend applying it consistently across the codebase.

Addressing these points will make the implementation more robust and consistent.


class T5SelfAttention(nn.Module):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos=True, dropout=0.0):
def __init__(self, dim, dim_attn, dim_ffn, num_heads, num_buckets, device, shared_pos=True, dropout=0.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While you've correctly passed the device parameter to T5Attention, T5FeedForward, and T5RelativeEmbedding, the T5LayerNorm modules (self.norm1 and self.norm2, initialized on lines 126 and 128) are still initialized without the device parameter. This will cause issues when using device='meta', as their weights will be created on a concrete device instead of as meta tensors.

To fix this, you should probably update T5LayerNorm to accept a device parameter and use it during nn.Parameter creation, then pass the device during instantiation here.

self.blocks = nn.ModuleList(
[
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, shared_pos, dropout)
T5SelfAttention(dim, dim_attn, dim_ffn, num_heads, num_buckets, device, shared_pos, dropout)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Similar to my other comment, self.norm (a T5LayerNorm instance initialized on line 273) is created without the device parameter. This should also be updated to ensure consistent device placement, especially for device='meta', to avoid materializing tensors when it's not intended.

model = model.requires_grad_(False)
model.load_state_dict(state_dict, assign=True)
model.to(device=device, dtype=dtype, non_blocking=True)
model.to(device=device, dtype=dtype, non_blocking=False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Changing non_blocking=True to False is a good defensive move to ensure the model is fully moved to the target device before proceeding. For consistency, you might want to consider applying this change to WanImageEncoder.from_state_dict and the base PreTrainedModel.from_state_dict as well, as they currently use non_blocking=True.


class T5Attention(nn.Module):
def __init__(self, dim, dim_attn, num_heads, dropout=0.0):
def __init__(self, dim, dim_attn, num_heads, device, dropout=0.0):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model类的device都加个"cuda:0"的默认值然后挪到参数表后面吧

@qzzz95 qzzz95 force-pushed the dev/qz/set_module_device branch from 055b3c5 to e6b3020 Compare November 17, 2025 10:22
@akaitsuki-ii akaitsuki-ii merged commit 93c34f3 into main Nov 18, 2025
@akaitsuki-ii akaitsuki-ii deleted the dev/qz/set_module_device branch November 18, 2025 07:22
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants