From da1096eac25b09a34650a3a47609f12ae856b7c5 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 6 Oct 2025 21:06:20 +0530 Subject: [PATCH 1/3] start --- .../modular_pipelines/flux/before_denoise.py | 331 ++++++------------ .../modular_pipelines/flux/denoise.py | 3 - .../modular_pipelines/flux/encoders.py | 235 +++++++------ .../modular_pipelines/flux/inputs.py | 239 +++++++++++++ 4 files changed, 458 insertions(+), 350 deletions(-) create mode 100644 src/diffusers/modular_pipelines/flux/inputs.py diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 95858fbf6eb0..2af13a2e798a 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -13,12 +13,12 @@ # limitations under the License. import inspect -from typing import Any, List, Optional, Tuple, Union +from typing import List, Optional, Union import numpy as np import torch -from ...models import AutoencoderKL +from ...pipelines import FluxPipeline from ...schedulers import FlowMatchEulerDiscreteScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor @@ -104,48 +104,6 @@ def calculate_shift( return mu -# Adapted from the original implementation. -def prepare_latents_img2img( - vae, scheduler, image, timestep, batch_size, num_channels_latents, height, width, dtype, device, generator -): - if isinstance(generator, list) and len(generator) != batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" - f" size of {batch_size}. Make sure the batch size matches the length of the generators." - ) - - vae_scale_factor = 2 ** (len(vae.config.block_out_channels) - 1) - latent_channels = vae.config.latent_channels - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. - height = 2 * (int(height) // (vae_scale_factor * 2)) - width = 2 * (int(width) // (vae_scale_factor * 2)) - shape = (batch_size, num_channels_latents, height, width) - latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - image = image.to(device=device, dtype=dtype) - if image.shape[1] != latent_channels: - image_latents = _encode_vae_image(image=image, generator=generator) - else: - image_latents = image - if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0: - # expand init_latents for batch_size - additional_image_per_prompt = batch_size // image_latents.shape[0] - image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0) - elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0: - raise ValueError( - f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts." - ) - else: - image_latents = torch.cat([image_latents], dim=0) - - noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) - latents = scheduler.scale_noise(image_latents, timestep, noise) - latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) - return latents, latent_image_ids - - # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" @@ -160,6 +118,7 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") +# TODO: align this with Qwen patchifier def _pack_latents(latents, batch_size, num_channels_latents, height, width): latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2) latents = latents.permute(0, 2, 4, 1, 3, 5) @@ -168,35 +127,6 @@ def _pack_latents(latents, batch_size, num_channels_latents, height, width): return latents -def _prepare_latent_image_ids(batch_size, height, width, device, dtype): - latent_image_ids = torch.zeros(height, width, 3) - latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None] - latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :] - - latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape - - latent_image_ids = latent_image_ids.reshape( - latent_image_id_height * latent_image_id_width, latent_image_id_channels - ) - - return latent_image_ids.to(device=device, dtype=dtype) - - -# Cannot use "# Copied from" because it introduces weird indentation errors. -def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) - else: - image_latents = retrieve_latents(vae.encode(image), generator=generator) - - image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor - - return image_latents - - def _get_initial_timesteps_and_optionals( transformer, scheduler, @@ -231,96 +161,6 @@ def _get_initial_timesteps_and_optionals( return timesteps, num_inference_steps, sigmas, guidance -class FluxInputStep(ModularPipelineBlocks): - model_name = "flux" - - @property - def description(self) -> str: - return ( - "Input processing step that:\n" - " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" - " 2. Adjusts input tensor shapes based on `batch_size` (number of prompts) and `num_images_per_prompt`\n\n" - "All input tensors are expected to have either batch_size=1 or match the batch_size\n" - "of prompt_embeds. The tensors will be duplicated across the batch dimension to\n" - "have a final batch_size of batch_size * num_images_per_prompt." - ) - - @property - def inputs(self) -> List[InputParam]: - return [ - InputParam("num_images_per_prompt", default=1), - InputParam( - "prompt_embeds", - required=True, - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="Pre-generated text embeddings. Can be generated from text_encoder step.", - ), - InputParam( - "pooled_prompt_embeds", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", - ), - # TODO: support negative embeddings? - ] - - @property - def intermediate_outputs(self) -> List[str]: - return [ - OutputParam( - "batch_size", - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", - ), - OutputParam( - "dtype", - type_hint=torch.dtype, - description="Data type of model tensor inputs (determined by `prompt_embeds`)", - ), - OutputParam( - "prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", - description="text embeddings used to guide the image generation", - ), - OutputParam( - "pooled_prompt_embeds", - type_hint=torch.Tensor, - kwargs_type="denoiser_input_fields", - description="pooled text embeddings used to guide the image generation", - ), - # TODO: support negative embeddings? - ] - - def check_inputs(self, components, block_state): - if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: - if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: - raise ValueError( - "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" - f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" - f" {block_state.pooled_prompt_embeds.shape}." - ) - - @torch.no_grad() - def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: - # TODO: consider adding negative embeddings? - block_state = self.get_block_state(state) - self.check_inputs(components, block_state) - - block_state.batch_size = block_state.prompt_embeds.shape[0] - block_state.dtype = block_state.prompt_embeds.dtype - - _, seq_len, _ = block_state.prompt_embeds.shape - block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) - block_state.prompt_embeds = block_state.prompt_embeds.view( - block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 - ) - self.set_block_state(state, block_state) - - return components, state - - class FluxSetTimestepsStep(ModularPipelineBlocks): model_name = "flux" @@ -389,6 +229,10 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.sigmas = sigmas block_state.guidance = guidance + # We set the index here to remove DtoH sync, helpful especially during compilation. + # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 + components.scheduler.set_begin_index(0) + self.set_block_state(state, block_state) return components, state @@ -432,11 +276,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=int, description="The number of denoising steps to perform at inference time", ), - OutputParam( - "latent_timestep", - type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image generation", - ), OutputParam("guidance", type_hint=torch.Tensor, description="Optional guidance to be used."), ] @@ -484,8 +323,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.sigmas = sigmas block_state.guidance = guidance - block_state.latent_timestep = timesteps[:1].repeat(batch_size) - self.set_block_state(state, block_state) return components, state @@ -524,11 +361,6 @@ def intermediate_outputs(self) -> List[OutputParam]: OutputParam( "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" ), - OutputParam( - "latent_image_ids", - type_hint=torch.Tensor, - description="IDs computed from the image sequence needed for RoPE", - ), ] @staticmethod @@ -552,20 +384,13 @@ def prepare_latents( generator, latents=None, ): - # Couldn't use the `prepare_latents` method directly from Flux because I decided to copy over - # the packing methods here. So, for example, `comp._pack_latents()` won't work if we were - # to go with the "# Copied from ..." approach. Or maybe there's a way? - - # VAE applies 8x compression on images but we must also account for packing which requires - # latent height and width to be divisible by 2. height = 2 * (int(height) // (comp.vae_scale_factor * 2)) width = 2 * (int(width) // (comp.vae_scale_factor * 2)) shape = (batch_size, num_channels_latents, height, width) if latents is not None: - latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - return latents.to(device=device, dtype=dtype), latent_image_ids + return latents.to(device=device, dtype=dtype) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( @@ -576,9 +401,7 @@ def prepare_latents( latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) - latent_image_ids = _prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype) - - return latents, latent_image_ids + return latents @torch.no_grad() def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: @@ -587,12 +410,11 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.height = block_state.height or components.default_height block_state.width = block_state.width or components.default_width block_state.device = components._execution_device - block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? block_state.num_channels_latents = components.num_channels_latents self.check_inputs(components, block_state) batch_size = block_state.batch_size * block_state.num_images_per_prompt - block_state.latents, block_state.latent_image_ids = self.prepare_latents( + block_state.latents = self.prepare_latents( components, batch_size, block_state.num_channels_latents, @@ -613,81 +435,126 @@ class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks): model_name = "flux" @property - def expected_components(self) -> List[ComponentSpec]: - return [ComponentSpec("vae", AutoencoderKL), ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] + def description(self) -> str: + return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified." @property - def description(self) -> str: - return "Step that prepares the latents for the image-to-image generation process" + def expected_components(self) -> List[ComponentSpec]: + return [ + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ] @property - def inputs(self) -> List[Tuple[str, Any]]: + def inputs(self) -> List[InputParam]: return [ - InputParam("height", type_hint=int), - InputParam("width", type_hint=int), - InputParam("latents", type_hint=Optional[torch.Tensor]), - InputParam("num_images_per_prompt", type_hint=int, default=1), - InputParam("generator"), InputParam( - "image_latents", + name="latents", required=True, type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.", + description="The initial random noised, can be generated in prepare latent step.", ), InputParam( - "latent_timestep", + name="image_latents", required=True, type_hint=torch.Tensor, - description="The timestep that represents the initial noise level for image-to-image/inpainting generation. Can be generated in set_timesteps step.", + description="The image latents to use for the denoising process. Can be generated in vae encoder and packed in input step.", ), InputParam( - "batch_size", + name="timesteps", required=True, - type_hint=int, - description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.", + type_hint=torch.Tensor, + description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", ), - InputParam("dtype", required=True, type_hint=torch.dtype, description="The dtype of the model inputs"), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( - "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process" - ), - OutputParam( - "latent_image_ids", + name="initial_noise", type_hint=torch.Tensor, - description="IDs computed from the image sequence needed for RoPE", + description="The initial random noised used for inpainting denoising.", ), ] + @staticmethod + def check_inputs(image_latents, latents): + if image_latents.shape[0] != latents.shape[0]: + raise ValueError( + f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}" + ) + + if image_latents.ndim != 3: + raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}") + @torch.no_grad() def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.device = components._execution_device - block_state.dtype = torch.bfloat16 # TODO: okay to hardcode this? - block_state.num_channels_latents = components.num_channels_latents - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.device = components._execution_device + self.check_inputs(image_latents=block_state.image_latents, latents=block_state.latents) - # TODO: implement `check_inputs` - batch_size = block_state.batch_size * block_state.num_images_per_prompt - if block_state.latents is None: - block_state.latents, block_state.latent_image_ids = prepare_latents_img2img( - components.vae, - components.scheduler, - block_state.image_latents, - block_state.latent_timestep, - batch_size, - block_state.num_channels_latents, - block_state.height, - block_state.width, - block_state.dtype, - block_state.device, - block_state.generator, - ) + # prepare latent timestep + latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0]) + + # make copy of initial_noise + block_state.initial_noise = block_state.latents + + # scale noise + block_state.latents = components.scheduler.scale_noise( + block_state.image_latents, latent_timestep, block_state.latents + ) + + self.set_block_state(state, block_state) + + return components, state + + +class FluxRoPEInputsStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps." + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam(name="image_height", required=True), + InputParam(name="image_width", required=True), + InputParam(name="prompt_embeds"), + ] + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + name="txt_ids", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the prompt embeds, used for RoPE calculation.", + ), + OutputParam( + name="img_ids", + kwargs_type="denoiser_input_fields", + type_hint=List[int], + description="The sequence lengths of the image latents, used for RoPE calculation.", + ), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + prompt_embeds = block_state.prompt_embeds + device, dtype = prompt_embeds.device, prompt_embeds.dtype + block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to( + device=prompt_embeds.device, dtype=prompt_embeds.dtype + ) + + height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) + width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) + block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index ffa0a4456f5d..0620a276a740 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -195,9 +195,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip block_state.num_warmup_steps = max( len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 ) - # We set the index here to remove DtoH sync, helpful especially during compilation. - # Check out more details here: https://github.com/huggingface/diffusers/pull/11696 - components.scheduler.set_begin_index(0) with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: for i, t in enumerate(block_state.timesteps): components, block_state = self.loop_step(components, block_state, i=i, t=t) diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 16ddecbadb4f..68ae4e0d05d0 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -25,7 +25,7 @@ from ...models import AutoencoderKL from ...utils import USE_PEFT_BACKEND, is_ftfy_available, logging, scale_lora_layers, unscale_lora_layers from ..modular_pipeline import ModularPipelineBlocks, PipelineState -from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam +from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .modular_pipeline import FluxModularPipeline @@ -67,89 +67,154 @@ def retrieve_latents( raise AttributeError("Could not access latents of provided encoder_output") -class FluxVaeEncoderStep(ModularPipelineBlocks): - model_name = "flux" +def encode_vae_image(vae: AutoencoderKL, image: torch.Tensor, generator: torch.Generator, sample_mode="sample"): + if isinstance(generator, list): + image_latents = [ + retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i], sample_mode=sample_mode) + for i in range(image.shape[0]) + ] + image_latents = torch.cat(image_latents, dim=0) + else: + image_latents = retrieve_latents(vae.encode(image), generator=generator, sample_mode=sample_mode) + + image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + + return image_latents + + +class FluxProcessImagesInputStep(ModularPipelineBlocks): + model_name = "Flux" @property def description(self) -> str: - return "Vae Encoder step that encode the input image into a latent representation" + return "Image Preprocess step. Resizing is needed in Flux Kontext (will be implemented later.)" @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("vae", AutoencoderKL), ComponentSpec( "image_processor", VaeImageProcessor, - config=FrozenDict({"vae_scale_factor": 16, "vae_latent_channels": 16}), + config=FrozenDict({"vae_scale_factor": 16}), default_creation_method="from_config", ), ] @property def inputs(self) -> List[InputParam]: - return [ - InputParam("image", required=True), - InputParam("height"), - InputParam("width"), - InputParam("generator"), - InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), - InputParam( - "preprocess_kwargs", - type_hint=Optional[dict], - description="A kwargs dictionary that if specified is passed along to the `ImageProcessor` as defined under `self.image_processor` in [diffusers.image_processor.VaeImageProcessor]", - ), - ] + return [InputParam("resized_image"), InputParam("image"), InputParam("height"), InputParam("width")] @property def intermediate_outputs(self) -> List[OutputParam]: return [ - OutputParam( - "image_latents", - type_hint=torch.Tensor, - description="The latents representing the reference image for image-to-image/inpainting generation", - ) + OutputParam(name="processed_image"), ] @staticmethod - # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image with self.vae->vae - def _encode_vae_image(vae, image: torch.Tensor, generator: torch.Generator): - if isinstance(generator, list): - image_latents = [ - retrieve_latents(vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) - ] - image_latents = torch.cat(image_latents, dim=0) + def check_inputs(height, width, vae_scale_factor): + if height is not None and height % (vae_scale_factor * 2) != 0: + raise ValueError(f"Height must be divisible by {vae_scale_factor * 2} but is {height}") + + if width is not None and width % (vae_scale_factor * 2) != 0: + raise ValueError(f"Width must be divisible by {vae_scale_factor * 2} but is {width}") + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState): + block_state = self.get_block_state(state) + + if block_state.resized_image is None and block_state.image is None: + raise ValueError("`resized_image` and `image` cannot be None at the same time") + + if block_state.resized_image is None: + image = block_state.image + self.check_inputs( + height=block_state.height, width=block_state.width, vae_scale_factor=components.vae_scale_factor + ) + height = block_state.height or components.default_height + width = block_state.width or components.default_width else: - image_latents = retrieve_latents(vae.encode(image), generator=generator) + width, height = block_state.resized_image[0].size + image = block_state.resized_image + + block_state.processed_image = components.image_processor.preprocess(image=image, height=height, width=width) + + self.set_block_state(state, block_state) + return components, state + + +class FluxVaeEncoderDynamicStep(ModularPipelineBlocks): + model_name = "flux" + + def __init__( + self, + input_name: str = "processed_image", + output_name: str = "image_latents", + ): + """Initialize a VAE encoder step for converting images to latent representations. + + Both the input and output names are configurable so this block can be configured to process to different image + inputs (e.g., "processed_image" -> "image_latents", "processed_control_image" -> "control_image_latents"). + + Args: + input_name (str, optional): Name of the input image tensor. Defaults to "processed_image". + Examples: "processed_image" or "processed_control_image" + output_name (str, optional): Name of the output latent tensor. Defaults to "image_latents". + Examples: "image_latents" or "control_image_latents" + + Examples: + # Basic usage with default settings (includes image processor): # FluxImageVaeEncoderDynamicStep() + + # Custom input/output names for control image: # FluxImageVaeEncoderDynamicStep( + input_name="processed_control_image", output_name="control_image_latents" + ) + """ + self._image_input_name = input_name + self._image_latents_output_name = output_name + super().__init__() + + @property + def description(self) -> str: + return f"Dynamic VAE Encoder step that converts {self._image_input_name} into latent representations {self._image_latents_output_name}.\n" - image_latents = (image_latents - vae.config.shift_factor) * vae.config.scaling_factor + @property + def expected_components(self) -> List[ComponentSpec]: + components = [ComponentSpec("vae", AutoencoderKL)] + return components - return image_latents + @property + def inputs(self) -> List[InputParam]: + inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")] + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam( + self._image_latents_output_name, + type_hint=torch.Tensor, + description="The latents representing the reference image", + ) + ] @torch.no_grad() def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - block_state.preprocess_kwargs = block_state.preprocess_kwargs or {} - block_state.device = components._execution_device - block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype - block_state.image = components.image_processor.preprocess( - block_state.image, height=block_state.height, width=block_state.width, **block_state.preprocess_kwargs - ) - block_state.image = block_state.image.to(device=block_state.device, dtype=block_state.dtype) - - block_state.batch_size = block_state.image.shape[0] + device = components._execution_device + dtype = components.vae.dtype - # if generator is a list, make sure the length of it matches the length of images (both should be batch_size) - if isinstance(block_state.generator, list) and len(block_state.generator) != block_state.batch_size: - raise ValueError( - f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch" - f" size of {block_state.batch_size}. Make sure the batch size matches the length of the generators." - ) + image = getattr(block_state, self._image_input_name) - block_state.image_latents = self._encode_vae_image( - components.vae, image=block_state.image, generator=block_state.generator + # Encode image into latents + image_latents = encode_vae_image( + image=image, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, ) + setattr(block_state, self._image_latents_output_name, image_latents) self.set_block_state(state, block_state) @@ -161,7 +226,7 @@ class FluxTextEncoderStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Text Encoder step that generate text_embeddings to guide the video generation" + return "Text Encoder step that generate text_embeddings to guide the image generation" @property def expected_components(self) -> List[ComponentSpec]: @@ -172,10 +237,6 @@ def expected_components(self) -> List[ComponentSpec]: ComponentSpec("tokenizer_2", T5TokenizerFast), ] - @property - def expected_configs(self) -> List[ConfigSpec]: - return [] - @property def inputs(self) -> List[InputParam]: return [ @@ -200,12 +261,6 @@ def intermediate_outputs(self) -> List[OutputParam]: type_hint=torch.Tensor, description="pooled text embeddings used to guide the image generation", ), - OutputParam( - "text_ids", - kwargs_type="denoiser_input_fields", - type_hint=torch.Tensor, - description="ids from the text sequence for RoPE", - ), ] @staticmethod @@ -216,16 +271,10 @@ def check_inputs(block_state): @staticmethod def _get_t5_prompt_embeds( - components, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - max_sequence_length: int, - device: torch.device, + components, prompt: Union[str, List[str]], max_sequence_length: int, device: torch.device ): dtype = components.text_encoder_2.dtype - prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if isinstance(components, TextualInversionLoaderMixin): prompt = components.maybe_convert_prompt(prompt, components.tokenizer_2) @@ -251,23 +300,11 @@ def _get_t5_prompt_embeds( prompt_embeds = components.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0] prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) - _, seq_len, _ = prompt_embeds.shape - - # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) - return prompt_embeds @staticmethod - def _get_clip_prompt_embeds( - components, - prompt: Union[str, List[str]], - num_images_per_prompt: int, - device: torch.device, - ): + def _get_clip_prompt_embeds(components, prompt: Union[str, List[str]], device: torch.device): prompt = [prompt] if isinstance(prompt, str) else prompt - batch_size = len(prompt) if isinstance(components, TextualInversionLoaderMixin): prompt = components.maybe_convert_prompt(prompt, components.tokenizer) @@ -297,10 +334,6 @@ def _get_clip_prompt_embeds( prompt_embeds = prompt_embeds.pooler_output prompt_embeds = prompt_embeds.to(dtype=components.text_encoder.dtype, device=device) - # duplicate text embeddings for each generation per prompt, using mps friendly method - prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt) - prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1) - return prompt_embeds @staticmethod @@ -309,34 +342,11 @@ def encode_prompt( prompt: Union[str, List[str]], prompt_2: Union[str, List[str]], device: Optional[torch.device] = None, - num_images_per_prompt: int = 1, prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, max_sequence_length: int = 512, lora_scale: Optional[float] = None, ): - r""" - Encodes the prompt into text encoder hidden states. - - Args: - prompt (`str` or `List[str]`, *optional*): - prompt to be encoded - prompt_2 (`str` or `List[str]`, *optional*): - The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is - used in all text-encoders - device: (`torch.device`): - torch device - num_images_per_prompt (`int`): - number of images that should be generated per prompt - prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not - provided, text embeddings will be generated from `prompt` input argument. - pooled_prompt_embeds (`torch.FloatTensor`, *optional*): - Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. - If not provided, pooled text embeddings will be generated from `prompt` input argument. - lora_scale (`float`, *optional*): - A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. - """ device = device or components._execution_device # set lora scale so that monkey patched LoRA @@ -361,12 +371,10 @@ def encode_prompt( components, prompt=prompt, device=device, - num_images_per_prompt=num_images_per_prompt, ) prompt_embeds = FluxTextEncoderStep._get_t5_prompt_embeds( components, prompt=prompt_2, - num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, device=device, ) @@ -381,10 +389,7 @@ def encode_prompt( # Retrieve the original scale by scaling back the LoRA layers unscale_lora_layers(components.text_encoder_2, lora_scale) - dtype = components.text_encoder.dtype if components.text_encoder is not None else torch.bfloat16 - text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype) - - return prompt_embeds, pooled_prompt_embeds, text_ids + return prompt_embeds, pooled_prompt_embeds @torch.no_grad() def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: @@ -400,7 +405,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip if block_state.joint_attention_kwargs is not None else None ) - (block_state.prompt_embeds, block_state.pooled_prompt_embeds, block_state.text_ids) = self.encode_prompt( + block_state.prompt_embeds, block_state.pooled_prompt_embeds = self.encode_prompt( components, prompt=block_state.prompt, prompt_2=None, diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py new file mode 100644 index 000000000000..bbd7c25853d5 --- /dev/null +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -0,0 +1,239 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List + +import torch + +from ..modular_pipeline import ModularPipelineBlocks, PipelineState +from ..modular_pipeline_utils import InputParam, OutputParam +from .modular_pipeline import FluxModularPipeline +from ...pipelines import FluxPipeline +# TODO: consider making these common utilities for modular if they are not pipeline-specific. +from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size + + +class FluxTextInputStep(ModularPipelineBlocks): + model_name = "flux" + + @property + def description(self) -> str: + return ( + "Text input processing step that standardizes text embeddings for the pipeline.\n" + "This step:\n" + " 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" + " 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" + ) + + @property + def inputs(self) -> List[InputParam]: + return [ + InputParam("num_images_per_prompt", default=1), + InputParam( + "prompt_embeds", + required=True, + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated text embeddings. Can be generated from text_encoder step.", + ), + InputParam( + "pooled_prompt_embeds", + kwargs_type="denoiser_input_fields", + type_hint=torch.Tensor, + description="Pre-generated pooled text embeddings. Can be generated from text_encoder step.", + ), + # TODO: support negative embeddings? + ] + + @property + def intermediate_outputs(self) -> List[str]: + return [ + OutputParam( + "batch_size", + type_hint=int, + description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", + ), + OutputParam( + "dtype", + type_hint=torch.dtype, + description="Data type of model tensor inputs (determined by `prompt_embeds`)", + ), + OutputParam( + "prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="text embeddings used to guide the image generation", + ), + OutputParam( + "pooled_prompt_embeds", + type_hint=torch.Tensor, + kwargs_type="denoiser_input_fields", + description="pooled text embeddings used to guide the image generation", + ), + # TODO: support negative embeddings? + ] + + def check_inputs(self, components, block_state): + if block_state.prompt_embeds is not None and block_state.pooled_prompt_embeds is not None: + if block_state.prompt_embeds.shape[0] != block_state.pooled_prompt_embeds.shape[0]: + raise ValueError( + "`prompt_embeds` and `pooled_prompt_embeds` must have the same batch size when passed directly, but" + f" got: `prompt_embeds` {block_state.prompt_embeds.shape} != `pooled_prompt_embeds`" + f" {block_state.pooled_prompt_embeds.shape}." + ) + + @torch.no_grad() + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + # TODO: consider adding negative embeddings? + block_state = self.get_block_state(state) + self.check_inputs(components, block_state) + + block_state.batch_size = block_state.prompt_embeds.shape[0] + block_state.dtype = block_state.prompt_embeds.dtype + + _, seq_len, _ = block_state.prompt_embeds.shape + block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) + block_state.prompt_embeds = block_state.prompt_embeds.view( + block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 + ) + self.set_block_state(state, block_state) + + return components, state + + +# Adapted from `QwenImageInputsDynamicStep` +class FluxInputsDynamicStep(ModularPipelineBlocks): + model_name = "flux" + + def __init__( + self, + image_latent_inputs: List[str] = ["image_latents"], + additional_batch_inputs: List[str] = [], + ): + if not isinstance(image_latent_inputs, list): + image_latent_inputs = [image_latent_inputs] + if not isinstance(additional_batch_inputs, list): + additional_batch_inputs = [additional_batch_inputs] + + self._image_latent_inputs = image_latent_inputs + self._additional_batch_inputs = additional_batch_inputs + super().__init__() + + @property + def description(self) -> str: + # Functionality section + summary_section = ( + "Input processing step that:\n" + " 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" + " 2. For additional batch inputs: Expands batch dimensions to match final batch size" + ) + + # Inputs info + inputs_info = "" + if self._image_latent_inputs or self._additional_batch_inputs: + inputs_info = "\n\nConfigured inputs:" + if self._image_latent_inputs: + inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" + if self._additional_batch_inputs: + inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" + + # Placement guidance + placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." + + return summary_section + inputs_info + placement_section + + @property + def inputs(self) -> List[InputParam]: + inputs = [ + InputParam(name="num_images_per_prompt", default=1), + InputParam(name="batch_size", required=True), + InputParam(name="height"), + InputParam(name="width"), + ] + + # Add image latent inputs + for image_latent_input_name in self._image_latent_inputs: + inputs.append(InputParam(name=image_latent_input_name)) + + # Add additional batch inputs + for input_name in self._additional_batch_inputs: + inputs.append(InputParam(name=input_name)) + + return inputs + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="image_height", type_hint=int, description="The height of the image latents"), + OutputParam(name="image_width", type_hint=int, description="The width of the image latents"), + ] + + def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) + block_state.height = block_state.height or height + block_state.width = block_state.width or width + + if not hasattr(block_state, "image_height"): + block_state.image_height = height + if not hasattr(block_state, "image_width"): + block_state.image_width = width + + # 2. Patchify the image latent tensor + # TODO: Implement patchifier for Flux. + latent_height, latent_width = image_latent_tensor.shape[2:] + image_latent_tensor = FluxPipeline._pack_latents( + image_latent_tensor, + block_state.batch_size, + image_latent_tensor.shape[1], + latent_height, + latent_width + ) + + # 3. Expand batch size + image_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=image_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, image_latent_input_name, image_latent_tensor) + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state From 7c7e8a44bcf2761ffc97e132c1478a7dc1aaf4c2 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Oct 2025 13:02:57 +0530 Subject: [PATCH 2/3] fix --- .../modular_pipelines/flux/before_denoise.py | 16 +- .../modular_pipelines/flux/denoise.py | 9 +- .../modular_pipelines/flux/encoders.py | 7 +- .../modular_pipelines/flux/modular_blocks.py | 181 ++++++++++++------ 4 files changed, 136 insertions(+), 77 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 2af13a2e798a..4afa84460735 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -398,6 +398,7 @@ def prepare_latents( f" size of {batch_size}. Make sure the batch size matches the length of the generators." ) + # TODO: move packing latents code to a patchifier latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) latents = _pack_latents(latents, batch_size, num_channels_latents, height, width) @@ -436,12 +437,13 @@ class FluxImg2ImgPrepareLatentsStep(ModularPipelineBlocks): @property def description(self) -> str: - return "Step that adds noise to image latents for image-to-image/inpainting. Should be run after set_timesteps, prepare_latents. Both noise and image latents should alreadybe patchified." + return "Step that adds noise to image latents for image-to-image. Should be run after `set_timesteps`," + " `prepare_latents`. Both noise and image latents should already be patchified." @property def expected_components(self) -> List[ComponentSpec]: return [ - ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), + ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler) ] @property @@ -521,9 +523,9 @@ def description(self) -> str: @property def inputs(self) -> List[InputParam]: return [ - InputParam(name="image_height", required=True), - InputParam(name="image_width", required=True), - InputParam(name="prompt_embeds"), + InputParam(name="height", required=True), + InputParam(name="width", required=True), + InputParam(name="prompt_embeds") ] @property @@ -552,8 +554,8 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip device=prompt_embeds.device, dtype=prompt_embeds.dtype ) - height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2)) - width = 2 * (int(block_state.image_width) // (components.vae_scale_factor * 2)) + height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2)) + width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2)) block_state.img_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/flux/denoise.py b/src/diffusers/modular_pipelines/flux/denoise.py index 0620a276a740..e482c198e835 100644 --- a/src/diffusers/modular_pipelines/flux/denoise.py +++ b/src/diffusers/modular_pipelines/flux/denoise.py @@ -76,18 +76,17 @@ def inputs(self) -> List[Tuple[str, Any]]: description="Pooled prompt embeddings", ), InputParam( - "text_ids", + "txt_ids", required=True, type_hint=torch.Tensor, description="IDs computed from text sequence needed for RoPE", ), InputParam( - "latent_image_ids", + "img_ids", required=True, type_hint=torch.Tensor, description="IDs computed from image sequence needed for RoPE", ), - # TODO: guidance ] @torch.no_grad() @@ -101,8 +100,8 @@ def __call__( encoder_hidden_states=block_state.prompt_embeds, pooled_projections=block_state.pooled_prompt_embeds, joint_attention_kwargs=block_state.joint_attention_kwargs, - txt_ids=block_state.text_ids, - img_ids=block_state.latent_image_ids, + txt_ids=block_state.txt_ids, + img_ids=block_state.img_ids, return_dict=False, )[0] block_state.noise_pred = noise_pred diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index 68ae4e0d05d0..a9d3bdfaf2f8 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -204,15 +204,13 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip dtype = components.vae.dtype image = getattr(block_state, self._image_input_name) + image = image.to(device=device, dtype=dtype) # Encode image into latents image_latents = encode_vae_image( image=image, vae=components.vae, - generator=block_state.generator, - device=device, - dtype=dtype, - latent_channels=components.num_channels_latents, + generator=block_state.generator ) setattr(block_state, self._image_latents_output_name, image_latents) @@ -412,7 +410,6 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip prompt_embeds=None, pooled_prompt_embeds=None, device=block_state.device, - num_images_per_prompt=1, # TODO: hardcoded for now. max_sequence_length=block_state.max_sequence_length, lora_scale=block_state.text_encoder_lora_scale, ) diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index ca4f993a11fe..9ef293f9bb38 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -18,21 +18,43 @@ from .before_denoise import ( FluxImg2ImgPrepareLatentsStep, FluxImg2ImgSetTimestepsStep, - FluxInputStep, FluxPrepareLatentsStep, FluxSetTimestepsStep, ) from .decoders import FluxDecodeStep from .denoise import FluxDenoiseStep -from .encoders import FluxTextEncoderStep, FluxVaeEncoderStep +from .encoders import FluxTextEncoderStep, FluxVaeEncoderDynamicStep +from .before_denoise import FluxRoPEInputsStep +from .inputs import FluxTextInputStep, FluxInputsDynamicStep + logger = logging.get_logger(__name__) # pylint: disable=invalid-name # vae encoder (run before before_denoise) +from .encoders import FluxProcessImagesInputStep + +FluxImg2ImgVaeEncoderBlocks = InsertableDict( + [ + ("preprocess", FluxProcessImagesInputStep()), + ("encode", FluxVaeEncoderDynamicStep()), + ] +) + +class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks): + model_name = "flux" + + block_classes = FluxImg2ImgVaeEncoderBlocks.values() + block_names = FluxImg2ImgVaeEncoderBlocks.keys() + + @property + def description(self) -> str: + return "Vae encoder step that preprocess andencode the image inputs into their latent representations." + + class FluxAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [FluxVaeEncoderStep] + block_classes = [FluxImg2ImgVaeEncoderStep] block_names = ["img2img"] block_trigger_inputs = ["image"] @@ -41,44 +63,49 @@ def description(self): return ( "Vae encoder step that encode the image inputs into their latent representations.\n" + "This is an auto pipeline block that works for img2img tasks.\n" - + " - `FluxVaeEncoderStep` (img2img) is used when only `image` is provided." - + " - if `image` is provided, step will be skipped." + + " - `FluxImg2ImgVaeEncoderStep` (img2img) is used when only `image` is provided." + + " - if `image` is not provided, step will be skipped." ) -# before_denoise: text2img, img2img -class FluxBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [ - FluxInputStep, - FluxPrepareLatentsStep, - FluxSetTimestepsStep, + +# before_denoise: text2img +FluxBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", FluxPrepareLatentsStep()), + ("set_timesteps", FluxSetTimestepsStep()), + ("prepare_rope_inputs", FluxRoPEInputsStep()) ] - block_names = ["input", "prepare_latents", "set_timesteps"] +) + +class FluxBeforeDenoiseStep(SequentialPipelineBlocks): + block_classes = FluxBeforeDenoiseBlocks.values() + block_names = FluxBeforeDenoiseBlocks.keys() @property def description(self): return ( - "Before denoise step that prepare the inputs for the denoise step.\n" - + "This is a sequential pipeline blocks:\n" - + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n" - + " - `FluxPrepareLatentsStep` is used to prepare the latents\n" - + " - `FluxSetTimestepsStep` is used to set the timesteps\n" + "Before denoise step that prepares the inputs for the denoise step in text-to-image generation." ) # before_denoise: img2img +FluxImg2ImgBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", FluxPrepareLatentsStep()), + ("set_timesteps", FluxImg2ImgSetTimestepsStep()), + ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()), + ("prepare_rope_inputs", FluxRoPEInputsStep()) + ] +) class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): - block_classes = [FluxInputStep, FluxImg2ImgSetTimestepsStep, FluxImg2ImgPrepareLatentsStep] - block_names = ["input", "set_timesteps", "prepare_latents"] + block_classes = FluxImg2ImgBeforeDenoiseBlocks.values() + block_names = FluxImg2ImgBeforeDenoiseBlocks.keys() @property def description(self): return ( - "Before denoise step that prepare the inputs for the denoise step for img2img task.\n" - + "This is a sequential pipeline blocks:\n" - + " - `FluxInputStep` is used to adjust the batch size of the model inputs\n" - + " - `FluxImg2ImgSetTimestepsStep` is used to set the timesteps\n" - + " - `FluxImg2ImgPrepareLatentsStep` is used to prepare the latents\n" + "Before denoise step that prepare the inputs for the denoise step for img2img task." ) @@ -113,7 +140,7 @@ def description(self) -> str: ) -# decode: all task (text2img, img2img, inpainting) +# decode: all task (text2img, img2img) class FluxAutoDecodeStep(AutoPipelineBlocks): block_classes = [FluxDecodeStep] block_names = ["non-inpaint"] @@ -124,32 +151,73 @@ def description(self): return "Decode step that decode the denoised latents into image outputs.\n - `FluxDecodeStep`" +# inputs: text2image/img2img +FluxImg2ImgBlocks = InsertableDict( + [ + ("text_inputs", FluxTextInputStep()), + ("additional_inputs", FluxInputsDynamicStep()) + ] +) + +class FluxImg2ImgInputStep(SequentialPipelineBlocks): + model_name = "flux" + block_classes = FluxImg2ImgBlocks.values() + block_names = FluxImg2ImgBlocks.keys() + + @property + def description(self): + return "Input step that prepares the inputs for the img2img denoising step. It:\n" + " - make sure the text embeddings have consistent batch size as well as the additional inputs (`image_latents`).\n" + " - update height/width based `image_latents`, patchify `image_latents`." + + +class FluxImageAutoInputStep(AutoPipelineBlocks): + block_classes = [FluxImg2ImgInputStep, FluxTextInputStep] + block_names = ["img2img", "text2image"] + block_trigger_inputs = [ "image_latents", None] + + @property + def description(self): + return ( + "Input step that standardize the inputs for the denoising step, e.g. make sure inputs have consistent batch size, and patchified. \n" + " This is an auto pipeline block that works for text2image/img2img tasks.\n" + + " - `FluxImg2ImgInputStep` (img2img) is used when `image_latents` is provided.\n" + + " - `FluxTextInputStep` (text2image) is used when `image_latents` are not provided.\n" + ) + + class FluxCoreDenoiseStep(SequentialPipelineBlocks): - block_classes = [FluxInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep] + model_name = "flux" + block_classes = [FluxImageAutoInputStep, FluxAutoBeforeDenoiseStep, FluxAutoDenoiseStep] block_names = ["input", "before_denoise", "denoise"] @property def description(self): return ( "Core step that performs the denoising process. \n" - + " - `FluxInputStep` (input) standardizes the inputs for the denoising step.\n" + + " - `FluxImageAutoInputStep` (input) standardizes the inputs for the denoising step.\n" + " - `FluxAutoBeforeDenoiseStep` (before_denoise) prepares the inputs for the denoising step.\n" + " - `FluxAutoDenoiseStep` (denoise) iteratively denoises the latents.\n" - + "This step support text-to-image and image-to-image tasks for Flux:\n" + + "This step supports text-to-image and image-to-image tasks for Flux:\n" + " - for image-to-image generation, you need to provide `image_latents`\n" - + " - for text-to-image generation, all you need to provide is prompt embeddings" + + " - for text-to-image generation, all you need to provide is prompt embeddings." ) -# text2image -class FluxAutoBlocks(SequentialPipelineBlocks): - block_classes = [ - FluxTextEncoderStep, - FluxAutoVaeEncoderStep, - FluxCoreDenoiseStep, - FluxAutoDecodeStep, +# Auto blocks (text2image and img2img) +AUTO_BLOCKS = InsertableDict( + [ + ("text_encoder", FluxTextEncoderStep()), + ("image_encoder", FluxAutoVaeEncoderStep()), + ("denoise", FluxCoreDenoiseStep()), + ("decode", FluxDecodeStep()) ] - block_names = ["text_encoder", "image_encoder", "denoise", "decode"] +) +class FluxAutoBlocks(SequentialPipelineBlocks): + model_name = "flux" + + block_classes = AUTO_BLOCKS.values() + block_names = AUTO_BLOCKS.keys() @property def description(self): @@ -162,35 +230,28 @@ def description(self): TEXT2IMAGE_BLOCKS = InsertableDict( [ - ("text_encoder", FluxTextEncoderStep), - ("input", FluxInputStep), - ("prepare_latents", FluxPrepareLatentsStep), - ("set_timesteps", FluxSetTimestepsStep), - ("denoise", FluxDenoiseStep), - ("decode", FluxDecodeStep), + ("text_encoder", FluxTextEncoderStep()), + ("input", FluxTextInputStep()), + ("prepare_latents", FluxPrepareLatentsStep()), + ("set_timesteps", FluxSetTimestepsStep()), + ("prepare_rope_inputs", FluxRoPEInputsStep()), + ("denoise", FluxDenoiseStep()), + ("decode", FluxDecodeStep()), ] ) IMAGE2IMAGE_BLOCKS = InsertableDict( [ - ("text_encoder", FluxTextEncoderStep), - ("image_encoder", FluxVaeEncoderStep), - ("input", FluxInputStep), - ("set_timesteps", FluxImg2ImgSetTimestepsStep), - ("prepare_latents", FluxImg2ImgPrepareLatentsStep), - ("denoise", FluxDenoiseStep), - ("decode", FluxDecodeStep), + ("text_encoder", FluxTextEncoderStep()), + ("vae_encoder", FluxVaeEncoderDynamicStep()), + ("input", FluxImg2ImgInputStep()), + ("prepare_latents", FluxPrepareLatentsStep()), + ("set_timesteps", FluxImg2ImgSetTimestepsStep()), + ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()), + ("prepare_rope_inputs", FluxRoPEInputsStep()), + ("denoise", FluxDenoiseStep()), + ("decode", FluxDecodeStep()), ] ) -AUTO_BLOCKS = InsertableDict( - [ - ("text_encoder", FluxTextEncoderStep), - ("image_encoder", FluxAutoVaeEncoderStep), - ("denoise", FluxCoreDenoiseStep), - ("decode", FluxAutoDecodeStep), - ] -) - - ALL_BLOCKS = {"text2image": TEXT2IMAGE_BLOCKS, "img2img": IMAGE2IMAGE_BLOCKS, "auto": AUTO_BLOCKS} From 77ef81685b9a59d00ea77db78cdfc933d9dabf28 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Tue, 7 Oct 2025 13:09:01 +0530 Subject: [PATCH 3/3] up --- .../modular_pipelines/flux/before_denoise.py | 10 ++--- .../modular_pipelines/flux/encoders.py | 6 +-- .../modular_pipelines/flux/inputs.py | 11 ++---- .../modular_pipelines/flux/modular_blocks.py | 38 +++++++++---------- 4 files changed, 25 insertions(+), 40 deletions(-) diff --git a/src/diffusers/modular_pipelines/flux/before_denoise.py b/src/diffusers/modular_pipelines/flux/before_denoise.py index 4afa84460735..5f3193af0e35 100644 --- a/src/diffusers/modular_pipelines/flux/before_denoise.py +++ b/src/diffusers/modular_pipelines/flux/before_denoise.py @@ -442,9 +442,7 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: - return [ - ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler) - ] + return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] @property def inputs(self) -> List[InputParam]: @@ -516,16 +514,14 @@ class FluxRoPEInputsStep(ModularPipelineBlocks): @property def description(self) -> str: - return ( - "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps." - ) + return "Step that prepares the RoPE inputs for the denoising process. Should be placed after text encoder and latent preparation steps." @property def inputs(self) -> List[InputParam]: return [ InputParam(name="height", required=True), InputParam(name="width", required=True), - InputParam(name="prompt_embeds") + InputParam(name="prompt_embeds"), ] @property diff --git a/src/diffusers/modular_pipelines/flux/encoders.py b/src/diffusers/modular_pipelines/flux/encoders.py index a9d3bdfaf2f8..6368086cbb5f 100644 --- a/src/diffusers/modular_pipelines/flux/encoders.py +++ b/src/diffusers/modular_pipelines/flux/encoders.py @@ -207,11 +207,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip image = image.to(device=device, dtype=dtype) # Encode image into latents - image_latents = encode_vae_image( - image=image, - vae=components.vae, - generator=block_state.generator - ) + image_latents = encode_vae_image(image=image, vae=components.vae, generator=block_state.generator) setattr(block_state, self._image_latents_output_name, image_latents) self.set_block_state(state, block_state) diff --git a/src/diffusers/modular_pipelines/flux/inputs.py b/src/diffusers/modular_pipelines/flux/inputs.py index bbd7c25853d5..f9192655d1ac 100644 --- a/src/diffusers/modular_pipelines/flux/inputs.py +++ b/src/diffusers/modular_pipelines/flux/inputs.py @@ -16,12 +16,13 @@ import torch +from ...pipelines import FluxPipeline from ..modular_pipeline import ModularPipelineBlocks, PipelineState from ..modular_pipeline_utils import InputParam, OutputParam -from .modular_pipeline import FluxModularPipeline -from ...pipelines import FluxPipeline + # TODO: consider making these common utilities for modular if they are not pipeline-specific. from ..qwenimage.inputs import calculate_dimension_from_latents, repeat_tensor_to_batch_size +from .modular_pipeline import FluxModularPipeline class FluxTextInputStep(ModularPipelineBlocks): @@ -202,11 +203,7 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip # TODO: Implement patchifier for Flux. latent_height, latent_width = image_latent_tensor.shape[2:] image_latent_tensor = FluxPipeline._pack_latents( - image_latent_tensor, - block_state.batch_size, - image_latent_tensor.shape[1], - latent_height, - latent_width + image_latent_tensor, block_state.batch_size, image_latent_tensor.shape[1], latent_height, latent_width ) # 3. Expand batch size diff --git a/src/diffusers/modular_pipelines/flux/modular_blocks.py b/src/diffusers/modular_pipelines/flux/modular_blocks.py index 9ef293f9bb38..b40dfe176207 100644 --- a/src/diffusers/modular_pipelines/flux/modular_blocks.py +++ b/src/diffusers/modular_pipelines/flux/modular_blocks.py @@ -19,22 +19,19 @@ FluxImg2ImgPrepareLatentsStep, FluxImg2ImgSetTimestepsStep, FluxPrepareLatentsStep, + FluxRoPEInputsStep, FluxSetTimestepsStep, ) from .decoders import FluxDecodeStep from .denoise import FluxDenoiseStep -from .encoders import FluxTextEncoderStep, FluxVaeEncoderDynamicStep -from .before_denoise import FluxRoPEInputsStep -from .inputs import FluxTextInputStep, FluxInputsDynamicStep - +from .encoders import FluxProcessImagesInputStep, FluxTextEncoderStep, FluxVaeEncoderDynamicStep +from .inputs import FluxInputsDynamicStep, FluxTextInputStep logger = logging.get_logger(__name__) # pylint: disable=invalid-name # vae encoder (run before before_denoise) -from .encoders import FluxProcessImagesInputStep - FluxImg2ImgVaeEncoderBlocks = InsertableDict( [ ("preprocess", FluxProcessImagesInputStep()), @@ -42,6 +39,7 @@ ] ) + class FluxImg2ImgVaeEncoderStep(SequentialPipelineBlocks): model_name = "flux" @@ -68,25 +66,23 @@ def description(self): ) - # before_denoise: text2img FluxBeforeDenoiseBlocks = InsertableDict( [ ("prepare_latents", FluxPrepareLatentsStep()), ("set_timesteps", FluxSetTimestepsStep()), - ("prepare_rope_inputs", FluxRoPEInputsStep()) + ("prepare_rope_inputs", FluxRoPEInputsStep()), ] ) + class FluxBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = FluxBeforeDenoiseBlocks.values() block_names = FluxBeforeDenoiseBlocks.keys() @property def description(self): - return ( - "Before denoise step that prepares the inputs for the denoise step in text-to-image generation." - ) + return "Before denoise step that prepares the inputs for the denoise step in text-to-image generation." # before_denoise: img2img @@ -95,18 +91,18 @@ def description(self): ("prepare_latents", FluxPrepareLatentsStep()), ("set_timesteps", FluxImg2ImgSetTimestepsStep()), ("prepare_img2img_latents", FluxImg2ImgPrepareLatentsStep()), - ("prepare_rope_inputs", FluxRoPEInputsStep()) + ("prepare_rope_inputs", FluxRoPEInputsStep()), ] ) + + class FluxImg2ImgBeforeDenoiseStep(SequentialPipelineBlocks): block_classes = FluxImg2ImgBeforeDenoiseBlocks.values() block_names = FluxImg2ImgBeforeDenoiseBlocks.keys() @property def description(self): - return ( - "Before denoise step that prepare the inputs for the denoise step for img2img task." - ) + return "Before denoise step that prepare the inputs for the denoise step for img2img task." # before_denoise: all task (text2img, img2img) @@ -153,12 +149,10 @@ def description(self): # inputs: text2image/img2img FluxImg2ImgBlocks = InsertableDict( - [ - ("text_inputs", FluxTextInputStep()), - ("additional_inputs", FluxInputsDynamicStep()) - ] + [("text_inputs", FluxTextInputStep()), ("additional_inputs", FluxInputsDynamicStep())] ) + class FluxImg2ImgInputStep(SequentialPipelineBlocks): model_name = "flux" block_classes = FluxImg2ImgBlocks.values() @@ -174,7 +168,7 @@ def description(self): class FluxImageAutoInputStep(AutoPipelineBlocks): block_classes = [FluxImg2ImgInputStep, FluxTextInputStep] block_names = ["img2img", "text2image"] - block_trigger_inputs = [ "image_latents", None] + block_trigger_inputs = ["image_latents", None] @property def description(self): @@ -210,9 +204,11 @@ def description(self): ("text_encoder", FluxTextEncoderStep()), ("image_encoder", FluxAutoVaeEncoderStep()), ("denoise", FluxCoreDenoiseStep()), - ("decode", FluxDecodeStep()) + ("decode", FluxDecodeStep()), ] ) + + class FluxAutoBlocks(SequentialPipelineBlocks): model_name = "flux"