Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def forward(
rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
attn_mask: Optional[torch.Tensor] = None,
attn_kwargs: Optional[Dict[str, Any]] = None,
modulate_index: Optional[List[int]] = None,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

While adding modulate_index fixes the API signature mismatch, it is currently unused within the forward method. This parameter is crucial for features like zero_cond_t, and ignoring it can lead to incorrect behavior.

To properly fix this, modulate_index should be passed to the _modulate calls within this method. This will also require updating _modulate to accept and handle this new parameter, likely by adapting the logic from the base QwenImageTransformerBlock's _modulate method.

) -> Tuple[torch.Tensor, torch.Tensor]:
if self.use_nunchaku_awq:
img_mod_params = self.img_mod(temb) # [B, 6*dim]
Expand Down