Conversation
Summary of ChangesHello @Artiprocher, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly expands the model capabilities by integrating the Flux.2-klein architecture. It provides all necessary components, from model definitions and configurations to a new pipeline, enabling advanced image generation and editing features within the system. The changes ensure that the new models are properly structured, configured, and tested for immediate use. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces support for Flux.2-klein models, including new pipeline configurations, model definitions for DiT and VAE, and associated utility functions. The changes involve adding new dataclasses for Flux2 pipeline configuration and state dictionaries, implementing the Flux2DiT and Flux2VAE models, and creating a new Flux2Pipeline class to orchestrate the image generation process. Several helper functions for LoRA conversion, prompt encoding, and latent preparation are also included. The new functionality appears to be well-structured, but there are several areas where correctness and maintainability could be improved, particularly concerning batched inference handling, hardcoded parameters, and potential precision issues.
| if edit_latents is not None: | ||
| noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]] |
There was a problem hiding this comment.
The slicing noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]] is incorrect. At this point, image_ids.shape[1] refers to the concatenated length (original + edit). Subtracting edit_image_ids.shape[1] from it will result in the original image sequence length, which is correct. However, the model_fn_flux2 function already handles this slicing internally. Performing it again here would lead to double slicing or incorrect results. This is a critical correctness issue.
# Return only the original image sequence length if edit images were used
# This slicing is handled by model_fn_flux2, so it's not needed here.
# if edit_latents is not None:
# noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]]
return noise_pred| if edit_latents is not None and edit_image_ids is not None: | ||
| latents = torch.concat([latents, edit_latents], dim=1) | ||
| image_ids = torch.concat([image_ids, edit_image_ids], dim=1) |
There was a problem hiding this comment.
The concatenation of latents and image_ids with edit_latents and edit_image_ids occurs unconditionally within the predict_noise function. The check if edit_latents is not None and edit_image_ids is not None: should encapsulate this block to prevent errors if edit_latents or edit_image_ids are None.
| if edit_latents is not None and edit_image_ids is not None: | |
| latents = torch.concat([latents, edit_latents], dim=1) | |
| image_ids = torch.concat([image_ids, edit_image_ids], dim=1) | |
| # Handle edit images by concatenating latents and image IDs | |
| if edit_latents is not None and edit_image_ids is not None: | |
| latents = torch.concat([latents, edit_latents], dim=1) | |
| image_ids = torch.concat([image_ids, edit_image_ids], dim=1) | |
| embedded_guidance_tensor = torch.tensor([embedded_guidance], device=latents.device) | |
| noise_pred = self.dit( | |
| hidden_states=latents, | |
| timestep=timestep / 1000, | |
| guidance=embedded_guidance_tensor, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=image_ids, | |
| ) |
| image_seq_len = latents.shape[1] | ||
| if edit_latents is not None: | ||
| image_seq_len = latents.shape[1] | ||
| latents = torch.concat([latents, edit_latents], dim=1) | ||
| image_ids = torch.concat([image_ids, edit_image_ids], dim=1) |
There was a problem hiding this comment.
The image_seq_len is initialized with latents.shape[1]. If edit_latents is not None, latents is concatenated with edit_latents, but image_seq_len is not updated to reflect the new combined length. This means the subsequent slicing model_output[:, :image_seq_len] will incorrectly truncate the output, potentially discarding the contribution from edit_latents or causing misalignment. This is a critical correctness issue.
| image_seq_len = latents.shape[1] | |
| if edit_latents is not None: | |
| image_seq_len = latents.shape[1] | |
| latents = torch.concat([latents, edit_latents], dim=1) | |
| image_ids = torch.concat([image_ids, edit_image_ids], dim=1) | |
| image_seq_len_original = latents.shape[1] | |
| if edit_latents is not None: | |
| latents = torch.concat([latents, edit_latents], dim=1) | |
| image_ids = torch.concat([image_ids, edit_image_ids], dim=1) | |
| embedded_guidance = torch.tensor([embedded_guidance], device=latents.device) | |
| model_output = dit( | |
| hidden_states=latents, | |
| timestep=timestep / 1000, | |
| guidance=embedded_guidance, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=image_ids, | |
| use_gradient_checkpointing=use_gradient_checkpointing, | |
| use_gradient_checkpointing_offload=use_gradient_checkpointing_offload, | |
| ) | |
| model_output = model_output[:, :image_seq_len_original] |
| # NOTE: the below logic means that we can't support batched inference with images of different resolutions or | ||
| # text prompts of differents lengths. Is this a use case we want to support? | ||
| if img_ids.ndim == 3: | ||
| img_ids = img_ids[0] | ||
| if txt_ids.ndim == 3: | ||
| txt_ids = txt_ids[0] |
There was a problem hiding this comment.
The note correctly identifies a limitation regarding batched inference with varying image resolutions or text prompt lengths. However, the current implementation if img_ids.ndim == 3: img_ids = img_ids[0] and if txt_ids.ndim == 3: txt_ids = txt_ids[0] will incorrectly handle batched inputs where B > 1 by only taking the first element. This is a critical correctness issue if batched inference is expected to work with these inputs. It should either raise an error for unsupported batch sizes or be adapted to handle them correctly.
| return b | ||
|
|
||
| def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor: | ||
| r"""Encode a batch of images using a tiled encoder. |
| t = torch.arange(1) # [0] - time dimension | ||
| h = torch.arange(height) | ||
| w = torch.arange(width) | ||
| l = torch.arange(1) # [0] - layer dimension |
There was a problem hiding this comment.
| self.heads = out_dim // dim_head if out_dim is not None else heads | ||
|
|
There was a problem hiding this comment.
Similar to Flux2Attention, the calculation of self.heads here (out_dim // dim_head) might lead to inconsistencies if out_dim is provided and out_dim // dim_head does not equal the heads argument passed to the constructor. Ensure consistency between heads and out_dim or clarify the intended behavior.
| self.heads = out_dim // dim_head if out_dim is not None else heads | ||
|
|
There was a problem hiding this comment.
The calculation of self.heads here (out_dim // dim_head) might lead to inconsistencies if out_dim is provided and out_dim // dim_head does not equal the heads argument passed to the constructor. It's generally better to ensure self.heads is directly derived from the heads argument or that out_dim is consistent with heads * dim_head.
|
|
||
| def edit_image_auto_resize(self, edit_image): | ||
| """Auto resize edit image to optimal dimensions""" | ||
| calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1]) |
| return image | ||
|
|
||
| def process_edit_image_ids(self, image_latents, scale=10): | ||
| """Process image IDs for edit images""" |
| ): | ||
| image_seq_len = latents.shape[1] | ||
| if edit_latents is not None: | ||
| image_seq_len = latents.shape[1] |
| ) | ||
|
|
||
| if config.vae_path is None: | ||
| config.vae_path = fetch_model(config.model_path, path="vae/diffusion_pytorch_model.safetensors") |
There was a problem hiding this comment.
这里指定为 modelscope 上一个具体的 repo_id 吧。现在的写法是假定 model_path 是一个目录,且目录下得有 vae 的子目录。这种隐式的假定让接口比较难调用。参考 QwenImagePipeline
| vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype) | ||
|
|
||
| if config.encoder_path is None: | ||
| config.encoder_path = fetch_model(config.model_path, path="text_encoder") |
| ) | ||
|
|
||
| if config.vae_path is None: | ||
| config.vae_path = fetch_model(config.model_path, path="vae/diffusion_pytorch_model.safetensors") |
There was a problem hiding this comment.
参照 QwenImagePipeline,写死成 modelscope 上的 repo_id。现在是隐式约定 config.model_path 得是目录,且下面得有 vae, text_encoder 子目录
| vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype) | ||
|
|
||
| if config.encoder_path is None: | ||
| if config.model_size == "4B": |
There was a problem hiding this comment.
model_size 类型改成 enum 吧,不然得看 pipeline 代码才能知道该填什么值了
No description provided.