Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,8 @@
_import_structure["modular_pipelines"].extend(
[
"FluxAutoBlocks",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
"FluxModularPipeline",
"QwenImageAutoBlocks",
"QwenImageEditAutoBlocks",
Expand Down Expand Up @@ -1050,6 +1052,8 @@
else:
from .modular_pipelines import (
FluxAutoBlocks,
FluxKontextAutoBlocks,
FluxKontextModularPipeline,
FluxModularPipeline,
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
Expand Down
9 changes: 7 additions & 2 deletions src/diffusers/modular_pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@
]
_import_structure["stable_diffusion_xl"] = ["StableDiffusionXLAutoBlocks", "StableDiffusionXLModularPipeline"]
_import_structure["wan"] = ["WanAutoBlocks", "WanModularPipeline"]
_import_structure["flux"] = ["FluxAutoBlocks", "FluxModularPipeline"]
_import_structure["flux"] = [
"FluxAutoBlocks",
"FluxModularPipeline",
"FluxKontextAutoBlocks",
"FluxKontextModularPipeline",
]
_import_structure["qwenimage"] = [
"QwenImageAutoBlocks",
"QwenImageModularPipeline",
Expand All @@ -65,7 +70,7 @@
from ..utils.dummy_pt_objects import * # noqa F403
else:
from .components_manager import ComponentsManager
from .flux import FluxAutoBlocks, FluxModularPipeline
from .flux import FluxAutoBlocks, FluxKontextAutoBlocks, FluxKontextModularPipeline, FluxModularPipeline
from .modular_pipeline import (
AutoPipelineBlocks,
BlockState,
Expand Down
15 changes: 12 additions & 3 deletions src/diffusers/modular_pipelines/flux/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,18 @@
_import_structure["modular_blocks"] = [
"ALL_BLOCKS",
"AUTO_BLOCKS",
"AUTO_BLOCKS_KONTEXT",
"FLUX_KONTEXT_BLOCKS",
"TEXT2IMAGE_BLOCKS",
"FluxAutoBeforeDenoiseStep",
"FluxAutoBlocks",
"FluxAutoBlocks",
"FluxAutoDecodeStep",
"FluxAutoDenoiseStep",
"FluxKontextAutoBlocks",
"FluxKontextAutoDenoiseStep",
"FluxKontextBeforeDenoiseStep",
]
_import_structure["modular_pipeline"] = ["FluxModularPipeline"]
_import_structure["modular_pipeline"] = ["FluxKontextModularPipeline", "FluxModularPipeline"]

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
try:
Expand All @@ -45,13 +49,18 @@
from .modular_blocks import (
ALL_BLOCKS,
AUTO_BLOCKS,
AUTO_BLOCKS_KONTEXT,
FLUX_KONTEXT_BLOCKS,
TEXT2IMAGE_BLOCKS,
FluxAutoBeforeDenoiseStep,
FluxAutoBlocks,
FluxAutoDecodeStep,
FluxAutoDenoiseStep,
FluxKontextAutoBlocks,
FluxKontextAutoDenoiseStep,
FluxKontextBeforeDenoiseStep,
)
from .modular_pipeline import FluxModularPipeline
from .modular_pipeline import FluxKontextModularPipeline, FluxModularPipeline
else:
import sys

Expand Down
84 changes: 72 additions & 12 deletions src/diffusers/modular_pipelines/flux/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,15 +118,6 @@ def retrieve_latents(
raise AttributeError("Could not access latents of provided encoder_output")


# TODO: align this with Qwen patchifier
def _pack_latents(latents, batch_size, num_channels_latents, height, width):
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
latents = latents.permute(0, 2, 4, 1, 3, 5)
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)

return latents


def _get_initial_timesteps_and_optionals(
transformer,
scheduler,
Expand Down Expand Up @@ -398,16 +389,15 @@ def prepare_latents(
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

# TODO: move packing latents code to a patchifier
# TODO: move packing latents code to a patchifier similar to Qwen
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
latents = _pack_latents(latents, batch_size, num_channels_latents, height, width)
latents = FluxPipeline._pack_latents(latents, batch_size, num_channels_latents, height, width)

return latents

@torch.no_grad()
def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

block_state.height = block_state.height or components.default_height
block_state.width = block_state.width or components.default_width
block_state.device = components._execution_device
Expand Down Expand Up @@ -557,3 +547,73 @@ def __call__(self, components: FluxModularPipeline, state: PipelineState) -> Pip
self.set_block_state(state, block_state)

return components, state


class FluxKontextRoPEInputsStep(ModularPipelineBlocks):
model_name = "flux-kontext"

@property
def description(self) -> str:
return "Step that prepares the RoPE inputs for the denoising process of Flux Kontext. Should be placed after text encoder and latent preparation steps."

@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="image_height"),
InputParam(name="image_width"),
InputParam(name="height"),
InputParam(name="width"),
InputParam(name="prompt_embeds"),
]

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
name="txt_ids",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the prompt embeds, used for RoPE calculation.",
),
OutputParam(
name="img_ids",
kwargs_type="denoiser_input_fields",
type_hint=List[int],
description="The sequence lengths of the image latents, used for RoPE calculation.",
),
]

def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

prompt_embeds = block_state.prompt_embeds
device, dtype = prompt_embeds.device, prompt_embeds.dtype
block_state.txt_ids = torch.zeros(prompt_embeds.shape[1], 3).to(
device=prompt_embeds.device, dtype=prompt_embeds.dtype
)

img_ids = None
if (
getattr(block_state, "image_height", None) is not None
and getattr(block_state, "image_width", None) is not None
):
image_latent_height = 2 * (int(block_state.image_height) // (components.vae_scale_factor * 2))
image_latent_width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
img_ids = FluxPipeline._prepare_latent_image_ids(
None, image_latent_height // 2, image_latent_width // 2, device, dtype
)
# image ids are the same as latent ids with the first dimension set to 1 instead of 0
img_ids[..., 0] = 1

height = 2 * (int(block_state.height) // (components.vae_scale_factor * 2))
width = 2 * (int(block_state.width) // (components.vae_scale_factor * 2))
latent_ids = FluxPipeline._prepare_latent_image_ids(None, height // 2, width // 2, device, dtype)

if img_ids is not None:
latent_ids = torch.cat([latent_ids, img_ids], dim=0)

block_state.img_ids = latent_ids

self.set_block_state(state, block_state)

return components, state
107 changes: 107 additions & 0 deletions src/diffusers/modular_pipelines/flux/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,96 @@ def __call__(
return components, block_state


class FluxKontextLoopDenoiser(ModularPipelineBlocks):
model_name = "flux-kontext"

@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", FluxTransformer2DModel)]

@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents for Flux Kontext. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `FluxDenoiseLoopWrapper`)"
)

@property
def inputs(self) -> List[Tuple[str, Any]]:
return [
InputParam("joint_attention_kwargs"),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"image_latents",
type_hint=torch.Tensor,
description="Image latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"guidance",
required=True,
type_hint=torch.Tensor,
description="Guidance scale as a tensor",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Prompt embeddings",
),
InputParam(
"pooled_prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Pooled prompt embeddings",
),
InputParam(
"txt_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from text sequence needed for RoPE",
),
InputParam(
"img_ids",
required=True,
type_hint=torch.Tensor,
description="IDs computed from latent sequence needed for RoPE",
),
]

@torch.no_grad()
def __call__(
self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
latents = block_state.latents
latent_model_input = latents
image_latents = block_state.image_latents
if image_latents is not None:
latent_model_input = torch.cat([latent_model_input, image_latents], dim=1)

timestep = t.expand(latents.shape[0]).to(latents.dtype)
noise_pred = components.transformer(
hidden_states=latent_model_input,
timestep=timestep / 1000,
guidance=block_state.guidance,
encoder_hidden_states=block_state.prompt_embeds,
pooled_projections=block_state.pooled_prompt_embeds,
joint_attention_kwargs=block_state.joint_attention_kwargs,
txt_ids=block_state.txt_ids,
img_ids=block_state.img_ids,
return_dict=False,
)[0]
noise_pred = noise_pred[:, : latents.size(1)]
block_state.noise_pred = noise_pred

return components, block_state


class FluxLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "flux"

Expand Down Expand Up @@ -221,3 +311,20 @@ def description(self) -> str:
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)


class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper):
model_name = "flux-kontext"
block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser]
block_names = ["denoiser", "after_denoiser"]

@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `FluxKontextLoopDenoiser`\n"
" - `FluxLoopAfterDenoiser`\n"
"This block supports both text2image and img2img tasks."
)
Loading
Loading