-
Notifications
You must be signed in to change notification settings - Fork 25
Feature/qwen image edit #152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Glaceon-Hyy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces image editing capabilities to the Qwen model within the diffsynth_engine
. It integrates a new image processor and updates the Qwen model and pipeline to handle image inputs for editing, alongside general improvements to attention mechanisms and model loading.
Highlights
- Qwen Image Editing Feature: This PR introduces core functionality for image editing by adding an 'edit' input to the QwenImageDiT model and a new 'encode_prompt_with_image' method in the QwenImagePipeline, enabling image-to-image transformations based on text prompts.
- New Image Processing Pipeline: New
Qwen2VLImageProcessor
andQwen2VLProcessor
classes are added to handle comprehensive image preprocessing (resizing, normalization, patching) and to facilitate multimodal input (text + image) for Qwen-VL models. - Attention Mechanism Enhancements: The attention module is enhanced to support Scaled Dot Product Attention (SDPA) with causal masking, leading to improved efficiency and flexibility in model computations.
- Model Loading and Device Management Refinements: Rotary embedding initialization and general model loading procedures have been refactored to improve device handling and simplify the transfer of parameters, contributing to more robust model management.
- LoRA Compatibility Improvements: Compatibility for LoRA (Low-Rank Adaptation) state dict conversion is improved by supporting additional naming conventions for LoRA weights, making the system more flexible for various LoRA models.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new image editing feature using the Qwen model. The changes are extensive, including new model configurations, a new image processor, and updates to the pipeline to support image-to-image tasks. The implementation looks solid, with several good refactorings like improving device handling in embedding layers and using model.to()
for moving models. My review focuses on a few areas that could be improved for correctness, maintainability, and clarity. I've identified a potential regression in the model's generation capability, some magic numbers that should be constants, a confusing code block that could be simplified, and a commented-out test assertion that needs attention.
if position_ids is None: | ||
assert attention_mask is None or attention_mask.ndim == 2, "attention mask must be 2D" | ||
# calculate RoPE index once per generation in the pre-fill stage only | ||
if (cache_position is not None and cache_position[0] == 0) or self.rope_deltas is None: | ||
position_ids, rope_deltas = self.get_rope_index( | ||
input_ids, | ||
image_grid_thw, | ||
video_grid_thw, | ||
second_per_grid_ts, | ||
attention_mask, | ||
) | ||
self.rope_deltas = rope_deltas | ||
# then use the prev pre-calculated rope-deltas to get the correct position ids | ||
else: | ||
batch_size, seq_length, _ = inputs_embeds.shape | ||
delta = ( | ||
(cache_position[0] + self.rope_deltas).to(inputs_embeds.device) if cache_position is not None else 0 | ||
) | ||
position_ids = torch.arange(seq_length, device=inputs_embeds.device) | ||
position_ids = position_ids.view(1, -1).expand(batch_size, -1) | ||
if cache_position is not None: # otherwise `deltas` is an int `0` | ||
delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0) | ||
position_ids = position_ids.add(delta) | ||
position_ids = position_ids.unsqueeze(0).expand(3, -1, -1) | ||
position_ids, rope_deltas = self.get_rope_index( | ||
input_ids, | ||
image_grid_thw, | ||
video_grid_thw, | ||
second_per_grid_ts, | ||
attention_mask, | ||
) | ||
self.rope_deltas = rope_deltas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic for handling subsequent generation steps (incremental decoding) has been removed. The previous implementation had an else
block to calculate position_ids
when cache_position
was not 0. This change limits the model to only perform pre-fill (single forward pass), which might break any autoregressive text generation capabilities. If this model is intended to be used for multi-step generation, this is a significant regression. Was this intentional?
pos_index = torch.arange(4096) | ||
neg_index = torch.arange(4096).flip(0) * -1 - 1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
""" | ||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video | ||
Args: txt_length: an integer representing the length of text | ||
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args: | ||
txt_length: [bs] a list of 1 integers representing the length of the text | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docstring for the forward
method is malformed and the description for txt_seq_lens
(previously txt_length
) is confusing. Please update it for clarity.
"""
Args:
video_fhw (List[Tuple[int, int, int]]): A list of (frame, height, width) tuples for each video/image in the batch.
txt_seq_lens (int): The maximum length of the text sequences in the batch.
device (str): The device to place tensors on.
"""
if input_image is not None: | ||
width, height = input_image.size | ||
width, height = self.calculate_dimensions(1024 * 1024, width / height) | ||
input_image = input_image.resize((width, height), Image.LANCZOS) | ||
|
||
self.validate_image_size(height, width, minimum=64, multiple_of=16) | ||
|
||
noise = self.generate_noise((1, 16, height // 8, width // 8), seed=seed, device="cpu", dtype=self.dtype).to( | ||
device=self.device | ||
) | ||
# dynamic shift | ||
image_seq_len = math.ceil(height // 16) * math.ceil(width // 16) | ||
mu = calculate_shift(image_seq_len, max_shift=0.9, max_seq_len=8192) | ||
if input_image: | ||
image_latents = self.prepare_image_latents(input_image) | ||
else: | ||
image_latents = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This block has a couple of areas for improvement:
- At line 480,
1024 * 1024
is a magic number. It should be defined as a named constant for readability and maintainability. - The logic for initializing
image_latents
(lines 491-494) is confusing. Theif input_image:
check is redundant within theif input_image is not None:
block, making theelse
branch unreachable.
Consider refactoring this section for clarity. For example:
TARGET_PIXEL_AREA = 1024 * 1024
image_latents = None
if input_image is not None:
width, height = input_image.size
width, height = self.calculate_dimensions(TARGET_PIXEL_AREA, width / height)
input_image = input_image.resize((width, height), Image.LANCZOS)
image_latents = self.prepare_image_latents(input_image)
self.validate_image_size(height, width, minimum=64, multiple_of=16)
self.assertTensorEqual(logits, expected) | ||
expected = vlm(input_ids=input_ids, attention_mask=attention_mask).logits.cpu() | ||
text_logits = self.encoder(input_ids=input_ids, attention_mask=attention_mask)["logits"].cpu() | ||
# self.assertTensorEqual(text_logits, expected) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assertion is commented out. This often indicates a test that is failing or a known discrepancy that is being ignored. Please either fix the underlying issue and re-enable the assertion, or add a comment explaining why this check is disabled. Leaving commented-out assertions can hide bugs and make the test suite less reliable.
return freqs | ||
|
||
def forward(self, video_fhw, txt_length, device): | ||
def forward(self, video_fhw, txt_seq_lens, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我们已经把取max length的操作放在外面了,这个参数保留现状就行
self.axes_dim = axes_dim | ||
pos_index = torch.arange(10000) | ||
neg_index = torch.arange(10000).flip(0) * -1 - 1 | ||
pos_index = torch.arange(4096) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里改回10000吧
No description provided.