Skip to content

support flux.2-klein#227

Merged
qzzz95 merged 5 commits intomainfrom
flux-klein
Jan 23, 2026
Merged

support flux.2-klein#227
qzzz95 merged 5 commits intomainfrom
flux-klein

Conversation

@Artiprocher
Copy link
Copy Markdown
Contributor

No description provided.

@gemini-code-assist
Copy link
Copy Markdown
Contributor

Summary of Changes

Hello @Artiprocher, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly expands the model capabilities by integrating the Flux.2-klein architecture. It provides all necessary components, from model definitions and configurations to a new pipeline, enabling advanced image generation and editing features within the system. The changes ensure that the new models are properly structured, configured, and tested for immediate use.

Highlights

  • New Model Integration: Introduced comprehensive support for the Flux.2-klein model, including its Diffusion Transformer (DiT) and VAE components.
  • Pipeline and Configuration: Added a dedicated Flux2Pipeline for image generation and editing, along with new configuration classes (Flux2PipelineConfig) and state dicts (Flux2StateDicts).
  • Model Architecture Definitions: Implemented the Flux2DiT and Flux2VAE model architectures, including their respective sub-modules like attention mechanisms, feedforward networks, and positional embeddings.
  • Testing and Validation: Included new unit tests for the Flux2Pipeline across different model sizes (4B and 9B, including base versions) to ensure functionality for text-to-image generation and image editing.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for Flux.2-klein models, including new pipeline configurations, model definitions for DiT and VAE, and associated utility functions. The changes involve adding new dataclasses for Flux2 pipeline configuration and state dictionaries, implementing the Flux2DiT and Flux2VAE models, and creating a new Flux2Pipeline class to orchestrate the image generation process. Several helper functions for LoRA conversion, prompt encoding, and latent preparation are also included. The new functionality appears to be well-structured, but there are several areas where correctness and maintainability could be improved, particularly concerning batched inference handling, hardcoded parameters, and potential precision issues.

Comment on lines +451 to +452
if edit_latents is not None:
noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The slicing noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]] is incorrect. At this point, image_ids.shape[1] refers to the concatenated length (original + edit). Subtracting edit_image_ids.shape[1] from it will result in the original image sequence length, which is correct. However, the model_fn_flux2 function already handles this slicing internally. Performing it again here would lead to double slicing or incorrect results. This is a critical correctness issue.

        # Return only the original image sequence length if edit images were used
        # This slicing is handled by model_fn_flux2, so it's not needed here.
        # if edit_latents is not None:
        #     noise_pred = noise_pred[:, :image_ids.shape[1] - edit_image_ids.shape[1]]
            
        return noise_pred

Comment on lines +435 to +437
if edit_latents is not None and edit_image_ids is not None:
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The concatenation of latents and image_ids with edit_latents and edit_image_ids occurs unconditionally within the predict_noise function. The check if edit_latents is not None and edit_image_ids is not None: should encapsulate this block to prevent errors if edit_latents or edit_image_ids are None.

Suggested change
if edit_latents is not None and edit_image_ids is not None:
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
# Handle edit images by concatenating latents and image IDs
if edit_latents is not None and edit_image_ids is not None:
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance_tensor = torch.tensor([embedded_guidance], device=latents.device)
noise_pred = self.dit(
hidden_states=latents,
timestep=timestep / 1000,
guidance=embedded_guidance_tensor,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=image_ids,
)

Comment on lines +104 to +108
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The image_seq_len is initialized with latents.shape[1]. If edit_latents is not None, latents is concatenated with edit_latents, but image_seq_len is not updated to reflect the new combined length. This means the subsequent slicing model_output[:, :image_seq_len] will incorrectly truncate the output, potentially discarding the contribution from edit_latents or causing misalignment. This is a critical correctness issue.

Suggested change
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
image_seq_len_original = latents.shape[1]
if edit_latents is not None:
latents = torch.concat([latents, edit_latents], dim=1)
image_ids = torch.concat([image_ids, edit_image_ids], dim=1)
embedded_guidance = torch.tensor([embedded_guidance], device=latents.device)
model_output = dit(
hidden_states=latents,
timestep=timestep / 1000,
guidance=embedded_guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=image_ids,
use_gradient_checkpointing=use_gradient_checkpointing,
use_gradient_checkpointing_offload=use_gradient_checkpointing_offload,
)
model_output = model_output[:, :image_seq_len_original]

Comment on lines +1008 to +1013
# NOTE: the below logic means that we can't support batched inference with images of different resolutions or
# text prompts of differents lengths. Is this a use case we want to support?
if img_ids.ndim == 3:
img_ids = img_ids[0]
if txt_ids.ndim == 3:
txt_ids = txt_ids[0]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The note correctly identifies a limitation regarding batched inference with varying image resolutions or text prompt lengths. However, the current implementation if img_ids.ndim == 3: img_ids = img_ids[0] and if txt_ids.ndim == 3: txt_ids = txt_ids[0] will incorrectly handle batched inputs where B > 1 by only taking the first element. This is a critical correctness issue if batched inference is expected to work with these inputs. It should either raise an error for unsupported batch sizes or be adapted to handle them correctly.

return b

def _tiled_encode(self, x: torch.Tensor) -> torch.Tensor:
r"""Encode a batch of images using a tiled encoder.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The attributes self.tile_sample_min_size and self.tile_overlap_factor are used here but do not appear to be initialized in the Flux2VAE constructor. This could lead to AttributeError if not set elsewhere. Ensure all attributes are properly initialized.

Comment on lines +408 to +411
t = torch.arange(1) # [0] - time dimension
h = torch.arange(height)
w = torch.arange(width)
l = torch.arange(1) # [0] - layer dimension
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to prepare_text_ids, the t and l coordinates for image IDs are hardcoded to torch.arange(1). This assumes a single time step and layer dimension for image tokens. If these dimensions can vary or need to be dynamic, this hardcoding could lead to correctness issues or limit flexibility.

Comment on lines +602 to +603
self.heads = out_dim // dim_head if out_dim is not None else heads

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to Flux2Attention, the calculation of self.heads here (out_dim // dim_head) might lead to inconsistencies if out_dim is provided and out_dim // dim_head does not equal the heads argument passed to the constructor. Ensure consistency between heads and out_dim or clarify the intended behavior.

Comment on lines +470 to +471
self.heads = out_dim // dim_head if out_dim is not None else heads

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The calculation of self.heads here (out_dim // dim_head) might lead to inconsistencies if out_dim is provided and out_dim // dim_head does not equal the heads argument passed to the constructor. It's generally better to ensure self.heads is directly derived from the heads argument or that out_dim is consistent with heads * dim_head.


def edit_image_auto_resize(self, edit_image):
"""Auto resize edit image to optimal dimensions"""
calculated_width, calculated_height = self.calculate_dimensions(1024 * 1024, edit_image.size[0] / edit_image.size[1])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The target area 1024 * 1024 for resizing edit images is hardcoded. This should be a configurable parameter, possibly passed through the pipeline's __call__ method or defined in the Flux2PipelineConfig, to allow for different target resolutions.

return image

def process_edit_image_ids(self, image_latents, scale=10):
"""Process image IDs for edit images"""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The scale=10 used in t_coords calculation is a magic number. It would be more maintainable to define this as a named constant or make it a configurable parameter.

):
image_seq_len = latents.shape[1]
if edit_latents is not None:
image_seq_len = latents.shape[1]
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

冗余代码

)

if config.vae_path is None:
config.vae_path = fetch_model(config.model_path, path="vae/diffusion_pytorch_model.safetensors")
Copy link
Copy Markdown
Contributor

@qzzz95 qzzz95 Jan 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里指定为 modelscope 上一个具体的 repo_id 吧。现在的写法是假定 model_path 是一个目录,且目录下得有 vae 的子目录。这种隐式的假定让接口比较难调用。参考 QwenImagePipeline

vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)

if config.encoder_path is None:
config.encoder_path = fetch_model(config.model_path, path="text_encoder")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同 vae_path

)

if config.vae_path is None:
config.vae_path = fetch_model(config.model_path, path="vae/diffusion_pytorch_model.safetensors")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参照 QwenImagePipeline,写死成 modelscope 上的 repo_id。现在是隐式约定 config.model_path 得是目录,且下面得有 vae, text_encoder 子目录

@modelscope modelscope deleted a comment from gemini-code-assist Bot Jan 22, 2026
vae_state_dict = cls.load_model_checkpoint(config.vae_path, device="cpu", dtype=config.vae_dtype)

if config.encoder_path is None:
if config.model_size == "4B":
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

model_size 类型改成 enum 吧,不然得看 pipeline 代码才能知道该填什么值了

@modelscope modelscope deleted a comment from gemini-code-assist Bot Jan 22, 2026
@modelscope modelscope deleted a comment from gemini-code-assist Bot Jan 22, 2026
@qzzz95 qzzz95 merged commit 31dce8f into main Jan 23, 2026
@qzzz95 qzzz95 deleted the flux-klein branch January 23, 2026 02:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants